# Training of models (Unet and DeepLabV3)
Code inspired  by Eugenia Anello. Source : https://github.com/eugeniaring/Medium-Articles/blob/main/Pytorch/denAE.ipynb

**The main differences:**
* Weights & Biases platform calling implemetation
* Unet architecture and usage of DeepLabV3
* New dataset class
* Change od displaying output
* Whole new output processing and more

In [None]:
# root_path       - path to the wound dataset
# csv_path_folder - path to the folder containing csvs with lists of image names used in given dataset (test, train, validate)
# model_type      - swith between different model types. Three model types are supported: 'Unet-Sigmoid', 'Unet-ReLU' and 'DeepLabV3'
# unet_relu_path  - path to the defined model in .py file. Must be set if given model is being used
# unet_sigm_path  - path to the defined model in .py file. Must be set if given model is being used 

root_path =       '../input/350pics/dataset'
csv_path_folder = '../input/wound-dataset-splitedlist/sort whole/'

model_type =      'Unet-Sigmoid' # 'Unet-Sigmoid' # 'Unet-ReLU' # 'DeepLabV3'

unet_relu_path =  '../input/sigmoid-model/unet_model.py'
unet_sigm_path =  '../input/nn-model/unet_model.py'

**INITIALIZATION**

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
from skimage import io
import seaborn as sb
import pandas as pd
import numpy as np
import random
import wandb
import time
import csv
import os

import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import transforms#, datasets
from torch.utils.data import DataLoader,random_split

random.seed(22)
torch.random.manual_seed(22)

In [None]:
wandb.login()
wandb.init(project="wound_image_processing", entity="------") # insert username

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Selected device: {device}')

In [None]:
# Global values and settings for the NN

lr= 0.00005 # 0.00001
num_epochs = 1000
batch_size= 4
save_checkpoint = True

wandb.config = {
    "learning_rate": lr,
    "epochs": num_epochs,
    "batch_size": batch_size,
    "save_checkpoint": save_checkpoint
}

# Marking each class with the color (for example: 0 (granulation) is red color (255, 0, 0) in RGB model)

switcher = {
        0: [255, 0, 0], # granulation tissue
        1: [0, 255, 0], # slough tissue
        2: [0, 0, 255], # necrotic tissue
        3: [0, 0, 0]    # background
    }

# Colors for printing some important stuff

CSTART = '\033[41m'
CEND = '\033[0m'

In [None]:
# model = Unet()

assert model_type in ('Unet-Sigmoid', 'Unet-ReLU', 'DeepLabV3'), "model name should be 'Unet-Sigmoid', 'Unet-ReLU' or 'DeepLabV3'"

if model_type.startswith('Unet'):
    # Loading model from the python file
    from shutil import copyfile
    unetpath = unet_relu_path if model_type == 'Unet-Sigmoid' else unet_sigm_path
    copyfile(src = unetpath, dst = "../working/unet_model.py")
    from unet_model import Unet
    model = Unet()
    
elif model_type == 'DeepLabV3':
    from torchvision.models.segmentation.deeplabv3 import DeepLabHead
    from torchvision import models
    model = models.segmentation.deeplabv3_resnet101(pretrained=True, progress=True)
    outputchannels=4
    model.classifier = DeepLabHead(2048, outputchannels)
else:
    raise Exception("model name should be 'Unet-Sigmoid', 'Unet-ReLU' or 'DeepLabV3'")
    
optim = torch.optim.Adam(model.parameters(), lr=lr)
model.to(device)

**DOWNLOADING WOUND DATASET**

In [None]:
# Creating the class to work with the dataset

class WoundDataset(data.Dataset):

    def __init__(self, root, transform=None, csv_file=None):
        self.img_orig = root + '/imgs'  # folder with the resized images (512 * 512)
        self.img_mask = root + '/masks' # folder with the masks to these images
        self.transform = transform
        
        if csv_file is None:
            self.imglist = os.listdir(self.img_orig)
        else:
            with open(csv_file, newline='') as csvfile:
                spamreader = csv.reader(csvfile, delimiter=',', quotechar='|')
                csv_img_list = [row[1] for row in spamreader][1:]
                self.imglist = [filename for filename in os.listdir(self.img_orig) if filename in csv_img_list]
                
    def __len__(self):
        return len(self.imglist)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.img_orig, self.imglist[idx])
        msk_name = os.path.join(self.img_mask, self.imglist[idx])
        
        img = io.imread(img_name)
        msk = io.imread(msk_name)[:,:,:3] # reading just RGB channels (without hue)
        
        if self.transform:
            img = self.transform(img)
            msk = self.transform(msk)

        return img, msk

In [None]:
# Function for coding the input to the NN

def convert_into_1d(old_tensor):
    # pixels that will be marked as a backgroung
    max_pixel_value = (torch.max(torch.flatten(old_tensor))).cpu().numpy()
    color_threshold = 0.5 * max_pixel_value
    
    # 1 - additional channel for the class labeling
    new_tensor = torch.full((old_tensor.shape[0], 1, old_tensor.shape[2], old_tensor.shape[3]), color_threshold)
    
    old_tensor = old_tensor.to(device)
    new_tensor = new_tensor.to(device)
    
    # concatenate two tensors
    old_tensor = torch.cat((old_tensor, new_tensor), dim=1)
    
    # finding the layer with the maximum pixel value
    max_idxs = torch.argmax(old_tensor, dim=1)
                    
    return max_idxs

In [None]:
# Additional function that helps to match an exact class (0-3) with the color (that it is marked in RGB color model)

def converting_to_rgb_layers(x):
    global switcher
    return switcher.get(int(x[0]), "error")

In [None]:
# Function for decoding final mask (NN output)

def convert_into_3d(output_mask):
    # create a new 'empty' tensor
    zero_tensor = torch.zeros(output_mask.shape[1], output_mask.shape[2], 1)
    
    # detecting which layer has to be displayed in the picture
    dominant_layers = torch.argmax(torch.tensor(output_mask), dim=0).unsqueeze(2)
    
    # creating a new 3-dimentional tensor
    # first dim - dominant layer, other - empty layers
    three_dim_tensor = torch.cat((dominant_layers, zero_tensor, zero_tensor), dim=2)
    # converting it to numpy array
    numpy_array = three_dim_tensor.detach().numpy()
        
    # changing dominant layers to rgb image
    numpy_array = np.apply_along_axis(converting_to_rgb_layers, -1, numpy_array)
        
    return numpy_array

In [None]:
# Creating the dataset

dataset = WoundDataset(root = root_path, transform = transforms.ToTensor())
dataset

In [None]:
# Creating a function for displaying images with their masks
# First 3 images are permanent, the last 3 are changing

def display_data(dataset, model=None, n_base=3, n_rand=3):
    n = n_base + n_rand
    plt.figure(figsize=(25,7))
    
    for i in range(n):
        if(i >= n_base):
            index=random.randint(n, len(dataset)-1)
        else:
            index = i
 
        ax = plt.subplot(3 if model else 2, n, i+1)
        
        image, mask = dataset[index]
        original_image = image.to(device).unsqueeze(0)
        original_mask = mask
        
        if model:
            model.eval()

        with torch.no_grad():
            image = np.transpose(np.array(image), (1,2,0))
            ax = plt.subplot(3 if model else 2, n, i+1)
            ax.imshow(image)
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
            if i == n//2:
                ax.set_title('Input images')

            original_mask = np.transpose(np.array(original_mask), (1,2,0))
            ax = plt.subplot(3 if model else 2, n, n+i+1)
            ax.imshow(original_mask)
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)  
            if i == n//2:
                ax.set_title('Original masks')
            
            if model:
                output_mask = model(original_image)
                if model_type == 'DeepLabV3':
                    output_mask = output_mask['out']
                output_mask  = output_mask.squeeze(0).cpu()
                output_mask = convert_into_3d(output_mask) # decoding the output
                ax = plt.subplot(3, n, 2*n+i+1)
                plt.imshow(output_mask)
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)  
                if i == n//2:
                    ax.set_title('Generated masks')
            
    plt.subplots_adjust(left=0.1, bottom=0.1, right=0.7, top=0.9, wspace=0.3, hspace=0.3)   
    
    plt.show()   

In [None]:
display_data(dataset)

**SPLITTING THE DATASET**

In [None]:
# Dividing the whole dataset
# Ratio: 70% - train, 15% - test, 15% - validation

train_ratio = 0.7
test_ratio = 0.15

train_size = int(np.floor(train_ratio * len(dataset)))
test_size = int(np.floor(test_ratio * len(dataset)))

# CHANGE - NAST - Load from csvs as new datasets
# data_train, data_test_val = torch.utils.data.random_split(dataset, [train_size, len(dataset)-train_size])
# data_test, data_val = torch.utils.data.random_split(data_test_val, [test_size, len(data_test_val)-test_size])
data_train = WoundDataset(root = root_path, transform = transforms.ToTensor(), csv_file=csv_path_folder+'train.csv')
data_test =  WoundDataset(root = root_path, transform = transforms.ToTensor(), csv_file=csv_path_folder+'test.csv')
data_val =   WoundDataset(root = root_path, transform = transforms.ToTensor(), csv_file=csv_path_folder+'val.csv')

train_loader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2)
test_loader =  torch.utils.data.DataLoader(data_test, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=2)
val_loader =   torch.utils.data.DataLoader(data_val, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=2)

**PERFORMING SMALL DATA ANALYSIS**

In [None]:
# Check the size of the dataset

size_of_the_dataset = str(len(dataset))
print('Whole dataset len: ', size_of_the_dataset)
print('Train dataset len: ', len(data_train))
print('Test  dataset len: ', len(data_test))
print('Val   dataset len: ', len(data_val))

In [None]:
# Count the ration of each class

def count_color_ratio(dataset):
    data = []
    for index in range(len(dataset)):
        image, mask = dataset[index]
        
        binary_classes = torch.where(mask > 0, 1, 0)
        r_total, g_total, b_total = map(int, (torch.sum(binary_classes[i]) for i in range (3)))
        black_total = mask.shape[1] * mask.shape[2] - (r_total + g_total + b_total)

        data.append((r_total, g_total, b_total, black_total))

    return pd.DataFrame(data, columns=['number_of_RED_pixels', 'number_of_GREEN_pixels', 'number_of_BLUE_pixels', 'number_of_BLACK_pixels'])

In [None]:
# Show a few counted results

df = count_color_ratio(data_train)
df.sample(n=10)

In [None]:
# Displaying the results

labels = 'GRANULATION TISSUE (red pixels)', 'SLOUGH TISSUE (green pixels)', 'NECROTIC TISSUE (blue pixels)'
red_sum, green_sum, blue_sum = df['number_of_RED_pixels'].sum(), df['number_of_GREEN_pixels'].sum(), df['number_of_BLUE_pixels'].sum()

fig, ax = plt.subplots()
ax.pie([red_sum, green_sum, blue_sum], labels = labels, autopct = '%1.1f%%', colors = ['#b56576', '#eaac8b', '#6d597a'])
plt.show()

Due to the class inequality, future neural network has to use weighted loss function.

For a better performance the equal class distribution is required. There are 3 classes of different tissue types, so the ratio is 100/3 ~ 33% for each class.

In [None]:
sum_all = red_sum + green_sum + blue_sum
basic_ratio = 33.3
red_ratio, green_ratio, blue_ratio, black_ratio = basic_ratio/red_sum/100*sum_all, basic_ratio/green_sum/100*sum_all, basic_ratio/blue_sum/100*sum_all, 1
print("Red ratio: %.2f\nGreen ratio: %.2f\nBlue ratio: %.2f\n" % (red_ratio, green_ratio, blue_ratio))

**INITIALIZING UNET NEURAL NETWORK**

Choosing loss function:

In [None]:
# loss_fn = torch.nn.CrossEntropyLoss()
weights = torch.tensor((float(red_ratio), float(green_ratio), float(blue_ratio), black_ratio)).to(device)
loss_fn = torch.nn.CrossEntropyLoss(weight=weights)

**TRAINING THE MODEL**

In [None]:
def train_model():
    train_loss = []
    model.train()

    for image_batch, mask_batch in train_loader:
        images = image_batch.to(device)
        masks = convert_into_1d(mask_batch.to(device))

        output = model(images)
                
        loss = loss_fn(output if model_type.startswith('Unet') else output['out'], masks.long())
        optim.zero_grad()
        loss.backward()
        optim.step()

        train_loss.append(loss.detach().cpu().numpy())
        
    return np.mean(train_loss)

def validate_model():
    val_loss = []
    model.eval()
    
    with torch.no_grad():
        for image_batch, mask_batch in val_loader:
            images = image_batch.to(device)
            masks = convert_into_1d(mask_batch.to(device))

            output = model(images)

            loss = loss_fn(output if model_type.startswith('Unet') else output['out'], masks.long())
            val_loss.append(loss.detach().cpu().numpy())
            
    return np.mean(val_loss)

# Training cycles

In [None]:
previous_val_loss = []
mean_of_last_three_val_loss = 10000 # a big number for the beginning (before first 5 values will be initialized)

for epoch in range(num_epochs):
    train_loss = train_model()
    val_loss = validate_model()
    
    if len(previous_val_loss) >= 5: mean_of_last_three_val_loss = np.mean([previous_val_loss[i] for i in range(-5, 0)])
        
    if mean_of_last_three_val_loss < val_loss: print(CSTART)
    print('EPOCH {}/{} \t train loss {:.3f} \t val loss {:.3f}'.format(epoch + 1, num_epochs, train_loss, val_loss) + CEND)
    previous_val_loss.append(val_loss)
    
    if model_type.startswith('Unet'):
        if (epoch+1)%20 == 0:
            display_data(data_val, model)
            torch.save(model.state_dict(), str(epoch + 1) + '_epoch_' + size_of_the_dataset + '_images.pth')
    else:
        if (epoch+1)%3 == 0:
            display_data(data_val, model)
            torch.save(model.state_dict(), str(epoch + 1) + '_epoch_' + size_of_the_dataset + '_images.pth')
    
    wandb.log({"train_loss": train_loss, "val_loss": val_loss})


In [None]:
display_data(data_test, model)