# Sementic segmentation of flair1 data with a basic UNet

In [None]:
# import all the necessary modules
import matplotlib.pyplot as plt
from utils import *
from dataset import *
from model import *
from metric import *
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import confusion_matrix
from tqdm import tqdm

## Prepare the data
1) Download the data and split it into training, validation and test sets. 
2) Analyse the balance of classes in the dataset. 
3) Create a Dataloader
4) Load the data

In [None]:
folder_path = '../../data/flair1/'
batch_size = 32
use_gpu = True
max_epochs = 10

In [None]:
toy_ds = Flair1Dataset(folder_path, split="toy_tr")
toy_dl = torch.utils.data.DataLoader(toy_ds, batch_size=batch_size, shuffle=True)
img, msk = next(iter(toy_dl))
plot_image_mask(img[0], msk[0], colors, dict_classes)

In [None]:
train_ds = Flair1Dataset(folder_path, split="train")
val_ds = Flair1Dataset(folder_path, split="valid")
test_ds = Flair1Dataset(folder_path, split="test")

In [None]:
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=True)

## Create the model

## Create the functions for training, validation and test, with plots of the loss

In [None]:
# main training function (trains for one epoch)
def train(model, train_loader, loss_fn, optimizer):

    # set model to train mode
    model.train()

    # main training loop
    pbar = tqdm(total=len(train_loader), desc="[Train]")
    training_loss = 0
    pixel_acc = PixelwiseMetrics(train_loader.dataset.n_classes)
    conf_mat = ConfMatrix(train_loader.dataset.n_classes)
    for batch_idx, (image, target) in enumerate(train_loader):
        #print(image.shape)
        # reset gradients
        optimizer.zero_grad()

        # move data to gpu if model is on gpu
        if use_gpu:
            image, target = image.cuda(), target.cuda()

        # forward pass
        prediction = model(image)
        #print(target.shape)
        training_loss += loss_fn(prediction, target).data.item()
        loss = loss_fn(prediction, target)
        pixel_acc.add_batch(target, prediction.max(1)[1])
        conf_mat.add_batch(target, prediction.max(1)[1])
        # backward pass
        loss.backward()
        optimizer.step()

        # update progressbar

        '''    pbar.set_description("[Train] Loss: {:.4f}".format(
                                round(training_loss.item(), 4)))
        pbar.set_description("[Train] Pixel-wise accuracy: {:.2f}%".format(
                                round(pixel_acc.get_average_accuracy() * 100, 2)))
        pbar.set_description("[Train] mIoU: {:.2f}%".format(
                                round(conf_mat.get_mIoU()) * 100, 2))'''
        pbar.update()

    # close progressbar and flush to disk
    pbar.close()
    training_loss /= len(train_loader)
    return model, training_loss, pixel_acc.get_average_accuracy(), pixel_acc.get_classwise_accuracy(), conf_mat.get_mIoU()

# main validation function (validates current model)
def val(model, val_loader, loss_fn, optimizer):

    # set model to evaluation mode
    model.eval()

    # main validation loop
    pbar = tqdm(total=len(val_loader), desc="[Val]")
    loss = 0
    val_pixel_acc = PixelwiseMetrics(val_loader.dataset.n_classes)
    val_conf_mat = ConfMatrix(val_loader.dataset.n_classes)

    for batch_idx, (image, target) in enumerate(val_loader):

        # move data to gpu if model is on gpu
        if use_gpu:
            image, target = image.cuda(), target.cuda()

        # forward pass
        with torch.no_grad():
            prediction = model(image)
        loss += loss_fn(prediction, target).cpu().item()

        # calculate error metrics
        val_conf_mat.add_batch(target, prediction.max(1)[1])
        val_pixel_acc.add_batch(target, prediction.max(1)[1])

        # update progressbar
        pbar.update()

    '''    # close progressbar
    pbar.set_description("[Val] Pixel-wise accuracy: {:.2f}%".format(
                            round(val_pixel_acc.get_average_accuracy() * 100, 2)))'''
    pbar.close()

    model.train()
    return loss / len(val_loader), val_pixel_acc.get_average_accuracy(), val_pixel_acc.get_classwise_accuracy(), val_conf_mat.get_mIoU()

def export_model(model, optimizer=None, name=None, step=None):

    # set output filename
    if name is not None:
        out_file = name
    else:
        out_file = "checkpoint"
        if step is not None:
            out_file += "_step_" + str(step)
    #out_file = os.path.join(self.args.checkpoint_dir, out_file + ".pth")

    # save model
    data = {"model_state_dict": model.state_dict()}
    if step is not None:
        data["step"] = step
    if optimizer is not None:

        data["optimizer_state_dict"] = optimizer.state_dict()
    torch.save(data, out_file)

## Run the algorithm for a tiny subset

In [None]:
toy_tr_ds = Flair1Dataset(folder_path, split="toy_tr")
toy_tr_dl = torch.utils.data.DataLoader(toy_tr_ds, batch_size=batch_size, shuffle=True)
toy_vl_ds = Flair1Dataset(folder_path, split="toy_vl")
toy_vl_dl = torch.utils.data.DataLoader(toy_vl_ds, batch_size=batch_size, shuffle=True)

n_classes = toy_tr_ds.n_classes
n_inputs = toy_vl_ds.n_inputs
model = UNet(n_classes=n_classes, n_channels=n_inputs) #pixel value from 1 to 18
train_pixelAcc = PixelwiseMetrics(n_classes)
#model = SimpleUNet(3, 18)
# model = model.float()
loss_fn = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.01, weight_decay=5e-4)
#optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
#train_loss, val_loss, train_pixel_acc, val_pixel_acc = train(model, toy_tr_ds, toy_vl_ds, loss_fn, optimizer, epochs=5)
train_losses = []
val_losses = []

for epoch in range(max_epochs):
    print("Epoch: ", epoch)
    # run training for one epoch
    model, tr_loss, tr_mean_acc, tr_class_acc, mIoU = train(model, toy_tr_dl, loss_fn, optimizer)
    train_losses.append(tr_loss)
    # run validation
    val_loss, val_mean_acc, val_class_acc, mIoU = val(model, toy_vl_dl, loss_fn, optimizer)
    val_losses.append(val_loss)
    print("-----------------------------------------------------------------------")
    print("-----------------------------------------------------------------------")
    print("Training loss: ", tr_loss)
    print("Training mean accuracy: ", tr_mean_acc)
    print("Training classwise accuracy: ", tr_class_acc)
    print("Training mIoU: ", mIoU)
    print("-----------------------")
    print("Validation loss: ", val_loss)
    print("Validation mean accuracy: ", val_mean_acc)
    print("Validation classwise accuracy: ", val_class_acc)
    print("Validation mIoU: ", mIoU)    
    #trainer.export_model(model, args.checkpoint_dir, name="final")
plt.plot(train_losses, label='Training loss')
plt.plot(val_losses, label='Validation loss')
plt.legend(frameon=False)
plt.show()

In [None]:
# Run validation with the trained model on vl_dl
val_loss, val_mean_acc, val_class_acc, mIoU = val(model, toy_vl_dl, loss_fn, optimizer)
print("-----------------------")
print("Validation loss: ", val_loss)
print("Validation mean accuracy: ", val_mean_acc)
print("Validation classwise accuracy: ", val_class_acc)
print("Validation mIoU: ", mIoU)

In [None]:
# print validation accuracy class by class
for i in range(len(val_class_acc)):
    print(i+1, dict_classes[i+1], " : ", val_class_acc["pixelclass_"+str(i)])

In [None]:
# call get_per_per_class() to get the percentage of each class in the dataset
toy_tr_ds = Flair1Dataset(folder_path, split="toy_tr")
per_class = toy_tr_ds.get_per_per_class()
plot_per_classes(per_class, dict_classes, colors, title = 'toy training dataset')