In [None]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)
%cd /content/drive/MyDrive/FYP/

model.py

In [None]:
import torch # credit to @aladdinpersson
import torch.nn as nn
import torchvision.transforms.functional as TF

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module): 
    def __init__(
        self, in_channels=1, out_channels=6, features=[64, 128, 256, 512]
    ):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
    
        # Downsampling (Encoder-path)
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Upsampling (Decoder-path)
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2,))
            self.ups.append(DoubleConv(feature*2, feature))

        # Bottle-neck
        self.bottleneck = DoubleConv(features[-1], features[-1]*2)

        # Final convolution at output layer
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape: 
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

generator.py

In [None]:
import cv2

def DataGeneratorZXY(
    img_data,
    label_data,
    size
):
    img, label = [], []

    for i in range(len(img_data[0, 0, :])):
        I = cv2.resize(img_data[:, :, i], size, interpolation=cv2.INTER_AREA)
        L = cv2.resize(label_data[:, :, i], size, interpolation = cv2.INTER_NEAREST)  
        img.append(I)
        label.append(L)
 
    for i in range(len(img_data[0, :, 0])):
        I = cv2.resize(img_data[:, i, :], size, interpolation=cv2.INTER_AREA)
        L  = cv2.resize(label_data[:, i, :], size, interpolation = cv2.INTER_NEAREST)
        img.append(I)
        label.append(L)

    return img, label

data.py

In [None]:
#from generator import DataGeneratorZXY
from torchvision import transforms
from torch.utils.data import Dataset
import numpy as np

class DataGenerator(Dataset):
    def __init__(self
        ):
        self.size = (128, 256)
        self.img_data = np.load('data_train.npz', allow_pickle=True, mmap_mode='r')['data']
        self.label_data = np.load('labels_train.npz', allow_pickle=True, mmap_mode='r')['labels']
        self.img, self.label = DataGeneratorZXY(self.img_data, self.label_data, self.size)
        self.transform = transforms.Compose([transforms.ToTensor()])
        #img, label = DataGeneatorMultiplePlaneWideAngle(img, label) # generator code not developed yet
        #img, label = HorizontalFip(data, label) 
   
    def __len__(self):
        return len(self.img)

    def __getitem__(self, index):
        return self.transform(self.img[index]), self.transform(self.label[index])

early-stopping.py

In [None]:
import torch 
import numpy as np

class EarlyStopping: # thanks pytorchtools.py
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

utils.py

In [None]:
#from data import DataGenerator
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split

device = 'cuda' if torch.cuda.is_available() else 'cpu' 

def save_checkpoint(
    model, 
    optimizer, 
    e, 
    loss,
    acc, 
    filename='checkpoint.pth.tar'
    ):
    print('=> Saving checkpoint')
    checkpoint = {     
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': e,
        'loss': loss,
        'accuracy': acc
    }
    torch.save(checkpoint, filename)

def resume_checkpoint(
    model,
    optimizer
    ):
    print('=> Loading checkpoint')
    checkpoint = torch.load('checkpoint.pth.tar')
    if torch.cuda.is_available():
        model.load_state_dict(checkpoint['state_dict'], False)  # define GPU device (cuda[0])
        optimizer.load_state_dict(checkpoint['optimizer']) # cuda related code not complete
    else:
        model.load_state_dict(checkpoint['state_dict'], False)
        optimizer.load_state_dict(checkpoint['optimizer'])
    epoch = checkpoint['epoch']+1
    loss = checkpoint['loss']
    acc = checkpoint['accuracy']
    return model, optimizer, epoch, loss, acc
    
def batch_accuracy(
    prediction, # torch.Size([16, 6, 256, 128]
    gt # torch.Size([16, 1, 256, 128]
    ): 
    acc = (torch.squeeze(gt.long(), dim=1) == torch.argmax(prediction.long(), dim=1)+1).float().mean().item() 
    return acc # prediction class indices starts from 0. +1 to match label values. 

def split_train_test_val(
    dataset
    ):
    len_train = int(0.9*len(dataset))
    len_val = int(0.05*len(dataset))
    len_test = len(dataset)-len_train-len_val
    train_data, val_data, test_data = random_split(dataset, [len_train, len_val, len_test])
    return train_data, val_data, test_data

def get_loaders(
    bs
    ):
    dataset = DataGenerator()
    train_data, val_data, test_data = split_train_test_val(dataset)
    train_loader = DataLoader(train_data, batch_size=bs, shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_data, batch_size=bs, shuffle=True, pin_memory=True)
    test_loader = DataLoader(test_data, batch_size=bs, shuffle=True, pin_memory=True) 
    torch.save(test_loader, 'test_loader.pth')         
    return train_loader, val_loader

def save_prediction_as_imgs(
    data, # torch.Size([16, 1, 256, 128])
    label, # torch.Size([16, 1, 256, 128])
    prediction, # torch.Size([16, 6, 256, 128])
    e, 
    ix,
    ):
    script_dir = os.path.abspath('')
    results_dir = os.path.join(script_dir, 'Validation images/')
    if not os.path.isdir(results_dir):
        os.makedirs(results_dir)
    
    data, label, prediction = torch.squeeze(data, dim=1), torch.squeeze(label, dim=1), torch.argmax(prediction, dim=1) # torch.Size([16, 256, 128])
    prediction = prediction+1 # class indices start from zero. here to get the orignal class naming.
    plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = False
    plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = True

    for i, (D, L, P) in enumerate(zip(data, label, prediction), 1): 
        fig, ax = plt.subplots(1, 3, figsize=(20,18))
        im0 = ax[0].imshow(D, aspect='auto', cmap='RdBu') # torch.Size([256, 128])
        im1 = ax[1].imshow(L, aspect='auto')
        im2 = ax[2].imshow(P, aspect='auto')
        ax[0].set_title('Image')
        ax[1].set_title('Label')
        ax[2].set_title('Prediction')
        ax[0].set_ylabel('Height')
        ax[1].set_xlabel('Width') # loc='center'
        cbar0 = fig.colorbar(im0, ax=ax[0], orientation='horizontal', pad=0.02)
        cbar1 = fig.colorbar(im1, ax=ax[1], orientation='horizontal', pad=0.02)
        cbar2 = fig.colorbar(im2, ax=ax[2], orientation='horizontal', pad=0.02)
        cbar0.ax.set_ylabel(r'Amplitude', fontsize=8)
        cbar1.ax.set_ylabel(r'Class', fontsize=8)
        cbar2.ax.set_ylabel(r'Class', fontsize=8)
        plt.suptitle('Validation of random images (Epoch:{}, Batch:{}.{})'.format(e, ix, i))
        file_name = 'Epoch_{}_Batch_{}.{}_Validation of random images.png'.format(e, ix, i)
        plt.savefig(results_dir + file_name)
        #plt.show()
        plt.clf() 

def loss_EarlyStopping_plot( # thanks MNIST_Early_Stopping_example.ipynb
    train_loss,
    val_loss
    ):
  
    plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = True
    plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = False

    # loss
    fig = plt.figure(figsize=(10,8))
    plt.plot(range(1,len(train_loss)+1),train_loss, label='Training Loss')
    plt.plot(range(1,len(val_loss)+1),val_loss,label='Validation Loss')

    # min_val_loss
    min_val_loss = val_loss.index(min(val_loss))+1 
    plt.axvline(min_val_loss, linestyle='--', color='r',label='Early Stopping Checkpoint')

    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.ylim(0, 0.5) # consistent scale
    plt.xlim(0, len(train_loss)+1) # consistent scale
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    #plt.show()
    plt.savefig('loss_EarlyStopping_plot.png', bbox_inches='tight')

def loss_acc_plot( # thanks Rath S. R. (2021). Saving and Loading the Best Model in PyTorch. Debugger cafe.
    train_loss, 
    train_acc, 
    val_loss, 
    val_acc
    ):
    # accuracy plots
    plt.figure(figsize=(10, 7))
    plt.plot(
        train_acc, color='green', linestyle='-', 
        label='train accuracy'
    )
    plt.plot(
        val_acc, color='blue', linestyle='-', 
        label='validation accuracy'
    )
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    plt.savefig('Accuracy.png')
    
    # loss plots
    plt.figure(figsize=(10, 7))
    plt.plot(
        train_loss, color='orange', linestyle='-', 
        label='train loss'
    )
    plt.plot(
        val_loss, color='red', linestyle='-', 
        label='validation loss'
    )
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig('Loss.png')

def save_model(
    model
    ):
    print('Completed experiment \n'
            '=> Saving model')
    torch.save(model.state_dict(), 'complete_trained_model.pth')

train.py

In [None]:
#from utils import (save_checkpoint, resume_checkpoint, batch_accuracy, get_loaders, save_prediction_as_imgs, loss_EarlyStopping_plot, loss_acc_plot, save_model)
#from early-stopping import EarlyStopping
#from model import UNET
import torch
import torch.cuda
import torch.cuda.amp
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import wandb
from tqdm import tqdm

# Assign utils and constant
device = 'cuda' if torch.cuda.is_available() else 'cpu' 

def train_fn(model, train_loader, criterion, optimizer, scaler, epoch):

    train_losses, avg_train_losses, train_accs, avg_train_accs = [], [], [], []  
    train_loop = tqdm(train_loader)
    model.train()   
    for ix, (data, label) in enumerate(train_loop, 1):
 
        if torch.cuda.is_available(): # cuda related code not complete
            data, label = data, label
                
            # forward propagation
            with torch.cuda.amp.autocast():
                prediction = model(data)
                loss = criterion(prediction, torch.squeeze(label.type(torch.LongTensor)-1, dim=1)) 

        else:
            data, label = data, label # torch.Size([16, 1, 256, 128]
            prediction = model(data)
            loss = criterion(prediction, torch.squeeze(label.type(torch.LongTensor)-1, dim=1)) # CrossEntropy() accepts [N, W, H] and class indices starting 0
            acc = batch_accuracy(prediction, label)
            
        # tracker, wandb
        train_losses.append(loss.item()), train_accs.append(acc)
        wandb.log({'train_loss': loss, 'train_accuracy': acc, 'epoch': epoch, 'batch': ix})
        wandb.watch(model)
          
        # backward propagation
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
  
        # update tqdm loop / loss in forward propagation
        train_loop.set_postfix({'train_loss':loss.item(), 'train_acc':acc}) 
    
    train_loss, train_acc = np.average(train_losses), np.average(train_accs)
    avg_train_losses.append(train_loss), avg_train_accs.append(train_acc)

    return model, optimizer, avg_train_losses, avg_train_accs

def val_fn(val_loader, model, criterion, epoch, early_stopping):

    val_loop = tqdm(val_loader)
    val_losses, avg_val_losses, val_accs, avg_val_accs = [], [], [], []
    model.eval() 

    with torch.no_grad():    
        for ix, (data, label) in enumerate(val_loop, 1):
            if torch.cuda.is_available(): # cuda related code are not complete
                data, label = data, label 
                
            else:
                data, label = data, label # torch.Size([16, 1, 256, 128])
                
            # forward pass
            prediction = model(data)
            loss = criterion(prediction, torch.squeeze(label.type(torch.LongTensor)-1, dim=1)) # CrossEntropy() accepts [N, W, H] and class indices starting 0
            acc = batch_accuracy(prediction, label)

            # loss accuracy tracker, tqdm, wandb, utils, EarlyStopping
            val_losses.append(loss.item()), val_accs.append(acc)
            val_loop.set_postfix({'val_loss':loss.item(), 'val_acc':acc})
            wandb.log({'val_loss': loss, 'val_accuracy': acc, 'epoch': epoch, 'batch': ix})
            # wandb.log({
            #     'confusion_matrix': 
            #         wandb.plot.confusion_matrix( 
            #             preds = torch.argmax(prediction, dim=1),
            #             y_true = torch.squeeze(label.long, dim=1),
            #             class_names = ['Class 1', 'Class 2', 'Class 3', 'Class 4', 'Class 5', 'Class 6']),
            #     })
            save_prediction_as_imgs(data, label, prediction, epoch, ix)
         
        val_loss, val_acc = np.average(val_losses), np.average(val_accs)
        avg_val_losses.append(val_loss), avg_val_accs.append(val_acc)

        early_stopping(val_loss, model)
        if early_stopping.early_stop:
            print('=> Early stopping')
            
    return avg_val_losses, avg_val_accs

def train(model, lr, epoch, bs, criterion, optimizer, scaler, patience, load_checkpoint, entity, project):

    # wandb, EarlyStopping, loss & accuracy tracker
    wandb.init(project=project, entity=entity)
    wandb.config = {
        'learning_rate': lr,
        'epochs': epoch,
        'batch_size': bs
    }
    early_stopping = EarlyStopping(patience=patience, verbose=True, path='EarlyStopping_checkpoint.pth')
    train_loss, val_loss, train_acc, val_acc = [], [], [], []

    # assign the defined paramteres and hyperparameters
    model = eval(model)(in_channels=1, out_channels=6).to(device)
    criterion = getattr(nn, criterion)()
    optimizer = getattr(optim, optimizer)(model.parameters(), lr=lr)
    scaler = getattr(torch.cuda.amp, scaler)()
    train_loader, val_loader = get_loaders(bs=bs)

    # resume checkpoint of partially trained model
    if load_checkpoint:
        model, optimizer, resume_epoch, resume_train_loss, resume_train_acc = resume_checkpoint(model, optimizer) 
        avg_val_loss, avg_val_acc = val_fn(val_loader, model, criterion, resume_epoch-1, early_stopping)
        train_loss.append(resume_train_loss.item()), train_acc.append(resume_train_acc.item()), val_loss.append(avg_val_loss.pop()), val_acc.append(avg_val_acc.pop()) # resume_train_loss: torch.size([]), val_loss: (1, 0) <class 'list'>
        
        for e in range(resume_epoch, epoch+1):
            model, optimizer, avg_train_loss, avg_train_acc = train_fn(model, train_loader, criterion, optimizer, scaler, e)
            train_loss.append(avg_train_loss.pop()), train_acc.append(avg_train_acc.pop())
            save_checkpoint(model, optimizer, e, train_loss, train_acc)
            avg_val_loss, avg_val_acc = val_fn(val_loader, model, criterion, e, early_stopping)
            val_loss.append(avg_val_loss.pop()), val_acc.append(avg_val_acc.pop())

    else:
        for e in range(1, epoch+1):
            model, optimizer, avg_train_loss, avg_train_acc = train_fn(model, train_loader, criterion, optimizer, scaler, e)
            train_loss.append(avg_train_loss.pop()), train_acc.append(avg_train_acc.pop())
            save_checkpoint(model, optimizer, e, train_loss, train_acc)
            avg_val_loss, avg_val_acc = val_fn(val_loader, model, criterion, e, early_stopping)
            val_loss.append(avg_val_loss.pop()), val_acc.append(avg_val_acc.pop())
    
    loss_EarlyStopping_plot(train_loss, val_loss), loss_acc_plot(train_loss, train_acc, val_loss, val_acc)
    save_model(model)

main.py

In [None]:
#from train import train

load_checkpoint = False # not accurate to resume loaders

# hyperparameters / model parameters
lr =  1e-3
bs = 16
epoch = 10
model = 'UNET'
criterion = 'CrossEntropyLoss' 
optimizer = 'Adam' #SGD, Adagrad
scaler = 'GradScaler'
patience = 3

# wandb configurations
entity = 'your_name' # insert wandb username
project = 'your_project_name' # insert wandb project name

def main():   
        train(
                model=model,
                lr=lr,
                epoch=epoch,
                bs=bs,
                criterion=criterion,
                optimizer=optimizer,
                scaler=scaler,
                patience=patience,
                load_checkpoint=load_checkpoint,
                entity=entity,
                project=project
        )

if __name__ == "__main__":
    main()