In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import json

from enetModules.ENet import ENet
from enetModules.ENet import RDDNeck
from enetModules.ENet import UBNeck
from enetModules.ENet import ASNeck
from enetModules.Utils import Utils
from torch.optim.lr_scheduler import StepLR
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image

In [2]:
cuda = False
batch_size = 32
epochs = 40
learn_rate = 5e-4 
save_every = 5 
n_classes = 2
weight_decay = 2e-4
test_val_split_size = 0.10
description = "Diese System wurde mit den unbehandelten River Blindness Datensatz trainiert."

path_load_model = "./modelBackups/Model_29_11_22_best_SM/ckpt-enet-80-0.19326677545905113.pth"
path_save_model = "./transferTrainModel/"
path_images = "./content/RiverBlindness/img/"
path_labels = "./content/RiverBlindness/labels/"

with open(path_save_model + "logs/settings.json", "w", encoding="utf-8") as f:
    settings = { 
        "cuda": cuda,
        "batch_size": batch_size,
        "epochs": epochs,
        "learn_rate": learn_rate,
        "n_classes": n_classes,
        "weight_decay": weight_decay,
        "test_val_split_size": test_val_split_size,
        "description": description
    }
    
    json.dump(settings, f, ensure_ascii=False, indent=4)

In [3]:
def freezeLayer(layer):
    for param in layer.parameters():
        param.requires_grad = False

def unfreezeLayer(layer):
    for param in layer.parameters():
        param.requires_grad = True

In [4]:
img_filenames = np.array(os.listdir(path_images))
inputs = []

label_filenames = np.array(os.listdir(path_labels))
labels = []

assert(len(img_filenames) == len(label_filenames))

# Reading images and labels                  
for file in img_filenames:    
    img = Image.open(path_images + file)
    augmented_imgs = Utils.imageAugmentation(img)
    
    for aug_img in augmented_imgs:
        img = cv2.resize(np.array(aug_img), (512, 512), cv2.INTER_NEAREST)
        inputs.append(img[:,:,0:3]) # cutting out potential alpha channel
    
    label = Image.open(path_labels + file)
    augmented_labels = Utils.imageAugmentation(label)
    
    for aug_label in augmented_labels:
        label = cv2.resize(np.array(aug_label), (512, 512), cv2.INTER_NEAREST)
        labels.append(label)
    
inputs = np.stack(inputs, axis=2)
inputs = torch.tensor(inputs).transpose(0, 2).transpose(1, 3)

labels = np.array(labels) 
labels = torch.tensor(labels)

#Shuffles the datasets
input_train, input_val, label_train, label_val = train_test_split(inputs, labels, test_size=test_val_split_size)

batch_count_train = (len(input_train) // batch_size)
batch_count_val = (len(input_val) // batch_size)

if((len(input_train) / batch_size) % 1 != 0):
    batch_count_train += 1

if((len(input_val) / batch_size) % 1 != 0):
    batch_count_val += 1

print("[INFO]Starting to define the class weights...")
class_weights = Utils.get_class_weights(labels, n_classes)
print("[INFO]Fetched all class weights successfully!")

# Checking for cuda
if(torch.cuda.is_available() & cuda):
    print("[INFO]CUDA is available!")
    device = torch.device("cuda:0")
else:
    print("[INFO]CUDA isn't available!")
    device = torch.device("cpu")
    
print("[INFO]Training model form path ====> " + path_load_model)
    
preTrainedModel = torch.load(path_load_model, map_location=device)

enet = ENet(n_classes)
enet.load_state_dict(preTrainedModel['state_dict'])
enet = enet.to(device)

print ("[INFO]Model Loaded!")
print ("[INFO]Freezing Layer.")

freezeLayer(enet.init)
freezeLayer(enet.b10)
freezeLayer(enet.b11)  
freezeLayer(enet.b12)
freezeLayer(enet.b13)  
freezeLayer(enet.b14)

freezeLayer(enet.b20)  
freezeLayer(enet.b21)
freezeLayer(enet.b22)  
freezeLayer(enet.b23)
freezeLayer(enet.b24)  
freezeLayer(enet.b25)
freezeLayer(enet.b26)  
freezeLayer(enet.b27)
freezeLayer(enet.b28) 

freezeLayer(enet.b31)
freezeLayer(enet.b32)  
freezeLayer(enet.b33)
freezeLayer(enet.b34)  
freezeLayer(enet.b35)
freezeLayer(enet.b36)  
freezeLayer(enet.b37)
freezeLayer(enet.b38) 
    
freezeLayer(enet.b40)
freezeLayer(enet.b41)      
freezeLayer(enet.b42)

freezeLayer(enet.b50)    
freezeLayer(enet.b51)

# Set a new output layer
'''
enet.b34 = RDDNeck(dilation=4, 
                    in_channels=128, 
                    out_channels=128, 
                    down_flag=False)

enet.b35 =  RDDNeck(dilation=1, 
                    in_channels=128, 
                    out_channels=128, 
                    down_flag=False)

enet.b36 = RDDNeck(dilation=8, 
                    in_channels=128, 
                    out_channels=128, 
                    down_flag=False)

enet.b37 = ASNeck(in_channels=128, 
                    out_channels=128)

enet.b38 = RDDNeck(dilation=16, 
                    in_channels=128, 
                    out_channels=128, 
                    down_flag=False)


enet.b40 = UBNeck(in_channels=128, 
                    out_channels=64, 
                    relu=True)

enet.b41 = RDDNeck(dilation=1, 
                    in_channels=64, 
                    out_channels=64, 
                    down_flag=False, 
                    relu=True)

enet.b42 = RDDNeck(dilation=1, 
                    in_channels=64, 
                    out_channels=64, 
                    down_flag=False, 
                    relu=True)


enet.b50 = UBNeck(in_channels=64, 
                          out_channels=16, 
                          relu=True)

enet.b51 = RDDNeck(dilation=1, 
                    in_channels=16, 
                    out_channels=16, 
                    down_flag=False, 
                    relu=True)
'''

enet.fullconv = nn.ConvTranspose2d(in_channels=16, 
                                    out_channels= n_classes, 
                                    kernel_size=3, 
                                    stride=2, 
                                    padding=1, 
                                    output_padding=1,
                                    bias=False)
    
print ("[INFO]Model Instantiated!")

criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights).to(device))
optimizer = torch.optim.Adam(enet.parameters(),
                                lr=learn_rate,
                                weight_decay=weight_decay)
print("[INFO]Defined the loss function and the optimizer")  

[INFO]Starting to define the class weights...
[INFO]Fetched all class weights successfully!
[INFO]CUDA isn't available!
[INFO]Training model form path ====> ./modelBackups/Model_29_11_22_best_SM/ckpt-enet-80-0.19326677545905113.pth
[INFO]Model Loaded!
[INFO]Freezing Layer.
[INFO]Model Instantiated!
[INFO]Defined the loss function and the optimizer


In [5]:
print("[INFO]Staring Training...")

train_losses = []
val_losses = []

for e in range(1, epochs + 1):
    train_loss = 0
    print ("-"*15,"Epoch %d" % e , "-"*15) 

    enet.train()

    for _ in tqdm(range(batch_count_train)):                
        X_train, y_train = input_train[batch_size * _: batch_size * (_ + 1)], label_train[batch_size * _: batch_size * (_ + 1)]        
        X_train, y_train = X_train.to(device), y_train.to(device)

        optimizer.zero_grad()

        out = enet(X_train.float())        
        loss = criterion(out, y_train.long())           
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    print()
    train_losses.append(train_loss)

    print ('Epoch {}/{}...'.format(e, epochs),
            'Loss {:6f}'.format(train_loss))

    with torch.no_grad():                
        print()
        print("Validation:")

        # Validates the model
        enet.eval()              
        val_loss = 0

        for _ in tqdm(range(batch_count_val)):           
            X_val, y_val = input_val[batch_size * _: batch_size * (_ + 1)], label_val[batch_size * _: batch_size * (_ + 1)]
            X_val, y_val = X_val.to(device), y_val.to(device)

            out = enet(X_val.float())
            loss = criterion(out, y_val.long())

            val_loss += loss.item()

        print('Loss {:6f}'.format(val_loss))

        val_losses.append(val_loss)            

    if e % save_every == 0:
        checkpoint = {
            'epochs' : e,
            'state_dict' : enet.state_dict()
        }

        torch.save(checkpoint, path_save_model + 'ckpt-enet-{}-{}.pth'.format(e, train_loss))
        
        with open(path_save_model + "logs/trainLosses.json", "w", encoding="utf-8") as f:
            json.dump(train_losses, f, ensure_ascii=False, indent=4)

        with open(path_save_model + "logs/valLosses.json", "w", encoding="utf-8") as f:
            json.dump(val_losses, f, ensure_ascii=False, indent=4) 
            
        print()
        print('Model and Losses saved!')

    print ('Epoch {}/{}...'.format(e, epochs))         
print("[INFO]Training Process complete!") 

[INFO]Staring Training...
--------------- Epoch 1 ---------------


100%|██████████| 11/11 [01:36<00:00,  8.80s/it]



Epoch 1/1... Loss 9.692383

Validation:


100%|██████████| 2/2 [00:05<00:00,  2.64s/it]

Loss 1.485765
Epoch 1/1...
[INFO]Training Process complete!





