In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
import tensorboardX as ts
import os
import matplotlib.pyplot as plt
import numpy as np
import sys
import tqdm.autonotebook as tqdm
import logging
import torchinfo
import gc
import copy
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
import segmentation_models_pytorch as smp
import albumentations as A
from model import *
from dataset import *
from ColonNext import ColonNext
from unext import UNext
import torch.optim as optim
import torch.cuda.amp as amp
import logging


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
gc.collect()
torch.cuda.empty_cache()
device = "cuda" if torch.cuda.is_available() else "cpu"
writer = ts.SummaryWriter("runs/Unext")
logger = logging.getLogger("Unext")
logging.basicConfig(level=logging.INFO)
lr = 1e-4
momentum = 0.9
weight_decay = 1e-5
num_workers = 2
epoches = 200
batch_size = 24

In [3]:
dataset = read_dataset_csv("train.csv")
print("Load Dataset")
train_transform = A.Compose([
    A.Resize(height=256, width=256),
    A.RandomRotate90(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.ElasticTransform(p=0.5),
])
test_transform = A.Compose([
    A.Resize(height=256, width=256),
])
train_ds = UWMDataset(dataframe=dataset[:-(len(dataset) // 25)], transforms=train_transform, num_classes=4)
test_ds = UWMDataset(dataframe=dataset[-(len(dataset) // 25):], transforms=test_transform, num_classes=4)
train_loader = torch.utils.data.DataLoader(
    train_ds,
    batch_size=8,
    num_workers=num_workers,
    shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
    test_ds,
    batch_size=1,
    num_workers=num_workers,
    shuffle=True,
)
batch_train = next(iter(train_loader))
batch_test = next(iter(test_loader))

Load Dataset


In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# load = torch.load("chekpoint/ColonNext_19.pth")
# load["acc"]

In [5]:
# teacher = SwinFormer(num_classes=4).to(device)
# teacher.load_state_dict(load["net_state"])
# teacher = UNext(num_classes=4).to(device)
# teacher.load_state_dict(load["net"])
model = UNext(4).to(device)

In [6]:
loss_func_1 = smp.losses.TverskyLoss(mode=smp.losses.MULTILABEL_MODE).to(device)
loss_func_2 = smp.losses.FocalLoss(mode=smp.losses.MULTILABEL_MODE).to(device)

In [7]:
# len(dataset)

In [8]:
print(torchinfo.summary(model, (8, 3, 320, 320), device=device,
                            col_names=["kernel_size", "output_size", "num_params", "mult_adds"],
                            row_settings=["var_names"], ))

Layer (type (var_name))                                 Kernel Shape              Output Shape              Param #                   Mult-Adds
UNext (UNext)                                           --                        [8, 4, 320, 320]          --                        --
├─IntermediateLayerGetter (encoder)                     --                        [8, 768, 10, 10]          --                        --
│    └─Conv2dNormActivation (0)                         --                        [8, 96, 80, 80]           --                        --
│    │    └─Conv2d (0)                                  [4, 4]                    [8, 96, 80, 80]           4,704                     240,844,800
│    │    └─LayerNorm2d (1)                             --                        [8, 96, 80, 80]           192                       1,536
│    └─Sequential (1)                                   --                        [8, 96, 80, 80]           --                        --
│    │    └─CNBlock (0

In [9]:
def structure_loss(pred, mask):

    weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
    wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
    wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))

    pred = torch.sigmoid(pred)
    inter = ((pred * mask)*weit).sum(dim=(2, 3))
    union = ((pred + mask)*weit).sum(dim=(2, 3))
    wiou = 1 - (inter + 1)/(union - inter+1)

    return (wbce + wiou).mean()

In [10]:
def metrics(out, targets):
    tp, fp, fn, tn = smp.metrics.get_stats(out, targets.to(device), mode=smp.losses.MULTILABEL_MODE, threshold=0.5)
    return smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")

In [11]:
def dice_coef(y_true, y_pred, thr=0.5, dim=(2,3), epsilon=0.001):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred>thr).to(torch.float32)
    inter = (y_true*y_pred).sum(dim=dim)
    den = y_true.sum(dim=dim) + y_pred.sum(dim=dim)
    dice = ((2*inter+epsilon)/(den+epsilon)).mean(dim=(1,0))
    return dice

In [12]:
def criterion(output, targets):
    losses = 0
    for out in output.values():
        # losses += structure_loss(pred=out.contiguous(), mask=targets.float().to(device))
        losses += loss_func_1(y_pred=out.contiguous(), y_true=targets.to(device)) * 0.5
        losses += loss_func_2(y_pred=out.contiguous(), y_true=targets.to(device)) * 0.5
    return losses

In [13]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=6000, eta_min=1e-6)
scaler = amp.GradScaler()
# scaler.load_state_dict(load["scaler"])
# optimizer.load_state_dict(load["optimizer"])
# scheduler.load_state_dict(load["scheduler"])

In [14]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [15]:
def overfit_test():
    overfitimg, overfitlabels = next(iter(train_loader))
    overfitimg = overfitimg.float().to(device)
    overfitlabels = overfitlabels.to(device)
    for i in range(25):
        model.train()
        loss_total = AverageMeter()
        metric_total = AverageMeter()
        with torch.cuda.amp.autocast():
            outputs = model(overfitimg)
            loss = criterion(outputs, overfitlabels)
            metric = dice_coef(y_pred=outputs["out1"], y_true=overfitlabels)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        loss_total.update(loss)
        metric_total.update(metric)
        torch.cuda.empty_cache()
        logger.info(f'Overfit Test: Epoch:{i} Loss:{loss_total.avg:.4}, Metric:{metric_total.avg:.4} ')

In [16]:
# overfit_test()

In [17]:
def train(epoch, net, trainloader, criterion, optimizer, scaler, scheduler, name):
    net.train()
    loss_total = AverageMeter()
    metric_total = AverageMeter()
    loop = tqdm.tqdm(trainloader, total=len(trainloader))
    for batch_idx, (inputs, targets) in enumerate(loop):
        inputs, targets = inputs.float().to(device), targets.to(device)
        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            metric = dice_coef(y_pred=outputs["out1"], y_true=targets)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        loss_total.update(loss)
        metric_total.update(metric)
        scheduler.step()
        torch.cuda.empty_cache()
        loop.set_description(f'Train: Epoch:{epoch} Loss:{loss_total.avg:.4} Metric:{metric_total.avg:.4} ')

    writer.add_scalar('Loss/train', loss_total.avg.item(), epoch)
    writer.add_scalar('Metric/train', metric_total.avg.item(), epoch)
    logger.info(f'Train: Epoch:{epoch} Loss:{loss_total.avg:.4} Metric:{metric_total.avg:.4} ')
    return net


def test(epoch, net, testloader, criterion, scaler, optimizer, scheduler, checkpoint, name):
    net.eval()
    loss_total = AverageMeter()
    metric_total = AverageMeter()
    with torch.no_grad():
        loop = tqdm.tqdm(testloader, total=len(testloader))
        for batch_idx, (inputs, targets) in enumerate(loop):
            inputs, targets = inputs.float().to(device), targets.to(device)
            with torch.cuda.amp.autocast():
                outputs = net(inputs)
                loss = criterion(outputs, targets)
                metric = dice_coef(y_pred=outputs["out1"], y_true=targets)
            loss_total.update(loss)
            metric_total.update(metric)
            loop.set_description(f'Train: Test:{epoch} Loss:{loss_total.avg:.4} Metric:{metric_total.avg:.4} ')
        writer.add_scalar('Loss/test', loss_total.avg.item(), epoch)
        writer.add_scalar('Metric/test', metric_total.avg.item(), epoch)
        logger.info(f'Test:  Epoch:{epoch} Loss:{loss_total.avg:.4}  Metric:{metric_total.avg:.4} ')

    # Save checkpoint
    checkpoint.save(net=net, acc=metric_total.avg.item(), filename=name, scaler=scaler, epoch=epoch, optimizer=optimizer, scheduler=scheduler)
    torch.cuda.empty_cache()
    print()


class Checkpoint(object):
    def __init__(self):
        self.best_acc = 0.
        self.folder = 'chekpoint'
        os.makedirs(self.folder, exist_ok=True)

    def save(self, net, acc, filename, scaler, optimizer, scheduler, epoch=-1):
        if acc > self.best_acc:
            logger.info('Saving checkpoint...')
            state = {
                'net': net.state_dict(),
                'acc': acc,
                'epoch': epoch,
                'scaler': scaler.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
            }
            path = os.path.join(os.path.abspath(self.folder), f"{filename}_{epoch}" + '.pth')
            torch.save(state, path)
            self.best_acc = acc

    def load(self, net):
        pass

In [18]:
ckpt = Checkpoint()
start, end = 0, 50
for epoch in range(start, end):
    teacher = train(epoch=epoch, net=model, trainloader=train_loader, criterion=criterion, scaler=scaler,
        optimizer=optimizer, scheduler=scheduler, name="Unext")
    test(epoch=epoch, net=teacher, testloader=test_loader, criterion=criterion, scaler=scaler,
         optimizer=optimizer, scheduler=scheduler, checkpoint=ckpt, name="Unext")

Train: Epoch:0 Loss:0.9764 Metric:0.6299 :  72%|███████▏  | 1432/1991 [08:20<03:15,  2.86it/s]


KeyboardInterrupt: 