# Notebook for UNet training and testing.

- If you  want to test synthetic data with our trained model, download `revised_model_epoch_29.pth` from OneDrive to `/evaluation/segmentation/` folder.  
- If you want to train UNet on CholecSeg8k data, download the CholecSeg8k dataset from OneDrive to `evaulation/segmentation/cholecseg8k` folder.

## Setup

In [1]:
# install dependencies
# !pip install numpy
# !pip install matplotlib
# !pip install torch
# !pip install torchvision
# !pip install tqdm
# !pip install ipywidgets

In [2]:
import torch
import torch.nn as nn
import transforms as T
import cv2
import random
import torchvision
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import glob
import os
from tqdm.notebook import tqdm
import json
import numpy as np
import shutil
from PIL import Image, ImageColor
import io
import torchvision.transforms as transforms

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NUM_EPOCHS = 30
DO_TRAINING = False
FINAL_MODEL_PATH = "revised_model_epoch_29.pth"
MODEL_NAME = "fcn_resnet50" 
os.environ["CUDA_VISIBLE_DEVICES"]="1"

# Defining a color used to depict each semantic class being segmented
META_DATA_ORIGINAL = [
    ("black_background", (0,0,0)),
    ("abdominal_wall", (33, 191, 197)),
    ("liver", (231, 126, 9)),
    ("gastrointestinal_tract", (209, 53, 84)),
    ("fat", (80, 155, 4)),
    ("grasper", (255, 207, 210)),
    ("connective_tissue", (169, 52, 199)),
    ("blood", (229, 18, 18)),
    ("cystic_duct", (149, 50, 18)),
    ("l-hook_electrocautery", (46, 43, 180)),
    ("gallbladder", (148, 55, 66)),
    ("hepatic_vein", (214, 51, 149)),
    ("liver_ligament", (240, 79, 10)),
]

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# THESE 2 LINES REDUCE / MERGE CLASSES. TO REDUCE PUT 255 VALUE. TO MERGE PUT DESIRED CLASS AS VALUE. REMAINING CLASSES HAVE TO BE FROM 0 to N.

CLASSES_TO_IGNORE = ["black_background","gastrointestinal_tract", "connective_tissue", "blood", "cystic_duct", "l-hook_electrocautery","hepatic_vein", "liver_ligament"]
REPLACE_CLASS = {0:255, 1:0, 2:1,3:255,4:2, 5:3,6:255,7:255, 8:255, 9:3, 10:4,11:255, 12:255, 13:3, 14:3, 15:3, 16:3}
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

META_DATA = [x for x in META_DATA_ORIGINAL if x[0] not in CLASSES_TO_IGNORE]


# Optimizer parameters
learning_rate = 0.00125
momentum = 0.9
power = 0.9
weight_decay = 1e-4

## Helper functions and classes

Defining some reusable function that we will use throughout this notebook

In [3]:
def cat_list(images, fill_value=0):
    max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
    batch_shape = (len(images),) + max_size
    batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
    for img, pad_img in zip(images, batched_imgs):
        pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
    return batched_imgs

def collate_fn(batch):
    images, targets = list(zip(*batch))
    batched_imgs = cat_list(images, fill_value=0)
    batched_targets = cat_list(targets, fill_value=255)
    return batched_imgs, batched_targets

# Helper function to do a cross entropy loss between the ground truth and predicted values
def criterion(inputs, target):
    losses = {}
    for name, x in inputs.items():
        losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)
    if len(losses) == 1:
        return losses["out"]
    return losses["out"] + 0.5 * losses["aux"]


# Helper function to compute relevant metrics using a confusion matrix
# see: https://en.wikipedia.org/wiki/Confusion_matrix
class ConfusionMatrix:
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.mat = None

    def update(self, a, b):
        n = self.num_classes
        if self.mat is None:
            self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
        with torch.no_grad():
            k = (a >= 0) & (a < n)
            inds = n * a[k].to(torch.int64) + b[k]
            self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)

    def reset(self):
        self.mat.zero_()

    def compute(self):
        h = self.mat.float()
        acc_global = torch.diag(h).sum() / h.sum()
        acc = torch.diag(h) / h.sum(1)
        iou = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
        return acc_global, acc, iou
    
    # Return overall accuracy, per-class accuracy, per-class Intersection over Union (IoU) and mean IoU
    def __str__(self):
        acc_global, acc, iou = self.compute()
        return ("global correct: {:.2f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.2f}").format(
            acc_global.item() * 100,
            [f"{i:.1f}" for i in (acc * 100).tolist()],
            [f"{i:.1f}" for i in (iou * 100).tolist()],
            iou.mean().item() * 100,
        )
    

## CholeSeg8k
The CholecSeg8k dataset [1] consists of subset of Cholec80 [2] annotated with semantic segmentation labels with 13 semantic classes for 17 video clips.


1. _Hong, W-Y., C-L. Kao, Y-H. Kuo, J-R. Wang, W-L. Chang, and C-S. Shih. "CholecSeg8k: A Semantic Segmentation Dataset for Laparoscopic Cholecystectomy Based on Cholec80." arXiv preprint arXiv:2012.12453 (2020)._

2. _Twinanda, Andru P., Sherif Shehata, Didier Mutter, Jacques Marescaux, Michel De Mathelin, and Nicolas Padoy. "Endonet: a deep architecture for recognition tasks on laparoscopic videos." IEEE transactions on medical imaging 36, no. 1 (2016): 86-97._

## Dataset class

In [4]:
# We define a dataset class that delivers images and correponding ground truth segmentation masks
# from the CholecSeg8k. Please refer to Lecture 6 for more info on torch Datasets.
def map_values(x):
    return REPLACE_CLASS.get(x, x)

class CholecDatasetSegm(torch.utils.data.Dataset):
    def __init__(self, gt_json, meta_data, root_dir = "./cholecseg8k", data_split = "train", transforms = None):
        self.gt_json = gt_json
        self.root_dir = root_dir
        self.data_split = data_split
        self.transforms = transforms
        gt_data = json.load(open(gt_json))
        self.images = [os.path.join(self.root_dir, g["file_name"]) for g in gt_data]
        self.targets = [os.path.join(self.root_dir, g["mask_name"]) for g in gt_data]
        self.metadata = meta_data
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index: int):
        img = Image.open(self.images[index]).convert("RGB")
        target = Image.open(self.targets[index]).convert("L")
        target = target.resize(img.size, resample=Image.NEAREST)
        
        if CLASSES_TO_IGNORE:
            target = np.array(target)
            
            # use numpy's vectorize function to apply the mapping to the whole array
            map_func = np.vectorize(map_values)
            target = map_func(target)

            target = Image.fromarray(target.astype(np.uint8), mode='L')
        if self.transforms is not None:
            img, target = self.transforms(img, target)        
        return img, target

In [5]:
class SegmentationPresetTrain:
    def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        min_size = int(0.5 * base_size)
        max_size = int(2.0 * base_size)

        trans = [T.RandomResize(min_size, max_size)]
        if hflip_prob > 0:
            trans.append(T.RandomHorizontalFlip(hflip_prob))
        trans.extend(
            [
                T.RandomCrop(crop_size),
                T.PILToTensor(),
                T.ConvertImageDtype(torch.float),
                T.Normalize(mean=mean, std=std),
            ]
        )
        self.transforms = T.Compose(trans)

    def __call__(self, img, target):
        return self.transforms(img, target)


class SegmentationPresetEval:
    def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        self.transforms = T.Compose(
            [
                T.RandomResize(base_size, base_size),
                T.PILToTensor(),
                T.ConvertImageDtype(torch.float),
                T.Normalize(mean=mean, std=std),
            ]
        )

    def __call__(self, img, target):
        return self.transforms(img, target)

In [6]:
# Defining Data Loaders for the training and testing splits.
# Please refer to Lecture 6 for more info on torch Data Loaders.


def get_transform(train=True):
    if train:
        return SegmentationPresetTrain(base_size=512, crop_size=400)
    else:
        return SegmentationPresetEval(base_size=400)
    
# Train loader
dataset = CholecDatasetSegm("./cholecseg8k/train_final.json", META_DATA, data_split="train", transforms=get_transform())
num_classes = len(META_DATA)
train_sampler = torch.utils.data.RandomSampler(dataset)
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
    sampler=train_sampler,
    collate_fn=collate_fn,
    drop_last=True,
)

# Test loader
dataset_test = CholecDatasetSegm("./cholecseg8k/val_final.json", META_DATA, data_split="val", transforms=get_transform(False))
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, sampler=test_sampler, collate_fn=collate_fn)

## CREATE IRCAD TEST DATASET

In [20]:
# mIOU for mixed style

data_path_25_66 = '/path/to/data' #output_vid25_66
data_path_01_49 = '/path/to/data' #output_vid01_49
data_path_52_56 = '/path/to/data' #output_vid52_56

data_path_seg = '/path/to/segmentation_maps' # Segmentation maps need to have CholecSeg classes. They were generated in the preprocessing step

# get list of files
rand_dict = {1: data_path_25_66,
            2: data_path_01_49,
            3: data_path_52_56}

filenames_common =set(os.listdir(rand_dict[1]))& set(os.listdir(rand_dict[2])) & set(os.listdir(rand_dict[3]))
filenames = [os.path.join(rand_dict[random.randint(1, 1)], f) for f in filenames_common]

random.seed(420420)
filenames = random.sample(filenames,10)

# get list of corresponding segmentation masks
masks = [os.path.join(data_path_seg, os.path.basename(f)) for f in filenames]

# Save json with test files + masks
data = []
for f, m in zip(filenames, masks):
    data.append({'file_name': f, 'mask_name': m})
with open('test_synthetic.json', 'w') as f:
    json.dump(data, f)

#Test IRCAD
transforms_ircad = SegmentationPresetEval(base_size=512)
dataset_test_ircad = CholecDatasetSegm("test_synthetic.json", META_DATA, root_dir='',data_split="val", transforms=transforms_ircad)
test_sampler_ircad = torch.utils.data.SequentialSampler(dataset_test_ircad)
data_loader_test_ircad = torch.utils.data.DataLoader(dataset_test_ircad, batch_size=1, sampler=test_sampler_ircad, collate_fn=collate_fn)

## Segmentation model

In [22]:
model = torchvision.models.segmentation.__dict__[MODEL_NAME](pretrained=True)
model.classifier[4] = nn.Conv2d(512, num_classes, 1)
model.aux_classifier [4] = nn.Conv2d(256, num_classes, 1)
model = model.to(DEVICE)



## Optimizer and learning rate scheduler

In [23]:
params_to_optimize = [
    {"params": [p for p in model.backbone.parameters() if p.requires_grad]},
    {"params": [p for p in model.classifier.parameters() if p.requires_grad]},
]
params = [p for p in model.aux_classifier.parameters() if p.requires_grad]
params_to_optimize.append({"params": params, "lr": learning_rate * 10})

iters_per_epoch = len(data_loader)
optimizer = torch.optim.SGD(params_to_optimize, lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: (1 - x / (iters_per_epoch * NUM_EPOCHS)) ** power)

## Helper function for training and validation for one epoch

In [24]:
# Helper function to train
def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device):
    model.train()
    train_loss  = 0.0
    pbar = tqdm(data_loader)
    for image, target in pbar:
        image, target = image.to(device), target.to(device)
        output = model(image)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        train_loss += loss.item()
        pbar.set_description("train_loss: {:.3f} lr: {:.3f}".format(loss.item(), 
                                                                    optimizer.param_groups[0]["lr"]))
    train_loss /= len(data_loader)
    return train_loss, optimizer.param_groups[0]["lr"]

# Helper function to evaluate
def evaluate(model, data_loader, device, num_classes):
    model.eval()
    confmat = ConfusionMatrix(num_classes)
    pbar = tqdm(data_loader)
    with torch.no_grad():
        for image, target in pbar:
            image, target = image.to(device), target.to(device)
            output = model(image)
            output = output["out"]
            confmat.update(target.flatten(), output.argmax(1).flatten())
            pbar.set_description("eval")
    return confmat

In [25]:
if DO_TRAINING:
    pbar = tqdm(range(NUM_EPOCHS))
    # Train and evaluate after each epoch
    for epoch in pbar:
        train_loss, last_lr = train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, DEVICE)
        confmat = evaluate(model, data_loader_test, device=DEVICE, num_classes=num_classes)
        acc_global, acc, iu = confmat.compute()    
        pbar.set_description(
            "train_loss: {:.3f} last_lr: {:.3f} acc_global: {:.3f} iou: {:.3f}".format(
                train_loss, last_lr, acc_global.item() * 100, iu.mean().item() * 100
            )
        )
        print("confmat:", confmat)
        torch.save(model.state_dict(), "revised_model_epoch_"+str(epoch)+".pth")
else:
    m,v = model.load_state_dict(torch.load(FINAL_MODEL_PATH, map_location=DEVICE))
    print("=> loaded model weights from {} \nmissing keys = {}  invalid keys {}".format(FINAL_MODEL_PATH, m, v))


    confmat = evaluate(model, data_loader_test_ircad, device=DEVICE, num_classes=num_classes)
    acc_global, acc, iu = confmat.compute()    
    print(
        "acc_global: {:.3f} iou: {:.3f}".format(
            acc_global.item() * 100, iu.mean().item() * 100
        )
    )
    print("confmat:", confmat)


=> loaded model weights from revised_model_epoch_29.pth 
missing keys = []  invalid keys []


  0%|          | 0/10 [00:00<?, ?it/s]

acc_global: 82.877 iou: 65.003
confmat: global correct: 82.88
average row correct: ['84.9', '74.8', '97.3', '86.2', '90.7']
IoU: ['61.9', '72.6', '87.9', '63.8', '38.8']
mean IoU: 65.00
