# Toy Transformer 
This file has been created to understand the attention/transformer structure and basic method to build transformer in PyTorch.  "

## 0. Preparation - package import

In [2]:
import os
import json
import torch
import random
import cv2
import numpy as np
import torchvision
from pathlib import Path
from torch import nn
import torch.utils.data as Data
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.nn import functional as F
from torchvision import models, transforms
import matplotlib.pyplot as plt 
import matplotlib

%matplotlib inline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 1. Dataset and Dataloader
### Train / valid split based on the **folds** argument

In [None]:
data_path = Path("C:/Users/Siyao/Downloads/EndoVis2017Data")
train_path = data_path / "cropped_train"

def get_split(fold):
    """Split train and valid dataset based on the No. of folder"""
    folds = {0: [1, 3],
             1: [2, 5],
             2: [4, 8],
             3: [6, 7]}
    train_path = data_path / 'cropped_train'

    train_file_names = []
    val_file_names = []

    for instrument_id in range(1, 9):
        if instrument_id in folds[fold]:
            val_file_names += list((train_path / ('instrument_dataset_' + str(instrument_id)) / 'images').glob('*'))
        else:
            train_file_names += list((train_path / ('instrument_dataset_' + str(instrument_id)) / 'images').glob('*'))

    return train_file_names, val_file_names

train_file_names, val_file_names = get_split(0)

### Function to load image or mask

In [None]:
def load_image(path):
    img = cv2.imread(str(path))
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

binary_factor = 255
parts_factor = 85
instrument_factor = 32

def load_mask(path, problem_type="instruments", mask_folder="instruments_masks",factor=instrument_factor):
    if problem_type == 'binary':
        mask_folder = 'binary_masks'
        factor = binary_factor
    elif problem_type == 'parts':
        mask_folder = 'parts_masks'
        factor = parts_factor
    elif problem_type == 'instruments':
        factor = instrument_factor
        mask_folder = 'instruments_masks'

    mask = cv2.imread(str(path).replace('images', mask_folder).replace('jpg', 'png'), 0)

    return (mask / factor).astype(np.uint8)

### Dataset for training and validation

In [None]:
class RoboticsDataset(Dataset):
    """Dataset that only loads single frame"""

    def __init__(self, file_names, to_augment=False, transform=None, mode='train', problem_type=None):
        self.file_names = file_names
        self.to_augment = to_augment
        self.transform = transform
        self.mode = mode
        self.problem_type = problem_type

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        img_file_name = self.file_names[idx]
        image = load_image(img_file_name)
        mask = load_mask(img_file_name, self.problem_type)

        # data = {"image": image, "mask": mask}
        # augmented = self.transform(**data)
        # image, mask = augmented["image"], augmented["mask"]

        # if self.mode == 'train':
        if self.problem_type == 'binary':
            return torch.from_numpy(image), torch.from_numpy(np.expand_dims(mask, 0)).float(), str(img_file_name)
        else:
            return torch.from_numpy(image), torch.from_numpy(mask).long(), str(img_file_name)
        # else:
        #     return torch.from_numpy(image), str(img_file_name)

train_data_single = RoboticsDataset(train_file_names, problem_type="instrument")
valid_data_single = RoboticsDataset(val_file_names, mode='valid')

### Instrument Dataset
1. Mutiple image stacked as data
2. The label used instrument type

In [None]:
### Creating the lists of file name, that starts from tau frames after the first frame
### which avoid the first few frames having no previous frames issue.

tau = 3
train_img_path = [str(i) for i in train_file_names]
train_frame_name = [i for i in train_img_path if int(i[-7:-4])>=tau]
valid_img_path = [str(i) for i in val_file_names] 
valid_frame_name = [i for i in valid_img_path if int(i[-7:-4])>=tau]

In [None]:
## RESIZE IMAGE TO 1/16 RESOLUTION
class InstrumentDataset(Dataset):
    """Dataset that loads multiple frame"""

    def __init__(self, file_names, problem_type="Instrument", tau=3):
        self.file_names = file_names
        self.problem_type = problem_type
        self.tau = tau      # tau is the number of frames should be combiend
        self.transform = transforms.Compose([
                                transforms.ToPILImage(),
                                # transforms.Resize([256,320]),
                                transforms.ToTensor()
                            ]) 
    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        current_frame = self.file_names[idx]
        mask = load_mask(current_frame, self.problem_type)
        frames_ls = []
        for i in range(tau):
            to_find = "frame"+current_frame[-7:-4]
            to_repl = "frame"+ '%03d' % (int(current_frame[-7:-4])-i)
            frame = current_frame.replace(to_find, to_repl)
            frame_array = load_image(frame)
            frame_tensor = self.transform(frame_array)
            # frame_tensor = torch.from_numpy(frame_tensor)            
            frames_ls.append(frame_tensor)
        frames_stack = torch.stack(frames_ls, 0)
        # permute the tensor from [tau, H, W, C] to [tau, C, H, W]
        # frames_tensor = frames_stack.permute(0,3,1,2)
        return frames_stack.float(), torch.from_numpy(mask).float(), str(current_frame) 

In [None]:
### Traning and validation data with multi-frames input
# The data in 4D tensor (tau, H, W, C), label in 3D tensor ()
training_data_frames = InstrumentDataset(train_frame_name)
valid_data_frames = InstrumentDataset(valid_frame_name)

### Dataloader

In [None]:
batch_size = 1
training_data_loader = Data.DataLoader(training_data_frames, batch_size=batch_size, shuffle=True)
valid_data_loader = Data.DataLoader(valid_data_frames, batch_size=batch_size, shuffle=False)

In [None]:
a,b,c = next(iter(training_data_loader))
print(f"Data shape: {a.shape}")
print(f"Mask shape: {b.shape}")
print(f"Path: {c}")

In [None]:
img = a[0][0].permute(1,2,0)
plt.imshow(img)

## Model
### CNN Backbone

In [4]:
# model = models.segmentation.fcn_resnet101(pretrained=True).eval()
model = models.resnet101(pretrained=True).eval()
model.to(device)
# y = model(a[0].to(device))

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [5]:
from torchinfo import summary
summary(model, input_size=(3,3,224, 320))

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   --                        --
├─Conv2d: 1-1                            [3, 64, 112, 160]         9,408
├─BatchNorm2d: 1-2                       [3, 64, 112, 160]         128
├─ReLU: 1-3                              [3, 64, 112, 160]         --
├─MaxPool2d: 1-4                         [3, 64, 56, 80]           --
├─Sequential: 1-5                        [3, 256, 56, 80]          --
│    └─Bottleneck: 2-1                   [3, 256, 56, 80]          --
│    │    └─Conv2d: 3-1                  [3, 64, 56, 80]           4,096
│    │    └─BatchNorm2d: 3-2             [3, 64, 56, 80]           128
│    │    └─ReLU: 3-3                    [3, 64, 56, 80]           --
│    │    └─Conv2d: 3-4                  [3, 64, 56, 80]           36,864
│    │    └─BatchNorm2d: 3-5             [3, 64, 56, 80]           128
│    │    └─ReLU: 3-6                    [3, 64, 56, 80]           --
│ 

In [None]:
class CNNBackbone(nn.Module):
    def __init__(self, c_in=3, n_out=2048):
        super(CNNBackbone, self).__init__()
        self.c_in = c_in
        self.n_out = n_out
        self.cnn = models.resnet101(pretrained=True).eval()
        self.cnn.fc = nn.Linear(2048,self.n_out)
    
    def forward(self, x):
        y = self.cnn(x[0])
        out = y.unsqueeze(0)
        # out.shape = [batch_size, tau, 2048]
        return out

### Position Encoding

In [None]:
class PositionEmbeddingSine(nn.Module):
    """
    ## https://github.com/Epiphqny/VisTR/blob/master/models/position_encoding.py
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """
    def __init__(self, num_pos_feats=64, num_frames = 3, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        self.frames = num_frames
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors
        mask = tensor_list.mask
        n,h,w = mask.shape
        mask = mask.reshape(n//self.frames, self.frames,h,w)
        assert mask is not None
        not_mask = ~mask
        z_embed = not_mask.cumsum(1, dtype=torch.float32)
        y_embed = not_mask.cumsum(2, dtype=torch.float32)
        x_embed = not_mask.cumsum(3, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            z_embed = z_embed / (z_embed[:, -1:, :, :] + eps) * self.scale
            y_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, :, None] / dim_t
        pos_z = z_embed[:, :, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
        pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
        pos_z = torch.stack((pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
        pos = torch.cat((pos_z, pos_y, pos_x), dim=4).permute(0, 1, 4, 2, 3)
        return pos

    N_steps = args.hidden_dim // 3
    if args.position_embedding in ('v2', 'sine'):
        # TODO find a better way of exposing other arguments
        position_embedding = PositionEmbeddingSine(N_steps, num_frames = args.num_frames, normalize=True)
    else:
        raise ValueError(f"not supported {args.position_embedding}")

    return position_embedding

In [None]:
def clones(module, N):
    "Produce N identical layers; stack N modules."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [None]:
### Define sparse attention 
# https://github.com/wilson1yan/VideoGPT/blob/master/videogpt/attention.py
class SparseAttention(nn.Module):
    ops = dict()
    attn_mask = dict()
    block_layout = dict()

    def __init__(self, shape, n_head, causal, num_local_blocks=4, block=32,
                 attn_dropout=0.): # does not use attn_dropout
        super().__init__()
        self.causal = causal
        self.shape = shape

        self.sparsity_config = StridedSparsityConfig(shape=shape, n_head=n_head,
                                                     causal=causal, block=block,
                                                     num_local_blocks=num_local_blocks)

        if self.shape not in SparseAttention.block_layout:
            SparseAttention.block_layout[self.shape] = self.sparsity_config.make_layout()
        if causal and self.shape not in SparseAttention.attn_mask:
            SparseAttention.attn_mask[self.shape] = self.sparsity_config.make_sparse_attn_mask()

    def get_ops(self):
        try:
            from deepspeed.ops.sparse_attention import MatMul, Softmax
        except:
            raise Exception('Error importing deepspeed. Please install using `DS_BUILD_SPARSE_ATTN=1 pip install deepspeed`')
        if self.shape not in SparseAttention.ops:
            sparsity_layout = self.sparsity_config.make_layout()
            sparse_dot_sdd_nt = MatMul(sparsity_layout,
                                       self.sparsity_config.block,
                                       'sdd',
                                       trans_a=False,
                                       trans_b=True)

            sparse_dot_dsd_nn = MatMul(sparsity_layout,
                                       self.sparsity_config.block,
                                       'dsd',
                                       trans_a=False,
                                       trans_b=False)

            sparse_softmax = Softmax(sparsity_layout, self.sparsity_config.block)

            SparseAttention.ops[self.shape] = (sparse_dot_sdd_nt,
                                               sparse_dot_dsd_nn,
                                               sparse_softmax)
        return SparseAttention.ops[self.shape]

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, shape, dim_q, dim_v, n_head, n_layer, 
                        causal=True, attn_type="Sparse", attn_kwargs):
        super().__init__
        self.causal = causal
        self.shape = shape

        self.d_k = dim_q // n_head
        self.d_v = dim_v // n_head
        self.n_head = n_head

        self.w_qs = nn.Linear(dim_q, n_head * self.d_k, bias=False) # q
        self.w_qs.weight.data.normal_(std=1.0 / np.sqrt(dim_q))

        self.w_ks = nn.Linear(dim_kv, n_head * self.d_k, bias=False) # k
        self.w_ks.weight.data.normal_(std=1.0 / np.sqrt(dim_kv))

        self.w_vs = nn.Linear(dim_kv, n_head * self.d_v, bias=False) # v
        self.w_vs.weight.data.normal_(std=1.0 / np.sqrt(dim_kv))

        self.fc = nn.Linear(n_head * self.d_v, dim_q, bias=True) # c
        self.fc.weight.data.normal_(std=1.0 / np.sqrt(dim_q * n_layer))

        self.attn = SparseAttention(shape, n_head, causal, **attn_kwargs)

    def forward(self, q, k, v):
        # compute k, q, v
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        q = view_range(self.w_qs(q), -1, None, (n_head, d_k))
        k = view_range(self.w_ks(k), -1, None, (n_head, d_k))
        v = view_range(self.w_vs(v), -1, None, (n_head, d_v))

        # b x n_head x seq_len x d
        # (b, *d_shape, n_head, d) -> (b, n_head, *d_shape, d)
        q = shift_dim(q, -2, 1)
        k = shift_dim(k, -2, 1)
        v = shift_dim(v, -2, 1)
        
        a = self.attn(q, k, v, decode_step, decode_idx)

        # (b, *d_shape, n_head, d) -> (b, *d_shape, n_head * d)
        a = shift_dim(a, 1, -2).flatten(start_dim=-2)
        a = self.fc(a) # (b x seq_len x embd_dim)

        return a

In [None]:
class SpaTemSelfAtteBlock(nn.Module):
    """Define the Multi-head attention -> Add&Norm -> Feed Forward -> Add&Norm module"""
    def __init__(self, shape, embd_dim, n_head, n_layer, dropout=0.2,
                 attn_type, attn_dropout, class_cond_dim, frame_cond_shape):
        super.__init__()

        self.shape = shape
        # Multi-head attention sub-layer
        self.mh_attn = MultiHeadAttention(shape, embd_dim, embd_dim, n_head, n_layer)
        self.norm_1 = nn.LayerNorm(embd_dim, class_cond_dim)
        
        # Feed forward sub-layer
        self.fc = nn.Sequential(
            nn.Linear(in_features=embd_dim, out_features=embd_dim*4),
            nn.ReLU(),
            nn.Dropout(dropout)
            nn.Linear(in_features=embd_dim*4, out_features=embd_dim)
        )
        self.norm_2 =  nn.LayerNorm(embd_dim, class_cond_dim)
    
    def forward(self, x, cond, decode):5
        y1 = self.mh_attn(x)
        y2 = self.norm_1(x+y1)
        y3 = self.fc(y2)
        out = self.norm_2(y2+y3)
        return out
        

In [None]:
class EncoderLayer(nn.Module):
    """Use a CNN to extract information from each frame"""
    def __init__(self, c_in, n_classes, use_gt=True, embd_dim, dropout):
        super(MyTransSeg, self).__init__():
            self.backbone = CNNBackbone()
            self.pos_en = PositionEmbeddingSine(embd_dim, dropout)
            self.encoder = 