In [1]:
# Huggingface
from transformers import AutoFeatureExtractor, AutoModel
from transformers import AdamW, get_linear_schedule_with_warmup

# Pytorch
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import Dataset, TensorDataset, random_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

# Visualization libraries
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import rc

# Others
from typing import Tuple
from PIL import Image
import os
import glob
import math
import numpy as np
import pandas as pd
import random
from tqdm.notebook import tqdm

# Logging
from importlib import reload
import logging
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%I:%M:%S')

# Config visualization output
%matplotlib inline
%config InlineBackend.figure_format='retina'
sns.set(style='whitegrid', palette='muted', font_scale=1.2)
sns.set_palette(sns.color_palette(["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]))

# Make computations repeatable
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)

# Compute on gpu if available
device = "cuda" if torch.cuda.is_available() else "cpu"

# Run length encoding
def rle_encoding(x):
    '''
    x: numpy array of shape (height, width), 1 - mask, 0 - background
    Returns run length as list
    '''
    dots = np.where(x.T.flatten()==1)[0] # .T sets Fortran order down-then-right
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b+1, 0))
        run_lengths[-1] += 1
        prev = b
    return run_lengths

# Dice score
def dice_score(y_true, y_pred):
    return torch.sum(y_pred[y_true==1])*2.0 / (torch.sum(y_pred) + torch.sum(y_true))

In [2]:
from models.trans_resunet import TransResUNet
import ml_collections
from torchinfo import summary

def get_r50_b16_config():
    """Returns the ViT-B/16 configuration."""
    config = ml_collections.ConfigDict()
    
    config.image_size = (480, 640)
    config.n_classes = 1
    config.pre_trained_path = 'imagenet21k_R50+ViT-B_16.npz'
    
    config.resnet = ml_collections.ConfigDict()
    # Using three bottleneck blocks results in a downscaling of 2^(1 + 3)=16 which
    # results in an effective patch size of /16.
    config.resnet.num_layers = (3, 4, 9)
    config.resnet.width_factor = 1
    
    config.transformer = ml_collections.ConfigDict()
    config.transformer.num_special_tokens = 1
    config.transformer.patch_size = 16
    config.transformer.hidden_size = 768
    config.transformer.mlp_dim = 3072
    config.transformer.num_heads = 12
    config.transformer.num_layers = 12
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.1
    
    return config

def get_r50_l32_config():
    """Returns the ViT-L/32 configuration."""
    config = ml_collections.ConfigDict()
    
    config.image_size = (480, 640)
    config.n_classes = 1
    config.pre_trained_path = 'imagenet21k_R50+ViT-L_32.npz'
    
    config.resnet = ml_collections.ConfigDict()
    # Using four bottleneck blocks results in a downscaling of 2^(1 + 4)=32 which
    # results in an effective patch size of /32.
    config.resnet.num_layers = (3, 4, 6, 3)
    config.resnet.width_factor = 1
    
    config.transformer = ml_collections.ConfigDict()
    config.transformer.num_special_tokens = 1
    config.transformer.patch_size = 32
    config.transformer.hidden_size = 1024
    config.transformer.mlp_dim = 4096
    config.transformer.num_heads = 16
    config.transformer.num_layers = 24
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.1
    
    return config

In [3]:
# config = get_r50_l32_config()
config = get_r50_b16_config()
net = TransResUNet(config)

Resized position embedding: torch.Size([1, 197, 768]) to torch.Size([1, 1601, 768])
Position embedding grid-size from [14, 14] to [40, 40]


In [6]:
summary(net, (4, 3, 480, 640), depth=5)

Layer (type:depth-idx)                                       Output Shape              Param #
TransResUNet                                                 --                        --
├─HybridVit: 1                                               --                        --
│    └─Encoder: 2                                            --                        --
│    │    └─ModuleList: 3-1                                  --                        --
├─ModuleList: 1-1                                            --                        --
├─HybridVit: 1-2                                             [4, 1201, 768]            --
│    └─Embeddings: 2-1                                       [4, 1201, 768]            --
│    │    └─ResNetV2: 3-2                                    [4, 1024, 30, 40]         --
│    │    │    └─Sequential: 4-1                             [4, 64, 240, 320]         --
│    │    │    │    └─StdConv2d: 5-1                         [4, 64, 240, 320]         9,408
│ 