In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import os
import json

from unetModules.unet_model import UNet
from unetModules.Utils import Utils
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
from PIL import Image

In [2]:
cuda = True
batch_size = 1
epochs = 120
learn_rate = 5e-4
save_every = 5
n_classes = 2
n_channels = 3
weight_decay = 2e-4
description = "Diese Unet System wurde mit den unbehandelten River Blindness Datensatz trainiert. Datenaugmentierung wurde angewendet und der Datensatz wurde um weiter Bilder erweitert(+Red). Validation set wurde separiert."

path_save_model = "./model/"
#path_images = "../../content/SchistosomaMansoni/img/"
#path_labels = "../../content/SchistosomaMansoni/labels/"
#path_images_val = "../../content/SchistosomaMansoni/val_img/"
#path_labels_val = "../../content/SchistosomaMansoni/val_labels/"
path_images = "../../content/RiverBlindness/img/"
path_labels = "../../content/RiverBlindness/labels/"
path_images_val = "../../content/RiverBlindness/val_img/"
path_labels_val = "../../content/RiverBlindness/val_labels/"

with open(path_save_model + "logs/settings.json", "w", encoding="utf-8") as f:
    settings = { 
        "cuda": cuda,
        "batch_size": batch_size,
        "epochs": epochs,
        "learn_rate": learn_rate,
        "n_classes": n_classes,
        "n_channels": n_channels,
        "weight_decay": weight_decay,
        "description": description
    }
    
    json.dump(settings, f, ensure_ascii=False, indent=4)

In [3]:
img_filenames = np.array(os.listdir(path_images))
input_train = []

label_filenames = np.array(os.listdir(path_labels))
label_train = []

assert(len(img_filenames) == len(label_filenames))

# Reading train images and labels                  
for file in img_filenames:    
    img = Image.open(path_images + file)
    label = Image.open(path_labels + file)
    augmented_imgs, augmented_labels = Utils.imageAugmentation(img, label)
        
    for aug_img in augmented_imgs:
        img = cv2.resize(np.array(aug_img), (512, 512), cv2.INTER_NEAREST)
        input_train.append(img[:,:,0:3]) # cutting out potential alpha channel
      
    for aug_label in augmented_labels:
        label = cv2.resize(np.array(aug_label), (512, 512), cv2.INTER_NEAREST)
        label_train.append(label)

input_train = np.stack(input_train, axis=2)
input_train = torch.tensor(input_train).transpose(0, 2).transpose(1, 3)

label_train = np.array(label_train) 
label_train = torch.tensor(label_train)

# Reading validation images and labels
val_img_filenames = np.array(os.listdir(path_images_val))
input_val = []

val_label_filenames = np.array(os.listdir(path_labels_val))
label_val = []

assert(len(val_img_filenames) == len(val_label_filenames))

# Validation set should not be augmented
for file in val_img_filenames:
    val_img = Image.open(path_images_val + file)
    val_img = cv2.resize(np.array(val_img), (512, 512), cv2.INTER_NEAREST)
    input_val.append(val_img[:,:,0:3])
    
    val_label = Image.open(path_labels_val + file)
    val_label = cv2.resize(np.array(val_label), (512, 512), cv2.INTER_NEAREST)
    label_val.append(val_label)
    
input_val = np.stack(input_val, axis=2)
input_val = torch.tensor(input_val).transpose(0,2).transpose(1,3)

label_val = np.array(label_val)
label_val = torch.tensor(label_val)

batch_count_train = (len(input_train) // batch_size)
batch_count_val = (len(input_val) // batch_size)

if((len(input_train) / batch_size) % 1 != 0):
    batch_count_train += 1

if((len(input_val) / batch_size) % 1 != 0):
    batch_count_val += 1

print("[INFO]Starting to define the class weights...")
class_weights = Utils.get_class_weights(label_train, n_classes)
print("[INFO]Fetched all class weights successfully!")

unet = UNet(n_channels=n_channels, n_classes=n_classes)
print("[INFO]Model Instantiated!")

# Checking for cuda
if(torch.cuda.is_available() & cuda):
    print("[INFO]CUDA is available!")
    device = torch.device("cuda:0")
else:
    print("[INFO]CUDA isn't available!")
    device = torch.device("cpu")
                
unet = unet.to(device)
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights).to(device))
optimizer = torch.optim.Adam(unet.parameters(),
                                lr=learn_rate,
                                weight_decay=weight_decay)
print("[INFO]Defined the loss function and the optimizer")

[INFO]Starting to define the class weights...
[INFO]Fetched all class weights successfully!
[INFO]Model Instantiated!
[INFO]CUDA is available!
[INFO]Defined the loss function and the optimizer


In [4]:
print("[INFO]Starting Training...")

train_losses = []
val_losses = []

for e in range(1, epochs + 1):
    train_loss = 0
    print ("-"*15,"Epoch %d" % e , "-"*15) 

    unet.train()

    for _ in tqdm(range(batch_count_train)):
        X_train, y_train = input_train[batch_size * _: batch_size * (_ + 1)], label_train[batch_size * _: batch_size * (_ + 1)]        
        X_train, y_train = X_train.to(device), y_train.to(device)

        optimizer.zero_grad()

        out = unet(X_train.float())        
        loss = criterion(out, y_train.long())           
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        
    print()
    train_losses.append(train_loss)

    print ('Epoch {}/{}...'.format(e, epochs),
            'Loss {:6f}'.format(train_loss))

    with torch.no_grad():                
        print()
        print("Validation:")

        # Validates the model
        unet.eval()              
        val_loss = 0

        for _ in tqdm(range(batch_count_val)):           
            X_val, y_val = input_val[batch_size * _: batch_size * (_ + 1)], label_val[batch_size * _: batch_size * (_ + 1)]           
            X_val, y_val = X_val.to(device), y_val.to(device)

            out = unet(X_val.float())
            loss = criterion(out, y_val.long())

            val_loss += loss.item()

        print('Loss {:6f}'.format(val_loss))

        val_losses.append(val_loss)            

    if e % save_every == 0:
        checkpoint = {
            'epochs' : e,
            'state_dict' : unet.state_dict()
        }

        torch.save(checkpoint, path_save_model + 'ckpt-unet-{}-{}.pth'.format(e, train_loss))
        
        with open(path_save_model + "logs/trainLosses.json", "w", encoding="utf-8") as f:
            json.dump(train_losses, f, ensure_ascii=False, indent=4)

        with open(path_save_model + "logs/valLosses.json", "w", encoding="utf-8") as f:
            json.dump(val_losses, f, ensure_ascii=False, indent=4) 
            
        print()
        print('Model and Losses saved!')

    print ('Epoch {}/{}...'.format(e, epochs))         
print("[INFO]Training Process complete!")

[INFO]Starting Training...
--------------- Epoch 1 ---------------


  0%|          | 0/164 [00:04<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 256.00 MiB (GPU 0; 8.00 GiB total capacity; 6.85 GiB already allocated; 0 bytes free; 7.25 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF