# Toy Transformer 
This file has been created to understand the attention/transformer structure and basic method to build transformer in PyTorch.  "

## 0. Preparation - package import

In [2]:
import os
import json
import torch
import random
import cv2
import numpy as np
import torchvision
from pathlib import Path
from torch import nn
import torch.utils.data as Data
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.nn import functional as F
from torchvision import models
import matplotlib.pyplot as plt 
import matplotlib

%matplotlib inline


## 1. Dataset and Dataloader
### Train / valid split based on the **folds** argument

In [3]:
data_path = Path("C:/Users/Siyao/Downloads/EndoVis2017Data")
train_path = data_path / "cropped_train"

def get_split(fold):
    """Split train and valid dataset based on the No. of folder"""
    folds = {0: [1, 3],
             1: [2, 5],
             2: [4, 8],
             3: [6, 7]}
    train_path = data_path / 'cropped_train'

    train_file_names = []
    val_file_names = []

    for instrument_id in range(1, 9):
        if instrument_id in folds[fold]:
            val_file_names += list((train_path / ('instrument_dataset_' + str(instrument_id)) / 'images').glob('*'))
        else:
            train_file_names += list((train_path / ('instrument_dataset_' + str(instrument_id)) / 'images').glob('*'))

    return train_file_names, val_file_names

train_file_names, val_file_names = get_split(0)

### Function to load image or mask

In [4]:
def load_image(path):
    img = cv2.imread(str(path))
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

binary_factor = 255
parts_factor = 85
instrument_factor = 32

def load_mask(path, problem_type="instruments", mask_folder="instruments_masks",factor=instrument_factor):
    if problem_type == 'binary':
        mask_folder = 'binary_masks'
        factor = binary_factor
    elif problem_type == 'parts':
        mask_folder = 'parts_masks'
        factor = parts_factor
    elif problem_type == 'instruments':
        factor = instrument_factor
        mask_folder = 'instruments_masks'

    mask = cv2.imread(str(path).replace('images', mask_folder).replace('jpg', 'png'), 0)

    return (mask / factor).astype(np.uint8)

### Dataset for training and validation

In [5]:
class RoboticsDataset(Dataset):
    """Dataset that only loads single frame"""

    def __init__(self, file_names, to_augment=False, transform=None, mode='train', problem_type=None):
        self.file_names = file_names
        self.to_augment = to_augment
        self.transform = transform
        self.mode = mode
        self.problem_type = problem_type

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        img_file_name = self.file_names[idx]
        image = load_image(img_file_name)
        mask = load_mask(img_file_name, self.problem_type)

        # data = {"image": image, "mask": mask}
        # augmented = self.transform(**data)
        # image, mask = augmented["image"], augmented["mask"]

        # if self.mode == 'train':
        if self.problem_type == 'binary':
            return torch.from_numpy(image), torch.from_numpy(np.expand_dims(mask, 0)).float(), str(img_file_name)
        else:
            return torch.from_numpy(image), torch.from_numpy(mask).long(), str(img_file_name)
        # else:
        #     return torch.from_numpy(image), str(img_file_name)

train_data_single = RoboticsDataset(train_file_names, problem_type="instrument")
valid_data_single = RoboticsDataset(val_file_names, mode='valid')

### Instrument Dataset
1. Mutiple image stacked as data
2. The label used instrument type

In [6]:
### Creating the lists of file name, that starts from tau frames after the first frame
### which avoid the first few frames having no previous frames issue.

tau = 3
train_img_path = [str(i) for i in train_file_names]
train_frame_name = [i for i in train_img_path if int(i[-7:-4])>=tau]
valid_img_path = [str(i) for i in val_file_names] 
valid_frame_name = [i for i in valid_img_path if int(i[-7:-4])>=tau]

In [7]:
class InstrumentDataset(Dataset):
    """Dataset that loads multiple frame"""

    def __init__(self, file_names, problem_type="Instrument", tau=3):
        self.file_names = file_names
        self.problem_type = problem_type
        self.tau = tau      # tau is the number of frames should be combiend
    
    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        current_frame = self.file_names[idx]
        mask = load_mask(current_frame, self.problem_type)
        frames_ls = []
        for i in range(tau):
            to_find = "frame"+current_frame[-7:-4]
            to_repl = "frame"+ '%03d' % (int(current_frame[-7:-4])-i)
            frame = current_frame.replace(to_find, to_repl)
            frame_array = load_image(frame)
            frame_tensor = torch.from_numpy(frame_array)
            frames_ls.append(frame_tensor)
        frames_stack = torch.stack(frames_ls, 0)
        # permute the tensor from [tau, H, W, C] to [tau, C, H, W]
        frames_tensor = frames_stack.permute(0,3,1,2)
        return frames_tensor.float(), torch.from_numpy(mask).float(), str(current_frame) 

In [8]:
### Traning and validation data with multi-frames input
# The data in 4D tensor (tau, w, h, c), label in 3D tensor ()
training_data_frames = InstrumentDataset(train_frame_name)
valid_data_frames = InstrumentDataset(valid_frame_name)

### Dataloader

In [9]:
batch_size = 1
training_data_loader = Data.DataLoader(training_data_frames, batch_size=batch_size, shuffle=True)
valid_data_loader = Data.DataLoader(valid_data_frames, batch_size=batch_size, shuffle=False)

In [10]:
a,b,c = next(iter(training_data_loader))
a.shape

torch.Size([1, 3, 3, 1024, 1280])

## Model
### CNN Backbone

In [11]:
model = models.resnet101(pretrained=True) # load the pretrained model
for param in model.parameters(): param.requires_grad_(False)

In [24]:
model = models.resnet101(pretrained=True)
for param in model.parameters():
    param.requires_grad_(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print("finish")

finish


In [None]:
class CNNBackbone(nn.Modules):
    def __init__(self, c_in, n_out):
        super(CNNBackbone, self).__init__()
        self.c_in = c_in
        self.n_out = n_out

        


In [None]:
class PositionalEncoding(nn.Module):
    "Implement the PE function."
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], 
                         requires_grad=False)
        return self.dropout(x)

In [None]:
class MyTransSeg(nn.Module):
    """Use a CNN to extract information from each frame"""
    def __init__(self, c_in, n_classes, use_gt=True):
        super(MyTransSeg, self).__init__():
        self.c_in = c_in
        self.n_classes = n_classes
        self.use_gt = use_gt
        self.backbone = models.resnet101(pretrained=True)
        self.backbone.eval()
        self.postionencoding = PositionalEncoding()

In [34]:
def clones(module, N):
    "Produce N identical layers; stack N modules."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [38]:
a[0].shape

torch.Size([3, 3, 1024, 1280])

In [12]:
a = a.permute(0,1,4,2,3)
a.shape
a.to(device)

torch.Size([1, 3, 3, 1024, 1280])

In [159]:
a.shape

torch.Size([1, 3, 1024, 1280, 3])

### Positional Encoding

In [None]:
class PositionalEncoding(nn.Module):
    "Implement the PE function."
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], 
                         requires_grad=False)
        return self.dropout(x)