In [None]:
# Google Colab
# Install huggingface pytorch version transformers packages
from IPython.display import clear_output
!pip install transformers
clear_output()

In [None]:
# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

# Huggingface Pytorch Transformers
import transformers
from transformers import MobileViTFeatureExtractor, MobileViTForSemanticSegmentation, get_linear_schedule_with_warmup

# albumentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

# OpenCV
import cv2
from google.colab.patches import cv2_imshow

# Others
import os
import math
from PIL import Image
from tqdm import tqdm

from collections import defaultdict
import random
import numpy as np
import pandas as pd

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

Only works on Colab

In [None]:
import os
from google.colab import drive
# Google Colab with Personal Google Drive
drive.mount('/content/drive')
# Change to project folder
path = r"/content/drive/MyDrive/COMP6200 Master Project"
os.chdir(path)
os.path.abspath(os.curdir)

In [None]:
# device = "cuda:0" if torch.cuda.is_available() else "cpu"

print(torch.__version__)
print(torch.cuda.is_available())

if torch.cuda.is_available():
  print(torch.cuda.device_count())
  print(torch.cuda.get_device_name(0))

# Iris Segmentation Datasets

## Preprocessing

In [None]:
def binary_the_mask(mask_path):
  mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
  _, mask_binary = cv2.threshold(mask, 127, 1, cv2.THRESH_BINARY)
  mask_binary = mask_binary.astype(np.uint8)

  return mask_binary

class DatasetSeg(Dataset):

    def __init__(self, root, dataset_name, transform=None):
        self.root = root
        self.dataset_name = dataset_name
        self.images_dir = 'image'
        self.masks_dir = 'SegmentationClass'

        self.images_path = os.path.join(self.root, self.images_dir)
        self.masks_path = os.path.join(self.root, self.masks_dir)
        self.images_list = list(os.listdir(self.images_path))
        self.masks_list = list(os.listdir(self.masks_path))

        self.transform = transform


    def __len__(self):
      return len(self.images_list) # how many pictures
  
    def __getitem__(self, idx):
        image_filename = self.images_list[idx].split('.')[0]

        if self.dataset_name == 'CASIA-Iris-Africa' or self.dataset_name == 'CASIA-Iris-Asia/CASIA-distance':
          image = cv2.imread(os.path.join(self.images_path, (image_filename + '.JPEG')), cv2.IMREAD_UNCHANGED)
          mask = cv2.imread(os.path.join(self.masks_path, (image_filename +'.png')), cv2.IMREAD_UNCHANGED)
          mask_binary = binary_the_mask(os.path.join(self.masks_path, (image_filename +'.png')))

        if self.dataset_name == 'CASIA-Iris-Asia/CASIA-Iris-Complex/Occlusion' or self.dataset_name == 'CASIA-Iris-Asia/CASIA-Iris-Complex/Off_angle':
          image = cv2.imread(os.path.join(self.images_path, (image_filename + '.jpg')), cv2.IMREAD_UNCHANGED)
          mask = cv2.imread(os.path.join(self.masks_path, (image_filename +'.png')), cv2.IMREAD_UNCHANGED)
          mask_binary = binary_the_mask(os.path.join(self.masks_path, (image_filename +'.png')))

        if self.dataset_name == 'CASIA-Iris-Asia/CASIA-Iris-M1':
          image = cv2.imread(os.path.join(self.images_path, (image_filename + '.JPG')), cv2.IMREAD_UNCHANGED)
          mask = cv2.imread(os.path.join(self.masks_path, (image_filename +'.png')), cv2.IMREAD_UNCHANGED)
          mask_binary = binary_the_mask(os.path.join(self.masks_path, (image_filename +'.png')))

        if self.dataset_name == 'UBIRIS_v2_seg':
          image = cv2.imread(os.path.join(self.images_path, (image_filename + '.tiff')), cv2.IMREAD_UNCHANGED)
          mask = cv2.imread(os.path.join(self.masks_path, (image_filename +'.tiff')), cv2.IMREAD_UNCHANGED)
          mask_binary = binary_the_mask(os.path.join(self.masks_path, (image_filename +'.tiff')))

        if self.dataset_name =='MICHE_seg':
          image = cv2.imread(os.path.join(self.images_path, (image_filename + '.JPEG')), cv2.IMREAD_UNCHANGED)
          mask = cv2.imread(os.path.join(self.masks_path, (image_filename +'.png')), cv2.IMREAD_UNCHANGED)
          mask_binary = binary_the_mask(os.path.join(self.masks_path, (image_filename +'.png')))

        if self.transform is not None:
          augmentation = self.transform(image=image, mask=mask_binary)
          image_aug = augmentation['image']
          mask_binary_aug = augmentation['mask']
          return image_aug, mask_binary_aug

        return image, mask_binary

class DatasetAfterSplit(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, idx):
        image, mask_binary = self.subset[idx]

        if self.transform is not None:
          augmentation = self.transform(image=image, mask=mask_binary)
          image_aug = augmentation['image']
          mask_binary_aug = augmentation['mask']
          return image_aug, mask_binary_aug

        return image, mask_binary

    def __len__(self):
        return len(self.subset)


class DatasetMaskTransform(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, idx):
        image, mask_binary = self.subset[idx]

        if self.transform is not None:
          image = ToTensorV2()(image=image)['image']
          augmentation = self.transform(image=mask_binary)
          mask_binary_aug = augmentation['image']
          return image, mask_binary_aug
        
        return image, mask_binary

    def __len__(self):
        return len(self.subset)

In [None]:
size = (400,400)
resize_transform = A.Compose(
    [
        A.Resize(height=size[0], width=size[1], interpolation=cv2.INTER_NEAREST),
    ]
)
UBIRIS_v2_seg_train = DatasetSeg('./data/UBIRIS_v2_seg/train', 'UBIRIS_v2_seg', transform=resize_transform)
UBIRIS_v2_seg_test = DatasetSeg('./data/UBIRIS_v2_seg/test', 'UBIRIS_v2_seg', transform=resize_transform)

In [None]:
image_resized, mask_binary_resized = UBIRIS_v2_seg_train[0]
print(image_resized.shape)
print(mask_binary_resized.shape)

## Split Data 

In [None]:
def split_train_dataset(train_dataset_all):
  train_size= int(len(train_dataset_all)*0.8)
  val_size= int(len(train_dataset_all)-train_size)
  train_dataset, val_dataset = torch.utils.data.random_split(train_dataset_all, [train_size, val_size], generator=torch.Generator().manual_seed(2022))
  
  return train_dataset, val_dataset

UBIRIS_v2_seg_train_dataset, UBIRIS_v2_seg_val_dataset = split_train_dataset(UBIRIS_v2_seg_train)

In [None]:
len(UBIRIS_v2_seg_train_dataset)

In [None]:
len(UBIRIS_v2_seg_val_dataset)

In [None]:
#size = (512,512)
train_transform = A.Compose(
    [
      A.ShiftScaleRotate(shift_limit=0.12, scale_limit=0.15, rotate_limit=90, border_mode=cv2.BORDER_CONSTANT, value=0, p=0.5),
      A.HorizontalFlip(p=0.5),
     
      A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, always_apply=False, p=0.5),
      A.RandomGamma(p=0.5),

      A.OneOf([
          A.ElasticTransform(alpha=52, sigma=31, alpha_affine=31,border_mode=cv2.BORDER_CONSTANT, value=0, p=1),
          A.OpticalDistortion(distort_limit=0.55, shift_limit=0.12,border_mode=cv2.BORDER_CONSTANT, value=0, p=1),
      ], p=0.5),
     
      A.OneOf([
          A.GaussNoise(p=1),
          A.Downscale(p=0.5),
      ], p=0.5),
      A.Normalize(mean=(0.30213674, 0.35349771, 0.49928186), std=(0.14017686, 0.14544913, 0.17964584)), 
    ]
)

val_test_transform = A.Compose(
    [
     # UBIRIS_v2_seg mean and std
     A.Normalize(mean=(0.30213674, 0.35349771, 0.49928186), std=(0.14017686, 0.14544913, 0.17964584)),

    ]
)

In [None]:
UBIRIS_v2_seg_train_dataset_to_aug, _ = torch.utils.data.random_split(UBIRIS_v2_seg_train_dataset, 
                                                                      [300, len(UBIRIS_v2_seg_train_dataset)-300], 
                                                                      generator=torch.Generator().manual_seed(2022))
UBIRIS_v2_seg_train_dataset_aug1 = DatasetAfterSplit(UBIRIS_v2_seg_train_dataset_to_aug, transform=train_transform)

UBIRIS_v2_seg_train_dataset_all = torch.utils.data.ConcatDataset([UBIRIS_v2_seg_train_dataset, UBIRIS_v2_seg_train_dataset_aug1])

print(len(UBIRIS_v2_seg_train_dataset_aug1))
print(len(UBIRIS_v2_seg_train_dataset_all))

In [None]:
UBIRIS_v2_seg_val_dataset_all = DatasetAfterSplit(UBIRIS_v2_seg_val_dataset, transform=val_test_transform)
UBIRIS_v2_seg_test_dataset_all = DatasetAfterSplit(UBIRIS_v2_seg_test, transform=val_test_transform)
print(len(UBIRIS_v2_seg_val_dataset_all))
print(len(UBIRIS_v2_seg_test_dataset_all))

In [None]:
mask_resize_transform = A.Compose(
    [
     A.Resize(height=400, width=400, interpolation=cv2.INTER_NEAREST),
     ToTensorV2()
    ]
)


In [None]:
UBIRIS_v2_seg_train_dataset_all = DatasetMaskTransform(UBIRIS_v2_seg_train_dataset_all, transform=mask_resize_transform)
UBIRIS_v2_seg_val_dataset_all = DatasetMaskTransform(UBIRIS_v2_seg_val_dataset_all, transform=mask_resize_transform)
UBIRIS_v2_seg_test_dataset_all = DatasetMaskTransform(UBIRIS_v2_seg_test_dataset_all, transform=mask_resize_transform)

## Settings

Hugging Face Models:\
apple/deeplabv3-mobilevit-small\
apple/deeplabv3-mobilevit-x-small\
apple/deeplabv3-mobilevit-xx-small

In [None]:
params = {
    'model': "apple/deeplabv3-mobilevit-small",
    'device': "cuda:0" if torch.cuda.is_available() else "cpu",
    'batch_size': 16,
    'num_workers': 0,
    'lr': 0.0001,
    'epochs': 5,
    'num_labels': 1,
    'label2id':{
      'iris': 1,
    },
    'id2label':{
      '1': 'iris',
    },
}

## DataLoader


In [None]:
train_loader = DataLoader(UBIRIS_v2_seg_train_dataset_all,
                          batch_size=params['batch_size'],
                          shuffle=True,
                          num_workers=params['num_workers'],
                          pin_memory=True)

val_loader = DataLoader(UBIRIS_v2_seg_val_dataset_all,
                          batch_size=params['batch_size'],
                          shuffle=True,
                          num_workers=params['num_workers'],
                          pin_memory=True)

test_loader = DataLoader(UBIRIS_v2_seg_test_dataset_all,
                         batch_size=params['batch_size'],
                         shuffle=True,
                         num_workers=params['num_workers'],
                         pin_memory=True)

## Fine tuning models

In [None]:
# HuggingFace MobileViT Model
IrisViT_seg_model = MobileViTForSemanticSegmentation.from_pretrained(params['model'])

In [None]:
# del pytorch_model
# del trainer
# del IrisViT_seg_model
# torch.cuda.empty_cache()

In [None]:
IrisViT_seg_model.segmentation_head.classifier.convolution = nn.Sequential(nn.Conv2d(256, 1, kernel_size=(1,1), stride=(1,1)),
                                                                           nn.UpsamplingNearest2d((640, 640)))

# Initialize Classifier Weights to 0
# IrisViT_seg_model.segmentation_head.classifier.convolution.weight.data.fill_(0.0)
# IrisViT_seg_model.segmentation_head.classifier.convolution.bias.data.fill_(0.0)
# Change model to fit our task
IrisViT_seg_model.config.id2label = params['id2label']
IrisViT_seg_model.config.label2id = params['label2id']
IrisViT_seg_model.config.num_labels = params['num_labels']


model_save_path = r"./models/segmentation/all_datasets/IrisViT_seg_model_small.pth"
IrisViT_seg_model.load_state_dict(torch.load(model_save_path, map_location=params['device']))


IrisViT_seg_model.segmentation_head.classifier.convolution[1] = nn.UpsamplingNearest2d((400, 400))

# To GPU
IrisViT_seg_model.to(params['device'])

In [None]:
# print(IrisViT_seg_model)
# print(IrisViT_seg_model.config)
# print(IrisViT_seg_model.segmentation_head.classifier.convolution.weight)
# print(IrisViT_seg_model.segmentation_head.classifier.convolution.bias)
# print(IrisViT_seg_model.segmentation_head.classifier.convolution.weight.shape)
# print(IrisViT_seg_model.segmentation_head.classifier.convolution.bias.shape)

In [None]:
# for name, param in IrisViT_seg_model.named_parameters():
#   print(name)

## Loss Function

In [None]:
# Copy from https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook
class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = torch.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss
        
        return Dice_BCE

## Segmentation Evaluation

In [None]:
def get_true_false_positive_negative(pred_mask, true_mask):

    h,w = true_mask.size()
    num_pixel = h*w

    pred_mask = pred_mask>0
    true_mask = true_mask>0

    true_positive = (true_mask & pred_mask).sum()
    false_positive = (~true_mask & pred_mask).sum()
    true_negative = (~(true_mask | pred_mask)).sum()
    false_negative = (true_mask & (~pred_mask)).sum()

    return {'true_positive': true_positive/num_pixel,
            'false_positive': false_positive/num_pixel,
            'true_negative': true_negative/num_pixel,        
            'false_negative': false_negative/num_pixel}

# [0,1] The lower the better 0 is the best, 1 is the worst
def get_E1(batch_size, pred_masks, true_masks):

    e1_sum = 0
    for i in range(batch_size):
      tpfn = get_true_false_positive_negative(pred_masks[i], true_masks[i])
      fp, fn = tpfn['false_positive'], tpfn['false_negative']
      e1_sum += (fp+fn)

    return e1_sum/batch_size

# [0,1] The lower the better, 0 is the best, 1 is the worst
def get_E2(batch_size, pred_masks, true_masks):

    e2_sum = 0
    for i in range(batch_size):
        tpfn = get_true_false_positive_negative(pred_masks[i], true_masks[i])
        fp, fn = tpfn['false_positive'], tpfn['false_negative']
        e2_sum += 0.5*(fp+fn)

    return e2_sum/batch_size

# [0,1] 0 is the worst, 1 is the best, mean higher best, var lower best
def get_Precision(batch_size, pred_masks, true_masks):
    precision_list=[]

    for i in range(batch_size):
        tpfn = get_true_false_positive_negative(pred_masks[i], true_masks[i])
        tp, fp = tpfn['true_positive'], tpfn['false_positive']
        precision = tp/(tp+fp)
        precision_list.append(precision)
    
    precision_mean = torch.mean(torch.FloatTensor(precision_list))
    precision_var = torch.var(torch.FloatTensor(precision_list))

    return precision_list, precision_mean, precision_var

# [0,1] 0 is the worst, 1 is the best, mean higher best, var lower best
def get_Recall(batch_size, pred_masks, true_masks):
    recall_list = []
    for i in range(batch_size):
        tpfn = get_true_false_positive_negative(pred_masks[i], true_masks[i])
        tp, fn = tpfn['true_positive'], tpfn['false_negative']
        recall = tp/(tp+fn)
        recall_list.append(recall)
    
    recall_mean = torch.mean(torch.FloatTensor(recall_list))
    recall_var = torch.var(torch.FloatTensor(recall_list))

    return recall_list, recall_mean, recall_var

# [0,1] 0 is the worst, 1 is the best, mean higher best, var lower best
def get_F1(batch_size, pred_masks, true_masks):
    f1_list = []
    precision_list, _, _, = get_Precision(batch_size, pred_masks, true_masks)
    recall_list, _, _, = get_Recall(batch_size, pred_masks, true_masks)
    for i in range(batch_size):
        f1 = (2*recall_list[i]*precision_list[i])/(recall_list[i]+precision_list[i])
        f1_list.append(f1)
    
    f1_mean = torch.mean(torch.FloatTensor(f1_list))
    f1_var = torch.var(torch.FloatTensor(f1_list))

    return f1_list, f1_mean, f1_var


# [0,1] 0 is the worst, 1 is the best The higher the better
def get_mIoU(batch_size, pred_masks, true_masks):
    miou_sum = 0
    for i in range(batch_size):
        tfpn = get_true_false_positive_negative(pred_masks[i], true_masks[i])
        tp, fp, fn = tfpn['true_positive'], tfpn['false_positive'], tfpn['false_negative']
        if tp+fn+fp == 0:
            miou=1
        else:
            miou=tp/(tp+fn+fp)
        miou_sum += miou

    return miou_sum/batch_size

def evaluate_segmentation(pred_masks, true_masks):

    batch_size = true_masks.size()[0]

    e1 = get_E1(batch_size,  pred_masks, true_masks)
    miou = get_mIoU(batch_size, pred_masks, true_masks)

    e2 = get_E2(batch_size,  pred_masks, true_masks)
    _, precision_mean, precision_var = get_Precision(batch_size, pred_masks, true_masks)
    _, recall_mean, recall_var = get_Recall(batch_size, pred_masks, true_masks)
    _, f1_mean, f1_var = get_F1(batch_size, pred_masks, true_masks)

    return {'E1': e1, 
            'mIoU': miou,
            'E2': e2,
            'Precision_Mean': precision_mean,
            'Precision_Var': precision_var, 
            'Recall_Mean': recall_mean,
            'Recall_Var': recall_var, 
            'F1_Mean': f1_mean,
            'F1_Var': f1_var}

## Train, Validation, Test

In [None]:
class MetricMonitor:
    def __init__(self, float_precision=5):
        self.float_precision = float_precision
        self.reset()

    def reset(self):
        self.metrics = defaultdict(lambda: {"val": 0, "count": 0, "avg": 0})

    def update(self, metric_name, val):
        metric = self.metrics[metric_name]

        metric["val"] += val
        metric["count"] += 1
        metric["avg"] = metric["val"] / metric["count"]

    def __str__(self):
        return " | ".join(
            [
                "{metric_name}: {avg:.{float_precision}f}".format(
                    metric_name=metric_name, avg=metric["avg"], float_precision=self.float_precision
                )
                for (metric_name, metric) in self.metrics.items()
            ]
        )

In [None]:
def validate(val_loader, model, criterion, epoch, params):
    metric_monitor = MetricMonitor()
    validate_writer = SummaryWriter('./runs/UBIRIS_v2_seg/deeplabv3-mobilevit-small/validate')
    model.eval()
    stream = tqdm(val_loader)
    with torch.no_grad():
        for i, (images, masks) in enumerate(stream, start=0):
            images = images.float().to(params["device"], non_blocking=True)
            masks = masks.float().to(params["device"], non_blocking=True)
            output = model(images)
            logits = output.logits
            
            probabilities_writer = torch.sigmoid(logits)
            predicted_masks_writer = (probabilities_writer >= 0.5).float() * 1

            # Record all masks
            validate_writer.add_images('validate_predicted_iris_masks/Epoch {epoch}, Batch {batch}'.format(epoch=epoch, batch=i), 
                                        predicted_masks_writer, global_step=epoch, walltime=None, dataformats='NCHW')

            probabilities = torch.sigmoid(logits.squeeze(1))
            predicted_masks = (probabilities >= 0.5).float() * 1

            loss = criterion(logits, masks)
            evaluations = evaluate_segmentation(predicted_masks, masks.squeeze(1))

            # tag, scalar_value, global_step=None, walltime=None
            validate_writer.add_scalar('validate/Loss', loss.item(), global_step=epoch)
            validate_writer.add_scalar('validate/E1', evaluations['E1'].item(), global_step=epoch)
            validate_writer.add_scalar('validate/E2', evaluations['E2'].item(), global_step=epoch)
            validate_writer.add_scalar('validate/mIoU', evaluations['mIoU'].item(), global_step=epoch)
            validate_writer.add_scalars('validate/Precision_Mean_and_Var', {'Validate_Precision_Mean':evaluations['Precision_Mean'].item(),
                                                                            'Validate_Precision_Var':evaluations['Precision_Var'].item()}, global_step=epoch)
            validate_writer.add_scalars('validate/Recall_Mean_and_Var', {'Validate_Recall_Mean':evaluations['Recall_Mean'].item(),
                                                                        'Validate_Recall_Var':evaluations['Recall_Var'].item()}, global_step=epoch)
            validate_writer.add_scalars('validate/F1_Mean_and_Var', {'Validate_F1_Mean':evaluations['F1_Mean'].item(),
                                                                     'Validate_F1_Var':evaluations['F1_Var'].item()}, global_step=epoch)

            metric_monitor.update("Loss", loss.item())
            metric_monitor.update("E1", evaluations['E1'].item())
            metric_monitor.update("mIoU",  evaluations['mIoU'].item())

            stream.set_description(
                "Epoch: {epoch}. Validation. {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)
            )

def test(test_loader, model, criterion, epoch, params):
    metric_monitor = MetricMonitor()
    test_writer = SummaryWriter('./runs/UBIRIS_v2_seg/deeplabv3-mobilevit-small/test')
    model.eval()
    stream = tqdm(test_loader)
    with torch.no_grad():
        for i, (images, masks) in enumerate(stream, start=0):
            images = images.float().to(params["device"], non_blocking=True)
            masks = masks.float().to(params["device"], non_blocking=True)
            output = model(images)
            logits = output.logits

            probabilities_writer = torch.sigmoid(logits)
            predicted_masks_writer = (probabilities_writer >= 0.5).float() * 1

            # Record all masks
            test_writer.add_images('test_predicted_iris_masks/Epoch {epoch}, Batch {batch}'.format(epoch=epoch, batch=i), 
                                    predicted_masks_writer, global_step=epoch, walltime=None, dataformats='NCHW')

            probabilities = torch.sigmoid(logits.squeeze(1))
            predicted_masks = (probabilities >= 0.5).float() * 1

            loss = criterion(logits, masks)
            evaluations = evaluate_segmentation(predicted_masks, masks.squeeze(1))

            # tag, scalar_value, global_step=None, walltime=None
            test_writer.add_scalar('test/Loss', loss.item(), global_step=epoch)
            test_writer.add_scalar('test/E1', evaluations['E1'].item(), global_step=epoch)
            test_writer.add_scalar('test/E2', evaluations['E2'].item(), global_step=epoch)
            test_writer.add_scalar('test/mIoU', evaluations['mIoU'].item(), global_step=epoch)
            test_writer.add_scalars('test/Precision_Mean_and_Var', {'Test_Precision_Mean':evaluations['Precision_Mean'].item(),
                                                                    'Test_Precision_Var':evaluations['Precision_Var'].item()}, global_step=epoch)
            test_writer.add_scalars('test/Recall_Mean_and_Var', {'Test_Recall_Mean':evaluations['Recall_Mean'].item(),
                                                                 'Test_Recall_Var':evaluations['Recall_Var'].item()}, global_step=epoch)
            test_writer.add_scalars('test/F1_Mean_and_Var', {'Test_F1_Mean':evaluations['F1_Mean'].item(),
                                                             'Test_F1_Var':evaluations['F1_Var'].item()}, global_step=epoch)

            metric_monitor.update("Loss", loss.item())
            metric_monitor.update("E1", evaluations['E1'].item())
            metric_monitor.update("mIoU",  evaluations['mIoU'].item())

            stream.set_description(
                "Epoch: {epoch}. Test. {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)
            )

def train(train_loader, model, criterion, optimizer, epoch, params):
    metric_monitor = MetricMonitor()
    train_writer = SummaryWriter('./runs/UBIRIS_v2_seg/deeplabv3-mobilevit-small/train')
    model.train()
    stream = tqdm(train_loader)
    for i, (images, masks) in enumerate(stream, start=0):
        images = images.float().to(params['device'], non_blocking=True)
        masks = masks.float().to(params['device'], non_blocking=True)
        output = model(images)
        logits = output.logits

        probabilities_writer = torch.sigmoid(logits)
        predicted_masks_writer = (probabilities_writer >= 0.5).float() * 1

        # Record all masks
        train_writer.add_images('train_predicted_iris_masks/Epoch {epoch}, Batch {batch}'.format(epoch=epoch, batch=i), 
                                predicted_masks_writer, global_step=epoch, walltime=None, dataformats='NCHW')

        probabilities = torch.sigmoid(logits.squeeze(1))
        predicted_masks = (probabilities >= 0.5).float() * 1
        loss = criterion(logits, masks)
        evaluations = evaluate_segmentation(predicted_masks, masks.squeeze(1))

        # tag, scalar_value, global_step=None, walltime=None
        train_writer.add_scalar('train/Loss', loss.item(), global_step=epoch)
        train_writer.add_scalar('train/E1', evaluations['E1'].item(), global_step=epoch)
        train_writer.add_scalar('train/E2', evaluations['E2'].item(), global_step=epoch)
        train_writer.add_scalar('train/mIoU', evaluations['mIoU'].item(), global_step=epoch)
        train_writer.add_scalars('train/Precision_Mean_and_Var', {'Train_Precision_Mean':evaluations['Precision_Mean'].item(),
                                                                  'Train_Precision_Var':evaluations['Precision_Var'].item()}, global_step=epoch)
        train_writer.add_scalars('train/Recall_Mean_and_Var', {'Train_Recall_Mean':evaluations['Recall_Mean'].item(),
                                                               'Train_Recall_Var':evaluations['Recall_Var'].item()}, global_step=epoch)
        train_writer.add_scalars('train/F1_Mean_and_Var', {'Train_F1_Mean':evaluations['F1_Mean'].item(),
                                                           'Train_F1_Var':evaluations['F1_Var'].item()}, global_step=epoch)

        metric_monitor.update("Loss", loss.item())
        metric_monitor.update("E1", evaluations['E1'].item())
        metric_monitor.update("mIoU",  evaluations['mIoU'].item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        stream.set_description(
            "Epoch: {epoch}. Train. {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)
        )

In [None]:
def train_validate_test(model, train_loader, val_loader, params):
    criterion = DiceBCELoss().to(params['device'])
    # # Freeze layers by not tracking gradients
    # for param in model.parameters():
    #     param.requires_grad = False
    # model.segmentation_head.classifier.convolution[0].weight.requires_grad = True
    # model.segmentation_head.classifier.convolution[0].bias.requires_grad = True

    # optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters())) #, lr=params['lr'], weight_decay=0.0001
    optimizer = optim.AdamW(model.parameters())#, lr=params['lr'], weight_decay=0.0001
    for epoch in range(params["epochs"]):
        train(train_loader, model, criterion, optimizer, epoch, params)
        validate(val_loader, model, criterion, epoch, params)
    test(test_loader, model, criterion, 0, params)
    return model

In [None]:
IrisViT_seg_model_trained = train_validate_test(IrisViT_seg_model, train_loader=train_loader, val_loader=val_loader, params=params)

In [None]:
def predict(model, params, test_loader):
    model.eval()
    predictions = []
    with torch.no_grad():
        for images, masks in test_loader:
            images = images.to(params["device"], non_blocking=True)
            output = model(images)
            logits = output.logits
            probabilities = torch.sigmoid(logits.squeeze(1))
            predicted_masks = (probabilities >= 0.5).float() * 1
            predicted_masks = predicted_masks.cpu().numpy()
            for predicted_mask, original_mask in zip(
                predicted_masks, masks.squeeze(0).numpy()
            ):
                predictions.append((predicted_mask, original_mask))
    return predictions

In [None]:
predictions = predict(IrisViT_seg_model_trained, params, test_loader=test_loader)

In [None]:
def visualize(image, mask):
    fig, ax =  plt.subplots(nrows=1, ncols=2, figsize=(5,5))
    ax[0].axis('off')
    ax[1].axis('off')
    ax[0].imshow(image)
    ax[1].imshow(mask)


In [None]:
predicted_mask, masks = predictions[0]
visualize(predicted_mask, masks.squeeze(0))

In [None]:
%load_ext tensorboard
%tensorboard --logdir './runs/UBIRIS_v2_seg/deeplabv3-mobilevit-small/train'

In [None]:
%tensorboard --logdir './runs/UBIRIS_v2_seg/deeplabv3-mobilevit-small/validate'

In [None]:
%tensorboard --logdir './runs/UBIRIS_v2_seg/deeplabv3-mobilevit-small/test'

In [None]:
model_save_path = r"./models/segmentation/UBIRIS_v2_seg/IrisViT_seg_model_small.pth"
torch.save(IrisViT_seg_model_trained.state_dict(), model_save_path)