In [None]:
!pip install -r requirements.txt

In [1]:
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
import cv2

import numpy as np

import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.optim import Adam, SGD, AdamW

from torch.utils.data import DataLoader, Dataset
import cv2
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2

from externals.utils import set_seed, make_dirs, cfg_init
from externals.dataloading import read_image_mask, read_image_mask_downsampling, get_train_valid_dataset, get_transforms, CustomDataset
from externals.models import CNN3D_Segformer, Unet3D_Segformer, CNN3D_Unet, CNN3D_MANet, CNN3D_EfficientUnetplusplusb5, CNN3D_SegformerB4
from externals.metrics import AverageMeter, calc_fbeta
from externals.training_procedures import get_scheduler, scheduler_step, criterion
from torch.optim.swa_utils import AveragedModel, SWALR

In [2]:
class CFG:
    is_multiclass = False
    
    comp_name = 'vesuvius'
    comp_dir_path = './input/'
    comp_folder_name = 'vesuvius-challenge-ink-detection'
    comp_dataset_path = f'{comp_dir_path}{comp_folder_name}/'
    
    exp_name = 'mean_32_channels'
    # ============== pred target =============
    target_size = 1
    # ============== model cfg =============
    model_name = '3dcnn_segformer'
    # ============== training cfg =============
    size = 1024
    tile_size = 1024
    stride = tile_size // 4
    in_chans = 16

    train_batch_size = 9
    valid_batch_size = train_batch_size
    use_amp = True

    scheduler = 'GradualWarmupSchedulerV2'
    epochs = 30

    # adamW warmup
    warmup_factor = 10
    lr = 1e-4 / warmup_factor
    # ============== fold =============
    valid_id = 1
    # ============== fixed =============
    min_lr = 1e-6
    weight_decay = 1e-6
    max_grad_norm = 100
    num_workers = 16
    seed = 42
    # ============== set dataset path =============
    print('set dataset path')

    outputs_path = f'working/outputs/{comp_name}/{exp_name}/'

    submission_dir = outputs_path + 'submissions/'
    submission_path = submission_dir + f'submission_{exp_name}.csv'

    model_dir = outputs_path + \
        f'{comp_name}-models/'

    figures_dir = outputs_path + 'figures/'

    log_dir = outputs_path + 'logs/'
    log_path = log_dir + f'{exp_name}.txt'

set dataset path


In [3]:
cfg_init(CFG)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Model selection

In [4]:
# pick a model from the external folder 
model = CNN3D_Segformer(CFG) 

Downloading (…)lve/main/config.json:   0%|          | 0.00/70.0k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/179M [00:00<?, ?B/s]

Some weights of the model checkpoint at nvidia/mit-b3 were not used when initializing SegformerForSemanticSegmentation: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b3 and are newly initialized: ['decode_head.linear_c.3.proj.weight', 'decode_head.batch_norm.bias', 'decode_head.linear_c.0.proj.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.linear_c.2.proj.bias', 'decode_head

In [5]:
#simple model test
model(torch.ones(4, 1, 32, 128, 128)).shape

torch.Size([4, 1, 128, 128])

## train, val

In [None]:
def train_fn(train_loader, model, criterion, optimizer, device):
    model.train()

    scaler = GradScaler(enabled=CFG.use_amp)
    losses = AverageMeter()

    for step, (images, labels) in tqdm(enumerate(train_loader), total=len(train_loader)):
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)

        with autocast(CFG.use_amp):
            y_preds = model(images)
            loss = criterion(y_preds, labels)

        losses.update(loss.item(), batch_size)
        scaler.scale(loss).backward()

        grad_norm = torch.nn.utils.clip_grad_norm_(
            model.parameters(), CFG.max_grad_norm)

        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

    return losses.avg

def valid_fn(valid_loader, model, criterion, device, valid_xyxys, valid_mask_gt):
    mask_pred = np.ones(valid_mask_gt.shape)
    mask_count = np.zeros(valid_mask_gt.shape)

    model.eval()
    losses = AverageMeter()

    for step, (images, labels) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)

        with torch.no_grad():
            y_preds = model(images)
            loss = criterion(y_preds, labels) #undo the stupid sigmoid they put in this implementation
        losses.update(loss.item(), batch_size)

        # make whole mask
        y_preds = torch.sigmoid(y_preds).to('cpu').numpy()
        start_idx = step*CFG.valid_batch_size
        end_idx = start_idx + batch_size
        for i, (x1, y1, x2, y2) in enumerate(valid_xyxys[start_idx:end_idx]):
            mask_pred[y1:y2, x1:x2] *= y_preds[i].squeeze(0)
            mask_count[y1:y2, x1:x2] += np.ones((CFG.tile_size, CFG.tile_size))

    mask_pred = np.power(mask_pred, 1/mask_count)
    mask_pred[mask_pred==1] = 0
    return losses.avg, mask_pred

## main

In [None]:

def load_data(CFG):
    if CFG.valid_id == None:
        train_images, train_masks = get_train_valid_dataset(CFG)
    else:
        train_images, train_masks, valid_images, valid_masks, valid_xyxys = get_train_valid_dataset(CFG)
        valid_xyxys = np.stack(valid_xyxys)
        fragment_id = CFG.valid_id

        valid_mask_gt = cv2.imread(CFG.comp_dataset_path + f"train/{fragment_id}/inklabels.png", 0)
        valid_mask_gt = valid_mask_gt / 255
        pad0 = (CFG.tile_size - valid_mask_gt.shape[0] % CFG.tile_size)
        pad1 = (CFG.tile_size - valid_mask_gt.shape[1] % CFG.tile_size)
        valid_mask_gt = np.pad(valid_mask_gt, [(0, pad0), (0, pad1)], constant_values=0)
        valid_dataset = CustomDataset(
            valid_images, CFG, labels=valid_masks, transform=get_transforms(data='valid', cfg=CFG))
        valid_loader = DataLoader(valid_dataset,
                        batch_size=CFG.valid_batch_size,
                        shuffle=False,
                        num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
    train_dataset = CustomDataset(
        train_images, CFG, labels=train_masks, transform=get_transforms(data='train', cfg=CFG))
    train_loader = DataLoader(train_dataset,
                            batch_size=CFG.train_batch_size,
                            shuffle=True,
                            num_workers=CFG.num_workers, pin_memory=True, drop_last=True,
                            )
    if CFG.valid_id == None:
        return train_loader
    else:
        return train_loader, valid_loader, valid_xyxys, valid_mask_gt


In [None]:
def return_augs(CFG):
        # ============== augmentation =============
    train_aug_list = [
            # A.RandomResizedCrop(
            #     size, size, scale=(0.85, 1.0)),
            A.Resize(CFG.size, CFG.size),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=.5),
            A.RandomBrightnessContrast(p=0.25, brightness_limit=.2, contrast_limit=.2),
            A.ChannelDropout(channel_drop_range=(1, 2), p = .25),  
            A.ShiftScaleRotate(p=0.25),
            A.OneOf([
                    A.GaussNoise(var_limit=[10, 50]),
                    A.GaussianBlur(),
                    A.MotionBlur(),
                    ], p=0.25),
            A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.25),
            A.CoarseDropout(max_holes=1, max_width=int(CFG.size * 0.05), max_height=int(CFG.size * 0.05), 
                            mask_fill_value=0, p=0.25),
            # A.Cutout(max_h_size=int(size * 0.6),
            #          max_w_size=int(size * 0.6), num_holes=1, p=1.0),
            A.Normalize(
                mean= [0] * CFG.in_chans,
                std= [1] * CFG.in_chans
            ),
            ToTensorV2(transpose_mask=True),
        ]

    valid_aug_list = [
            A.Resize(CFG.size, CFG.size),
            A.Normalize(
                mean= [0] * CFG.in_chans,
                std= [1] * CFG.in_chans
            ),
            ToTensorV2(transpose_mask=True),
        ]
    return train_aug_list, valid_aug_list

In [None]:
train_aug_list, valid_aug_list = return_augs(CFG)
CFG.train_aug_list, CFG.valid_aug_list = train_aug_list, valid_aug_list
cfg_pairs = {value:CFG.__dict__[value] for value in dir(CFG) if value[1] != "_"}
model_name = f"{CFG.exp_name}_{CFG.model_name}"


In [None]:
# model = CNN3D_Segformer(CFG)
model = torch.nn.DataParallel(model)
model.to(device)
swa_model = AveragedModel(model)
swa_start = 2

In [None]:

best_counter = 0
best_loss = np.inf
best_score = 0
optimizer = AdamW(model.parameters(), lr=CFG.lr)
swa_scheduler = SWALR(optimizer, swa_lr=0.05)
scheduler = get_scheduler(CFG, optimizer)
if CFG.valid_id == None:
    train_loader = load_data(CFG)
else:
    train_loader, valid_loader, valid_xyxys, valid_mask_gt = load_data(CFG)

for epoch in range(CFG.epochs):
    # train
    avg_loss = train_fn(train_loader, model, criterion, optimizer, device)
    if epoch > swa_start:
        swa_model.update_parameters(model)
        swa_scheduler.step()
        # torch.optim.swa_utils.update_bn(train_loader, swa_model)
        # Update bn statistics for the swa_model at the end
    if CFG.valid_id != None:
        # eval
        avg_val_loss, mask_pred = valid_fn(
            valid_loader, model, criterion, device, valid_xyxys, valid_mask_gt)

        scheduler_step(scheduler, avg_val_loss, epoch)

        best_dice, best_th, best_metrics = calc_fbeta(valid_mask_gt, mask_pred)

        # score = avg_val_loss
        score = best_dice

        print({"dice":best_dice, "avg_train_loss":avg_loss, "avg_val_loss":avg_val_loss, "ctp":best_metrics[0],
                   "cfp":best_metrics[1], "ctn":best_metrics[2], "cfn":best_metrics[3]})

        update_best = score > best_score
        if update_best:
            best_loss = avg_val_loss
            best_score = score
            best_counter = 0
            torch.save(model.module.state_dict(),
                    CFG.model_dir + f"{model_name}_best.pth")
        else:
            best_counter += 1
            if best_counter > 8:
                break
        torch.save(model.module.state_dict(),
                CFG.model_dir + f"{model_name}_final.pth")
        plt.imshow(mask_pred > best_th)
    else:        
        print({"avg_train_loss":avg_loss})
        scheduler_step(scheduler, avg_loss, epoch)
        if (epoch % 5) == 0:
            torch.save(model.module.state_dict(),
                CFG.model_dir + f"{model_name}_{epoch}_final.pth")
torch.optim.swa_utils.update_bn(train_loader, swa_model)
torch.save(swa_model.module.state_dict(),
    CFG.model_dir + f"{model_name}_final_swa.pth")
        
    