In [None]:
%load_ext autoreload
%autoreload 2

import os
os.environ["CUDA_VISIBLE_DEVICES"]="1,0"

In [None]:
import torch
from torch import nn
from tqdm import tqdm
import os, random, gc
import numpy as np
from src.fastai_fix import *
import tifffile
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import albumentations as A
import cv2

In [None]:
class CONFIG:
    path = 'data/'
    out = 'experiments/init'
    
    num_workers = 12
    seed = 2023
    bs = 8
    
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CONFIG.seed)
os.makedirs(CONFIG.out, exist_ok=True)

def WrapperAdamW(param_groups,**kwargs):
    return OptimWrapper(param_groups,torch.optim.AdamW)

from src.radam import Over9000
def WrapperOver9000(param_groups,**kwargs):
    return OptimWrapper(param_groups,opt=Over9000)

In [None]:
def img2tensor(img,dtype:np.dtype=np.float32):
    if img.ndim==2 : img = np.expand_dims(img,2)
    img = np.transpose(img,(2,0,1))
    return torch.from_numpy(img.astype(dtype, copy=False))

def get_aug():
    return A.Compose([
        A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=0.75),
        A.OneOf([
            A.RandomGamma(gamma_limit=(50, 150), always_apply=True),
            A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.2, always_apply=True),],p=0.5),
        A.OneOf([
            A.MotionBlur(always_apply=True),
            A.GaussianBlur(always_apply=True),],p=0.25)
    ], p=1)

class ContrailsDataset(Dataset):
    def __init__(self, path, train=True, tfms=None, repeat=1):
        self.path = os.path.join(path, 'train_adj2' if train else 'val_adj2')
        self.fnames = sorted([fname for fname in os.listdir(self.path) if \
                       fname.split('.')[0].split('_')[-1] == 'img'])
        self.train, self.tfms = train, tfms
        self.nc = 3
        self.repeat = repeat

    def __len__(self):
        return self.repeat*len(self.fnames)

    def __getitem__(self, idx):
        idx = idx%len(self.fnames)
        img = tifffile.imread(os.path.join(self.path, self.fnames[idx]))
        img = img.reshape(*img.shape[:2],self.nc,-1)[:,:,:,:5]
        img = img.reshape(*img.shape[:2],-1)
        mask = tifffile.imread(os.path.join(self.path, self.fnames[idx].replace('img','mask')))

        if self.tfms is not None:
            augmented = self.tfms(image=img,mask=mask)
            img,mask = augmented['image'],augmented['mask']
        
        img = cv2.resize(img, (2*img.shape[1],2*img.shape[0]), interpolation=cv2.INTER_CUBIC)
        img,mask = img2tensor(img/255),img2tensor(mask/255)
        img = img.view(self.nc, -1, *img.shape[1:])
        
        return img,mask

In [None]:
class F_th(Metric):
    def __init__(self, ths=np.arange(0.1,0.9,0.01), beta=1): 
        self.ths = ths
        self.beta = beta
        
    def reset(self): 
        self.tp = torch.zeros(len(self.ths))
        self.fp = torch.zeros(len(self.ths))
        self.fn = torch.zeros(len(self.ths))
        
    def accumulate(self, learn):
        pred,targ = flatten_check(torch.sigmoid(learn.pred.float()), 
                                  (learn.y > 0.5).float())
        for i,th in enumerate(self.ths):
            p = (pred > th).float()
            self.tp[i] += (p*targ).float().sum().item()
            self.fp[i] += (p*(1-targ)).float().sum().item()
            self.fn[i] += ((1-p)*targ).float().sum().item()

    @property
    def value(self):
        self.dices = (1 + self.beta**2)*self.tp/\
            ((1 + self.beta**2)*self.tp + self.beta**2*self.fn + self.fp + 1e-6)
        return self.dices.max()

from src.lovasz import lovasz_hinge
def loss_comb(x,y):
    return F.binary_cross_entropy_with_logits(x,y) + \
        0.01*0.5*(lovasz_hinge(x,y,per_image=False) + lovasz_hinge(-x,1-y,per_image=False))

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict
from src.coat import CoaT,coat_lite_mini,coat_lite_small,coat_lite_medium
    
class LayerNorm2d(nn.Module):
    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight[:, None, None] * x + self.bias[:, None, None]
        return x
    
def icnr_init(x, scale=2, init=nn.init.kaiming_normal_):
    "ICNR init of `x`, with `scale` and `init` function"
    ni,nf,h,w = x.shape
    ni2 = int(ni/(scale**2))
    k = init(x.new_zeros([ni2,nf,h,w])).transpose(0, 1)
    k = k.contiguous().view(ni2, nf, -1)
    k = k.repeat(1, 1, scale**2)
    return k.contiguous().view([nf,ni,h,w]).transpose(0, 1)

class PixelShuffle_ICNR(nn.Sequential):
    def __init__(self, ni, nf=None, scale=2, blur=True):
        super().__init__()
        nf = ni if nf is None else nf
        layers = [nn.Conv2d(ni, nf*(scale**2), 1), LayerNorm2d(nf*(scale**2)), 
                  nn.GELU(), nn.PixelShuffle(scale)]
        layers[0].weight.data.copy_(icnr_init(layers[0].weight.data))
        if blur: layers += [nn.ReplicationPad2d((1,0,1,0)), nn.AvgPool2d(2, stride=1)]
        super().__init__(*layers)
    
class FPN(nn.Module):
    def __init__(self, input_channels:list, output_channels:list):
        super().__init__()
        self.convs = nn.ModuleList(
            [nn.Sequential(nn.Conv2d(in_ch, out_ch*2, kernel_size=3, padding=1),
             nn.GELU(), LayerNorm2d(out_ch*2),
             nn.Conv2d(out_ch*2, out_ch, kernel_size=3, padding=1))
            for in_ch, out_ch in zip(input_channels, output_channels)])
        
    def forward(self, xs:list, last_layer):
        hcs = [F.interpolate(c(x),scale_factor=2**(len(self.convs)-i),mode='bilinear') 
               for i,(c,x) in enumerate(zip(self.convs, xs))]
        hcs.append(last_layer)
        return torch.cat(hcs, dim=1)

class UnetBlock(nn.Module):
    def __init__(self, up_in_c:int, x_in_c:int, nf:int=None, blur:bool=False,
                 **kwargs):
        super().__init__()
        self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, **kwargs)
        self.bn = LayerNorm2d(x_in_c)
        ni = up_in_c//2 + x_in_c
        nf = nf if nf is not None else max(up_in_c//2,32)
        self.conv1 = nn.Sequential(nn.Conv2d(ni, nf, 3, padding=1),nn.GELU())
        self.conv2 = nn.Sequential(nn.Conv2d(nf, nf, 3, padding=1),nn.GELU())
        self.relu = nn.GELU()

    def forward(self, up_in:torch.Tensor, left_in:torch.Tensor) -> torch.Tensor:
        s = left_in
        up_out = self.shuf(up_in)
        cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
        return self.conv2(self.conv1(cat_x))
    
class UpBlock(nn.Module):
    def __init__(self, up_in_c:int, nf:int=None, blur:bool=True,
                 **kwargs):
        super().__init__()
        ni = up_in_c//4
        self.shuf = PixelShuffle_ICNR(up_in_c, ni, blur=blur, **kwargs)
        nf = nf if nf is not None else max(up_in_c//4,16)
        self.conv = nn.Sequential(nn.Conv2d(ni, ni, 3, padding=1),
                                  LayerNorm2d(ni) if ni >= 16 else nn.Identity(),
                                  nn.GELU(),nn.Conv2d(ni, nf, 1))

    def forward(self, up_in:torch.Tensor) -> torch.Tensor:
        return self.conv(self.shuf(up_in))
    
class LSTM_block(nn.Module):
    def __init__(self, n, bidirectional=False, num_layers=1, **kwargs):
        super().__init__()
        self.lstm = nn.LSTM(n, n if not bidirectional else n//2, batch_first=True,
                            bidirectional=bidirectional, num_layers=num_layers)
    
    def forward(self,x):
        s = x.shape
        x = x.flatten(-2,-1).permute(0,3,1,2).flatten(0,1)
        x = self.lstm(x)[0]
        x = x.view(-1,s[3],s[4],s[1],s[2]).permute(0,3,4,1,2)
        return x
    
class CoaT_ULSTM(nn.Module):
    def __init__(self, pre='coat_lite_medium_a750cd63.pth', arch='medium', num_classes=1, ps=0, **kwargs):
        super().__init__()
        if arch == 'mini': 
            self.enc = coat_lite_mini(return_interm_layers=True)
            nc = [64,128,320,512]
        elif arch == 'small': 
            self.enc = coat_lite_small(return_interm_layers=True)
            nc = [64,128,320,512]
        elif arch == 'medium': 
            self.enc = coat_lite_medium(return_interm_layers=True)
            nc = [128,256,320,512]
        else: raise Exception('Unknown model') 
        
        if pre is not None:
            sd = torch.load(pre)['model']
            print(self.enc.load_state_dict(sd,strict=False))
        
        self.lstm = nn.ModuleList([LSTM_block(nc[-2]),LSTM_block(nc[-1])])
        self.dec4 = UnetBlock(nc[-1],nc[-2],384)
        self.dec3 = UnetBlock(384,nc[-3],192)
        self.dec2 = UnetBlock(192,nc[-4],96)
        self.fpn = FPN([nc[-1],384,192],[32]*3)
        self.drop = nn.Dropout2d(ps)
        #self.final_conv = nn.Conv2d(96+32*3, num_classes, 3, padding=1)
        #self.final_conv = nn.Sequential(UpBlock(96+32*3, 32),
        #                                UpBlock(32, num_classes, blur=True))
        self.final_conv = nn.Sequential(UpBlock(96+32*3, num_classes, blur=True))
        self.up_result=1
    
    def forward(self, x):
        nt = x.shape[2]
        x = x.permute(0,2,1,3,4).flatten(0,1)
        encs = self.enc(x)
        encs = [encs[k] for k in encs]
        encs = [encs[0].view(-1,nt,*encs[0].shape[1:])[:,-1], 
                encs[1].view(-1,nt,*encs[1].shape[1:])[:,-1], 
                self.lstm[-2](encs[2].view(-1,nt,*encs[2].shape[1:]))[:,-1],
                self.lstm[-1](encs[3].view(-1,nt,*encs[3].shape[1:]))[:,-1]]
        dec4 = encs[-1]
        dec3 = self.dec4(dec4,encs[-2])
        dec2 = self.dec3(dec3,encs[-3])
        dec1 = self.dec2(dec2,encs[-4])
        x = self.fpn([dec4, dec3, dec2], dec1)
        x = self.final_conv(self.drop(x))
        if self.up_result != 0: x = F.interpolate(x,scale_factor=self.up_result,mode='bilinear')
        return x
    
    split_layers = lambda model: \
                 (lambda m: [list(m.enc.parameters()),
                 list(m.lstm.parameters()) + 
                 list(m.dec4.parameters()) + 
                 list(m.dec3.parameters()) + list(m.dec2.parameters()) +
                 list(m.fpn.parameters()) +
                 list(m.final_conv.parameters())]) \
                 (model if not isinstance(model, nn.DataParallel) else model.module)
    
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import math
from timm.models.layers import drop_path, to_2tuple, trunc_normal_

class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)
    
    def extra_repr(self) -> str:
        return 'p={}'.format(self.drop_prob)
    
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        # x = self.drop(x)
        # commit this for the orignal BERT implement 
        x = self.fc2(x)
        x = self.drop(x)
        return x

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim=16, M=10000):
        super().__init__()
        self.dim = dim
        self.M = M

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.M) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * (-emb))
        emb = x[...,None] * emb[None,...]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb
    
#BEiTv2 block
class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, **kwargs):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=drop, batch_first=True)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if init_values is not None:
            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
        else:
            self.gamma_1, self.gamma_2 = None, None

    def forward(self, xq, xk, xv, attn_mask=None, key_padding_mask=None):
        if self.gamma_1 is None:
            x = xq + self.drop_path(self.attn(self.norm1(xq),self.norm1(xk),self.norm1(xv),
                            attn_mask=attn_mask,
                            key_padding_mask=key_padding_mask,
                            need_weights=False)[0])
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        else:
            x = xq + self.drop_path(self.gamma_1 * self.attn(self.norm1(xq),self.norm1(xk),self.norm1(xv),
                            attn_mask=attn_mask,
                            key_padding_mask=key_padding_mask,
                            need_weights=False)[0])
            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
        return x
    
class Tmixer(nn.Module):
    def __init__(self, n, head_size=32, num_layers=2, **kwargs):
        super().__init__()
        self.seq_enc = SinusoidalPosEmb(n)
        self.blocks = nn.ModuleList([Block(n,n//64) for i in range(num_layers)])
    
    def forward(self,x):
        B,N,C,H,W = x.shape
        x = x.flatten(-2,-1).permute(0,1,3,2)
        
        enc = self.seq_enc(torch.arange(N, device=x.device)).view(1,N,1,C)
        xq = x[:,-1] + enc[:,-1]
        xk = (x + enc).flatten(1,2)
        xv = x.flatten(1,2)
        
        for m in self.blocks: xq = m(xq,xk,xv)
        
        x = xq.view(B,H,W,C).permute(0,3,1,2)
        return x
    
class CoaT_ULSTM(nn.Module):
    def __init__(self, pre='coat_lite_medium_a750cd63.pth', arch='medium', num_classes=1, ps=0, 
                 num_layers=2, **kwargs):
        super().__init__()
        if arch == 'mini': 
            self.enc = coat_lite_mini(return_interm_layers=True)
            nc = [64,128,320,512]
        elif arch == 'small': 
            self.enc = coat_lite_small(return_interm_layers=True)
            nc = [64,128,320,512]
        elif arch == 'medium': 
            self.enc = coat_lite_medium(return_interm_layers=True)
            nc = [128,256,320,512]
        else: raise Exception('Unknown model') 
        
        if pre is not None:
            sd = torch.load(pre)['model']
            print(self.enc.load_state_dict(sd,strict=False))
        
        self.mixer = nn.ModuleList([Tmixer(nc[-2],num_layers=num_layers),
                                    Tmixer(nc[-1],num_layers=num_layers)])
        self.dec4 = UnetBlock(nc[-1],nc[-2],384)
        self.dec3 = UnetBlock(384,nc[-3],192)
        self.dec2 = UnetBlock(192,nc[-4],96)
        self.fpn = FPN([nc[-1],384,192],[32]*3)
        self.drop = nn.Dropout2d(ps)
        #self.final_conv = nn.Conv2d(96+32*3, num_classes, 3, padding=1)
        #self.final_conv = nn.Sequential(UpBlock(96+32*3, 32),
        #                                UpBlock(32, num_classes, blur=True))
        self.final_conv = nn.Sequential(UpBlock(96+32*3, num_classes, blur=True))
        self.up_result=1
    
    def forward(self, x):
        nt = x.shape[2]
        x = x.permute(0,2,1,3,4).flatten(0,1)
        encs = self.enc(x)
        encs = [encs[k] for k in encs]
        encs = [encs[0].view(-1,nt,*encs[0].shape[1:])[:,-1], 
                encs[1].view(-1,nt,*encs[1].shape[1:])[:,-1], 
                self.mixer[-2](encs[2].view(-1,nt,*encs[2].shape[1:])),
                self.mixer[-1](encs[3].view(-1,nt,*encs[3].shape[1:]))]
        dec4 = encs[-1]
        dec3 = self.dec4(dec4,encs[-2])
        dec2 = self.dec3(dec3,encs[-3])
        dec1 = self.dec2(dec2,encs[-4])
        x = self.fpn([dec4, dec3, dec2], dec1)
        x = self.final_conv(self.drop(x))
        if self.up_result != 0: x = F.interpolate(x,scale_factor=self.up_result,mode='bilinear')
        return x
    
    split_layers = lambda model: \
                 (lambda m: [list(m.enc.parameters()),
                 list(m.mixer.parameters()) + 
                 list(m.dec4.parameters()) + 
                 list(m.dec3.parameters()) + list(m.dec2.parameters()) +
                 list(m.fpn.parameters()) +
                 list(m.final_conv.parameters())]) \
                 (model if not isinstance(model, nn.DataParallel) else model.module)

In [None]:
OUT = 'experiments'
fname = 'Seq_CoaT_512_1'

for fold in range(1):
    ds_train = ContrailsDataset(CONFIG.path, train=True, tfms=get_aug())
    ds_val = ContrailsDataset(CONFIG.path, train=False, tfms=None)
    
    model = CoaT_ULSTM().cuda()
    model = nn.DataParallel(model)
    
    data = ImageDataLoaders.from_dsets(ds_train,
                                   ds_val,
                                   bs=CONFIG.bs,
                                   num_workers=CONFIG.num_workers,
                                   pin_memory=True
                                  ).cuda()

    learn = Learner(data, 
                    model,
                    path = CONFIG.out, 
                    loss_func=loss_comb,
                    metrics=[F_th()],
                    cbs=[
                    GradientClip(3.0),
                    GradientAccumulation(32//CONFIG.bs),
                    CSVLogger(),
                    SaveModelCallback(monitor='f_th'),
                    ],
                    opt_func=partial(WrapperOver9000,eps=1e-4),
                    #opt_func=partial(WrapperAdamW,eps=1e-4),
                    #splitter = SwinFormer.split_layers
                   ).to_fp16()
    
    learn.fit_one_cycle(24, lr_max=3.5e-4, pct_start=0.1)
    torch.save(learn.model.module.state_dict(),os.path.join(OUT,f'{fname}_{fold}.pth'))
    
    #del learn, data, ds_train, ds_val
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
metric = learn.metrics[0]
dices = (1 + metric.beta**2)*metric.tp/\
        ((1 + metric.beta**2)*metric.tp + metric.beta**2*metric.fn + metric.fp + 1e-6)
ths = metric.ths

best_dice = dices.max()
best_thr = ths[dices.argmax()]
plt.figure(figsize=(8,4))
plt.plot(ths, dices, color='blue')
plt.vlines(x=best_thr, ymin=dices.min(), ymax=dices.max(), colors='black')
d = dices.max() - dices.min()
plt.text(ths[-1]-0.1, best_dice-0.1*d, f'DICE = {best_dice:.3f}', fontsize=12);
plt.text(ths[-1]-0.1, best_dice-0.2*d, f'TH = {best_thr:.3f}', fontsize=12);
plt.show()

In [None]:
#Transformer mixer, 4 layers
OUT = 'experiments'
fname = 'Seq_CoaT_512_2'

for fold in range(1):
    ds_train = ContrailsDataset(CONFIG.path, train=True, tfms=get_aug())
    ds_val = ContrailsDataset(CONFIG.path, train=False, tfms=None)
    
    model = CoaT_ULSTM(num_layers=4).cuda()
    model = nn.DataParallel(model)
    
    data = ImageDataLoaders.from_dsets(ds_train,
                                   ds_val,
                                   bs=CONFIG.bs,
                                   num_workers=CONFIG.num_workers,
                                   pin_memory=True
                                  ).cuda()

    learn = Learner(data, 
                    model,
                    path = CONFIG.out, 
                    loss_func=loss_comb,
                    metrics=[F_th()],
                    cbs=[
                    GradientClip(3.0),
                    GradientAccumulation(32//CONFIG.bs),
                    CSVLogger(),
                    SaveModelCallback(monitor='f_th'),
                    ],
                    opt_func=partial(WrapperOver9000,eps=1e-4),
                    #opt_func=partial(WrapperAdamW,eps=1e-4),
                    #splitter = SwinFormer.split_layers
                   ).to_fp16()
    
    learn.fit_one_cycle(24, lr_max=3.5e-4, pct_start=0.1)
    torch.save(learn.model.module.state_dict(),os.path.join(OUT,f'{fname}_{fold}.pth'))
    
    #del learn, data, ds_train, ds_val
    gc.collect()
    torch.cuda.empty_cache()