# MENTOR training

Note this notebook does not reflect the experiments performed in the paper in its entirety as efforts were made to keep annonymity in the review process.

The authors would like to thank [Segmentation Models](https://github.com/qubvel/segmentation_models.pytorch) repo for examples and guidance during the creation of this work.

In [None]:
import torchvision.transforms as transforms
import sys
import torch.nn as nn
from torch.utils.data import Dataset as BaseDataset
import segmentation_models_pytorch as smp
import torch
import os
from torch.utils.data import DataLoader
from PIL import Image

In [None]:
x_train_dir = "/train"
y_train_dir = "/trainannot"

x_valid_dir = "/val"
y_valid_dir = "/valannot"

In [None]:
class Dataset(BaseDataset):
    
    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            classes=None, 
            augmentation=None, 
            preprocessing=None,
    ):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        
        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        # read data
        image = Image.open(self.images_fps[i]).convert("RGB")
        
        transform = transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        image = transform(image)
        
        mask = Image.open(self.masks_fps[i].replace(".jpg", ".png")).convert("RGB")
        map_transform = transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.Grayscale(),
            transforms.ToTensor(),
        ])
        mask = map_transform(mask)
        mask = mask - torch.min(mask)
        mask = mask / torch.max(mask)
                                
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image)
            image = sample['image']

        return image, mask
        
    def __len__(self):
        return len(self.ids)

In [None]:
ENCODER = 'resnet50' # Change encoder here
ENCODER_WEIGHTS = 'imagenet' # 'imagenet'
ACTIVATION = None
DEVICE = 'cuda'

# create segmentation model with pretrained encoder
# we also experimented with smp.Unet
model = smp.UnetPlusPlus(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=1,
    activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
compare_loss = nn.MSELoss()

solver = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

In [None]:
train_dataset = Dataset(
    x_train_dir, 
    y_train_dir, 
    augmentation=None,
    preprocessing=None,
    classes=CLASSES,
)

valid_dataset = Dataset(
    x_valid_dir, 
    y_valid_dir, 
    augmentation=None,
    preprocessing=None,
    classes=CLASSES,
)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=1, pin_memory=True, drop_last=True)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=True, num_workers=1, pin_memory=True, drop_last=True)

d = {'train': train_loader, 'test': valid_loader}

In [None]:
device = torch.device("cuda")
model.to(device)

## Training Loop

In [None]:
log = {'iterations':[], 'epoch':[], 'validation':[], 'train_acc':[], 'val_acc':[]}
log_path = "./mentor_weights"
train_loss=[]
test_loss=[]
bestAccuracy = 1
bestEpoch=0
train_step = 0
val_step = 0
for epoch in range(11): # training for only 10 epochs
    for phase in ['train', 'test']:
        train = (phase=='train')
        if phase == 'train':
            model.train()
        else:
            model.eval()
        tloss = 0.
        metric = 0.
        tot = 0
        c = 0
        testPredScore = []
        testTrueLabel = []
        imgNames=[]
        with torch.set_grad_enabled(train):
            for batch_idx, data in enumerate(d[phase]):

                # Data and ground truth
                images, masks = data
                images = images.to(device)
                masks = masks.to(device)
                
                outputs = model(images)
                prob_mask = outputs.sigmoid()
                loss = compare_loss(prob_mask, masks)
                metric += loss
                tot += len(data)

                # Optimization of weights for training data
                if phase == 'train':
                    train_step += 1
                    solver.zero_grad()
                    loss.backward()
                    solver.step()
                    log['iterations'].append(loss.item())
                elif phase == 'test':
                    val_step += 1
                    temp = outputs.detach().cpu().numpy()

                tloss += loss.item()
                c += 1

        # Logging of train and test results
        if phase == 'train':
            log['epoch'].append(tloss/c)
            log['train_acc'].append(metric / tot)
            print('Epoch: ', epoch, 'Train loss: ',tloss/c, 'Metric: ', metric.cpu().item() / tot)
            train_loss.append(tloss / c)

        elif phase == 'test':
            log['validation'].append(tloss / c)
            log['val_acc'].append(metric / tot)
            print('Epoch: ', epoch, 'Test loss:', tloss / c, 'Metric: ', metric.cpu().item() / tot)
            test_loss.append(tloss / c)
            accuracy = tloss / c #/ c
            if (accuracy <= bestAccuracy):
                print("New model saved!")
                bestAccuracy = accuracy
                bestEpoch = epoch
                save_best_model = os.path.join(log_path,'final_model.pth')
                states = {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': solver.state_dict(),
                }
                torch.save(states, save_best_model)