<a href="https://colab.research.google.com/github/adast/TransResUNet/blob/master/TransResUNet_b16_fullres_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install necessary libraries, download dataset, download model, download pre trained weights

In [None]:
! pip install transformers
! pip install ml_collections
! pip install torchinfo

! 7z x "/content/drive/MyDrive/Hand Segmentation/train.zip" -o./dataset/train '-xr!__MACOSX'
! 7z x "/content/drive/MyDrive/Hand Segmentation/test.zip" -o./dataset/test '-xr!__MACOSX'
! mv dataset/train/train/* dataset/train/ && rm -rf dataset/train/train
! mv dataset/test/test/* dataset/test/ && rm -rf dataset/test/test
! cp "/content/drive/MyDrive/Hand Segmentation/sample_submission.csv" sample_submission.csv

! git clone https://github.com/adast/TransResUNet
! wget https://storage.googleapis.com/vit_models/imagenet21k/R50%2BViT-B_16.npz

## Import libraries, set manaul seed, utility functions

In [2]:
# TransUnet
import sys
sys.path.insert(0, './TransResUNet')
from TransResUNet.models.trans_resunet import TransResUNet

# Transformers
from transformers import AdamW, get_linear_schedule_with_warmup

# Pytorch
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchinfo import summary

# Others
import ml_collections
import os
import glob
import math
import numpy as np
import pandas as pd
import random
from tqdm.notebook import tqdm
from PIL import Image

# Make computations repeatable
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)

# Compute on gpu if available
device = "cuda" if torch.cuda.is_available() else "cpu"

# Run length encoding
def rle_encoding(x):
    '''
    x: numpy array of shape (height, width), 1 - mask, 0 - background
    Returns run length as list
    '''
    dots = np.where(x.T.flatten()==1)[0] # .T sets Fortran order down-then-right
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b+1, 0))
        run_lengths[-1] += 1
        prev = b
    return run_lengths

# Dice score
def dice_score(y_true, y_pred):
    return torch.sum(y_pred[y_true==1])*2.0 / (torch.sum(y_pred) + torch.sum(y_true))

## Define dataset class

In [3]:
class HandSegmentationDataset(Dataset):
    def __init__(self, path: str):
        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

        self.paths = []
        folders = sorted([os.path.join(path, f) for f in os.listdir(path) if os.path.isdir(os.path.join(path, f))])
        for folder in folders:
            image_paths = sorted(glob.glob(f'{folder}/images/*.png'))
            segmentation_paths = sorted(glob.glob(f'{folder}/segmentation/*.png'))
            for image_path, segmentation_path in zip(image_paths, segmentation_paths):
                assert image_path != segmentation_path
                self.paths.append((image_path, segmentation_path))

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index: int):
        image = Image.open(self.paths[index][0]).convert('RGB')
        pixel_values = self.to_tensor(image)
        pixel_values = self.normalize(pixel_values)

        segmentation = Image.open(self.paths[index][1]).convert('L')
        segmentation = self.to_tensor(segmentation)
        segmentation = (segmentation > 0.1).float()

        return {
            'pixel_values': torch.squeeze(pixel_values),
            'segmentations': segmentation
        }

## Initialize dataset and dataloaders

In [4]:
BATCH_SIZE = 3
NUM_WORKERS = 2

# Create and split dataset to train and val
dataset = HandSegmentationDataset('dataset/train/')
train_size = int(len(dataset) * 0.8)
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False)

## Define hyperparams. Initialize model, optimizer, scheduler, criterion

In [5]:
EPOCHS = 30
LEARNING_RATE = 5e-5

def get_r50_b16_config():
    config = ml_collections.ConfigDict()
    
    config.image_size = (480, 640)
    config.n_classes = 1
    config.pre_trained_path = 'R50+ViT-B_16.npz'
    
    config.resnet = ml_collections.ConfigDict()
    # Using three bottleneck blocks results in a downscaling of 2^(1 + 3)=16 which
    # results in an effective patch size of /16.
    config.resnet.num_layers = (3, 4, 9)
    config.resnet.width_factor = 1
    
    config.transformer = ml_collections.ConfigDict()
    config.transformer.num_special_tokens = 1
    config.transformer.patch_size = 16
    config.transformer.hidden_size = 768
    config.transformer.mlp_dim = 3072
    config.transformer.num_heads = 12
    config.transformer.num_layers = 12
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.1
    
    config.decoder = ml_collections.ConfigDict()
    config.decoder.head_channels = 512
    
    return config

def get_r50_l32_config():
    """Returns the ViT-L/32 configuration."""
    config = ml_collections.ConfigDict()
    
    config.image_size = (480, 640)
    config.n_classes = 1
    config.pre_trained_path = 'R50+ViT-L_32.npz'
    
    config.resnet = ml_collections.ConfigDict()
    # Using four bottleneck blocks results in a downscaling of 2^(1 + 4)=32 which
    # results in an effective patch size of /32.
    config.resnet.num_layers = (3, 4, 6, 3)
    config.resnet.width_factor = 1
    
    config.transformer = ml_collections.ConfigDict()
    config.transformer.num_special_tokens = 1
    config.transformer.patch_size = 32
    config.transformer.hidden_size = 1024
    config.transformer.mlp_dim = 4096
    config.transformer.num_heads = 16
    config.transformer.num_layers = 24
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.1
    
    config.decoder = ml_collections.ConfigDict()
    config.decoder.head_channels = 512
    
    return config

config = get_r50_b16_config()
model = TransResUNet(config)
model.to(device)

# Loss and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
total_steps = len(train_dataloader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps = len(train_dataloader),
    num_training_steps = total_steps
)

Resized position embedding: torch.Size([1, 197, 768]) to torch.Size([1, 1601, 768])
Position embedding grid-size from [14, 14] to [40, 40]


## Show model structure and info

In [6]:
from torchinfo import summary
summary(model)

Layer (type:depth-idx)                                       Param #
TransResUNet                                                 --
├─HybridVit: 1-1                                             --
│    └─Embeddings: 2-1                                       --
│    │    └─ResNetV2: 3-1                                    11,894,848
│    │    └─Conv2d: 3-2                                      787,200
│    │    └─Dropout: 3-3                                     --
│    └─Encoder: 2-2                                          --
│    │    └─ModuleList: 3-4                                  85,054,464
│    │    └─LayerNorm: 3-5                                   1,536
├─Conv2dReLU: 1-2                                            --
│    └─Conv2d: 2-3                                           3,538,944
│    └─BatchNorm2d: 2-4                                      1,024
│    └─ReLU: 2-5                                             --
├─ModuleList: 1-3                                            --
│

## Train and validation functions

In [7]:
def train_epoch(model, dataloader, loss_fn, optimizer, scheduler, device, writer=None, epoch_index=0):
    # Tracking variables.
    losses = []

    # Put the model into training mode.
    model.train()

    # For each batch of training data...
    for batch_index, batch in enumerate(tqdm(dataloader, total=len(dataloader), desc="Training on batches")):
        global_batch_index = epoch_index * len(dataloader) + batch_index # Global step index

        pixel_values = batch['pixel_values'].to(device) # Pixel values
        segmentations = batch['segmentations'].to(device) # Segmentation
        
        # Forward
        outputs = model(pixel_values)
        loss = loss_fn(outputs, segmentations)
        losses.append(loss.item())

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Clip the norm of gradient to prevent gradient expolosion
        optimizer.step() # Update weights
        scheduler.step() # Update the learning rate.

        # Write loss per batch to tensorboard
        if writer is not None:
            writer.add_scalar('Loss/train (per batch)', loss.item(), global_batch_index)

    return np.mean(losses)


def val_epoch(model, dataloader, loss_fn, device, writer=None, epoch_index=0):
    # Tracking variables.
    losses = []
    metrics = []

    # Put the model into evaluation mode.
    model.eval()

    # For each batch of training data...
    with torch.no_grad():
        for batch_index, batch in enumerate(tqdm(dataloader, total=len(dataloader), desc="Validation on batches")):
            global_batch_index = epoch_index * len(dataloader) + batch_index # Global step index

            pixel_values = batch['pixel_values'].to(device) # Pixel values
            segmentations = batch['segmentations'].to(device) # Segmentation
            
            # Forward
            outputs = model(pixel_values)
            loss = loss_fn(outputs, segmentations)

            # Compute metric
            outputs = torch.sigmoid(outputs)
            outputs = (outputs > 0.5).float()
            metric = dice_score(segmentations, outputs)

            losses.append(loss.item())
            metrics.append(metric.item())

            # Write to tensorboard
            if writer is not None:
                writer.add_scalar('Loss/val (per batch)', loss.item(), global_batch_index)
                writer.add_scalar('Dice/val (per batch)', metric.item(), global_batch_index)

    return np.mean(losses), np.mean(metrics)

## Train model

In [None]:
TENSORBOARD_DIR = '/content/drive/MyDrive/rucode_2021/segmentation/transunet-r50-b16-fullres/tensorboard'
CHECKPOINTS_DIR = '/content/drive/MyDrive/rucode_2021/segmentation/transunet-r50-b16-fullres/checkpoints'
! mkdir -p {CHECKPOINTS_DIR}

# Tensorboard
writer = SummaryWriter(log_dir=TENSORBOARD_DIR)

# Loop through each epoch.
for epoch in tqdm(range(EPOCHS), desc="Epoch"):
    # Perform one full pass over the training and validation sets
    train_loss = train_epoch(model, train_dataloader, criterion, optimizer, scheduler, device, writer, epoch)
    val_loss, val_metric = val_epoch(model, val_dataloader, criterion, device, writer, epoch)

    # Populate tensorboard
    writer.add_scalar('Loss/train (per epoch)', train_loss, epoch)
    writer.add_scalar('Loss/val (per epoch)',val_loss, epoch)
    writer.add_scalar('Dice/val (per epoch)',val_metric, epoch)

    # Print loss and accuracy values to see how training evolves.
    print(f'train_loss: {train_loss:.5f} - val_loss: {val_loss:.5f} - dice: {val_metric:.5f}\n')

    # Save checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict()
    }, f"{CHECKPOINTS_DIR}/epoch-{epoch}_vl_{val_loss:.5f}_dice_{val_metric:.5f}.pt")