In [None]:
# Setup
import sys
import subprocess
import pkg_resources


def setup(required):
    required = set(required)
    installed = {pkg.key for pkg in pkg_resources.working_set}
    missing   = required - installed
    if missing:
        # implement pip as a subprocess:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', *missing])
        
setup({'numpy', 'matplotlib', 'pandas', 'torch', 'wandb', 'tqdm', 'opencv-python', 'torchvision', })

In [None]:
# Parameters
import torch

INPUT_DIR = r"C:\Users\nmttu\Downloads\bkai-igh-neopolyp"
OUTPUT_DIR = "out"
TRAIN_DIR = f"{INPUT_DIR}/train/train"
TEST_DIR = f"{INPUT_DIR}/test/test"
CHECKPOINT_DIR = r"checkpoint"
CHECKPOINT_PATH = r"scratch_11152023_121350_5.pth"

num_workers = 2

device = "cuda" if torch.cuda.is_available() else "cpu"
learning_rate = 10**-3
end_learning_rate = 10*678*-5
num_epochs = 50
batch_size = 1
img_size = (736, 960)
pretrain = "scratch"
split_ratio = [0.8, 0.2]
loss_func = torch.nn.CrossEntropyLoss(weight=torch.tensor([0.45, 0.45, 0.1], device=device))

report_step = 200

In [None]:
# Utils
import os
import torch
from torchvision import transforms
from torchvision.transforms import functional as TF
import math
import wandb
from datetime import datetime, timezone, timedelta

def argmax2img(arr):
    if len(arr.shape) == 4:
        return torch.stack([argmax2img(img) for img in arr])
    img = torch.zeros(3, arr.shape[-2], arr.shape[-1])
    red = arr == 0
    green = arr == 1
    img[0, :, :] = red
    img[1, :, :] = green
    return img.float()

def transform(img, gt, random_seed: int = None):
    
    if random_seed is not None:
        torch.manual_seed(random_seed)
    transformation = transforms.Compose([transforms.RandomPerspective(p = 1),
                                          transforms.RandomRotation(degrees = math.pi / 4,expand = True),])
    stacked = torch.stack(tensors = [img, gt], dim = 0)
    stacked = transformation(stacked)
    img, gt = stacked[0, :, :], stacked[1, :, :]
    img = transforms.ColorJitter()(img)
    return img, gt

def train(dataloader, model, num_epochs = 100, learning_rate = 10**-3, loss_func = None, pretrain_name="scratch", start=0, end_learning_rate = None, report_step = 1000, vali_dataloader = None):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    if end_learning_rate is not None:
        expo = (end_learning_rate / learning_rate) ** (1/num_epochs)
    else:
        expo = 1
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = expo)
    
    accu_loss = 0
    for e in range(1 + start, 1 + start+num_epochs):
        for idx, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            pred = model(x)
            y = TF.center_crop(y, [pred.shape[-2], pred.shape[-1]])
            loss = loss_func(pred, y)
            loss.backward()
            optimizer.step()
            print(f"Epoch {e}/{num_epochs} --- Batch {idx + 1}/{len(dataloader)} --- Loss {loss.item():.4f}", end='\r')
            accu_loss+= loss.detach()
            if ((idx + 1)*dataloader.batch_size) % report_step == 0:
                wandb.log({"Train Loss": accu_loss.item() / report_step})
                accu_loss = 0
                table = wandb.Table(columns=["Predict", "Target"])
                table.add_data(wandb.Image(pred[0]),
                               wandb.Image(argmax2img(y[0])))
                wandb.log({f"Comparision": table})

                # Validation
                if vali_dataloader is not None:
                    model.eval()
                    with torch.no_grad():
                        vali_loss = 0
                        for idx, (x, y) in enumerate(vali_dataloader):
                            x, y = x.to(device), y.to(device)
                            pred = model(x)
                            y = TF.center_crop(y, [pred.shape[-2], pred.shape[-1]])
                            vali_loss += loss_func(pred, y)
                        vali_loss /= len(vali_dataloader)
                        wandb.log({"Validation Loss": vali_loss.item()})
                    model.train()
        scheduler.step()
        
        checkpoint_path = f"{pretrain_name}_{datetime.now(tz=timezone(timedelta(hours=7))).strftime(r'%m%d%Y_%H%M%S')}_{e}.pth"
        torch.save({"model":model,
                        "optimizer_state_dict":optimizer.state_dict()}, os.path.join(CHECKPOINT_DIR, checkpoint_path))

In [None]:
# Dataset
import os

import torch
from torchvision.io import read_image
from torchvision.transforms import functional as TF


class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, directory:str, transform=None, standard_size = (608, 800), train = True):
        '''Param:
        `dir` (string): Path to the directory containing samples. This directory cannot contain any other files.
        `name` (string, optional): Name of the dataset
        `transform` (callable, optional): The transformation to be applied to a sample'''

        self.__train = train
        self.transform = transform
        self.__standard_size = standard_size # Sketch images to one standard size. 1280x1024 by default.
        self.__list_samples__ = [] # A list containing samples' filenames.
        self.__list_gt__ = [] # A list containing grouth truth images' filenames.
        
        try:
            if self.__train:
                list_files = os.listdir(f"{directory}/train/train")
                for file_name in list_files: # Sample and ground truth pair should have the same name.
                    self.__list_samples__.append(f'{directory}/train/train/{file_name}')
                    self.__list_gt__.append(f'{directory}/train_gt/train_gt/{file_name}')
            else:
                list_files = os.listdir(f"{directory}/test/test")
                for file_name in list_files:
                    self.__list_samples__.append(f'{directory}/test/test/{file_name}')
        except FileNotFoundError:
            err_msg = f"Directory {directory} does not exist!"
            raise FileNotFoundError(err_msg)
    
    def __len__(self):
        return len(self.__list_samples__)

    def __getitem__(self, idx):
        sample = read_image(self.__list_samples__[idx]) / 255
        if self.__train:
            gt = read_image(self.__list_gt__[idx]) / 255
        else:
            gt = torch.zeros(size=sample.shape)
        if self.transform:
            sample, gt = self.transform(sample, gt)
        sample = TF.resize(sample, self.__standard_size)
        gt = TF.resize(gt, self.__standard_size)
        gt = ((gt - 0.8) > 0).float()
        # To class label
        red = gt[0]
        green = gt[1]
        background = 1 - (red + green) # No red or no green --> background = 1
        gt = green + background*2 # gt[..] = 0 means red, gt[..] = 1 means green, gt[..] = 2 means background
        return sample, gt.long()

In [None]:
# Model
from torch import nn
import torch
import torchvision.transforms.functional as TF

class ConvReLU(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 3, pre_activation = None, padding = 0):
        super().__init__()
        
        self.conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = kernel_size, padding = padding)
        self.batchnorm = nn.BatchNorm2d(num_features=out_channels)
        self.relu = nn.ReLU()
        if pre_activation is None or not isinstance(pre_activation, nn.Module):
            self.pre_activation = nn.Identity()
        else:
            self.pre_activation = pre_activation
    def forward(self, x):
        x = self.pre_activation(x)
        x = self.conv(x)
        x = self.batchnorm(x)
        x = self.relu(x)
        return x
    
class StackedConvReLU(nn.Module):
    def __init__(self, in_channels, out_channels, stack = 2, kernel_size = 3, pre_activation = None, padding = 0):
        super().__init__()
        self.stack = nn.Sequential(ConvReLU(in_channels, out_channels, kernel_size, padding=padding))
        for _ in range(stack - 1):
            self.stack.append(ConvReLU(out_channels, out_channels, kernel_size, padding=padding))
        if pre_activation is None or not isinstance(pre_activation, nn.Module):
            self.pre_activation = nn.Identity()
        else:
            self.pre_activation = pre_activation
    def forward(self, x):
        x = self.pre_activation(x)
        return self.stack(x)
    
class Up(StackedConvReLU):
    def __init__(self, in_channels, out_channels, stack = 2, kernel_size = 3, pre_activation = None, padding = 1):
        '''A child class with modifications so that the forward will concat 2 tensors before send them to self.stack.foward'''
        super().__init__(in_channels=in_channels, out_channels=out_channels, stack=stack, kernel_size=kernel_size, pre_activation=pre_activation, padding=padding)
    def forward(self, x, skip):
        x = self.pre_activation(x)
        skip = TF.center_crop(skip, [x.shape[-2], x.shape[-1]])
        stacked = torch.cat([x, skip], dim=1)
        return self.stack(stacked)
    
class UNet(nn.Module):
    def __init__(self, encoder_channels: list, in_channels: int = 3, out_channels: int = 3):
        super().__init__()
        self.encoder = nn.Sequential(StackedConvReLU(in_channels, encoder_channels[0]))
        self.encoder.extend([StackedConvReLU(encoder_channels[idx],
                                             encoder_channels[idx + 1],
                                             pre_activation = nn.MaxPool2d(2)) for idx in range(0, len(encoder_channels) - 1)])
        
        self.decoder = nn.Sequential(*[Up(encoder_channels[idx],
                                                       encoder_channels[idx - 1],
                                                       pre_activation = nn.ConvTranspose2d(encoder_channels[idx], encoder_channels[idx - 1] , 2 , 2)) for idx in range(len(encoder_channels)-1, 0, -1)]) 
        self.head = nn.Conv2d(in_channels = encoder_channels[0], out_channels = out_channels, kernel_size = 1)
    def forward(self, x):
        encode = []
        for layer in self.encoder:
            x = layer(x)
            encode.insert(0, x)
        encode = encode[1:]

        for idx, layer in enumerate(self.decoder):
            x = layer(x, encode[idx])
        return self.head(x)

In [None]:
# Train
import logging

logging.basicConfig(level=logging.ERROR)

from datetime import datetime, timedelta, timezone

import torch
import wandb

train_img = ImageDataset(directory=INPUT_DIR, train=True, transform=transform, standard_size=img_size)
train_set, vali_set = torch.utils.data.random_split(train_img, split_ratio)
train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers,)
vali_dataloader = torch.utils.data.DataLoader(vali_set, batch_size=batch_size, shuffle=True, num_workers=num_workers,)
model = UNet([16, 32, 64, 128, 256], 3, )
model.to(device)
wandb.login()
wandb.init(
    project="IntroToDLSegmentation",
    name=f'{pretrain}_weightedCE_{datetime.now(tz=timezone(timedelta(hours=7)))}',
    config={"learning_rate":learning_rate,
            "architecture": "UNet",
            "dataset": "BK NeoPolyp",
            "epochs": num_epochs,
           },
)
train(train_dataloader, model, num_epochs = num_epochs, learning_rate = learning_rate, loss_func= loss_func, pretrain_name=pretrain, end_learning_rate = end_learning_rate, report_step = report_step, vali_dataloader=vali_dataloader)
wandb.finish()

In [None]:
import os

import numpy as np
import torch
from torchvision.io import read_image, write_png
from torchvision import transforms
from torchvision.transforms import functional as TF
from torchvision import transforms
from tqdm import tqdm
import cv2

try:
    os.mkdir(OUTPUT_DIR)
except FileExistsError:
    pass
to_tensor = transforms.ToTensor()

#model = torch.load(f"{CHECKPOINT_DIR}/{CHECKPOINT_PATH}", map_location=device)["model"]
# Use trained model above
model.eval()

with torch.no_grad():
    for file_name in tqdm(os.listdir(TEST_DIR)):
        img = read_image(f"{TEST_DIR}/{file_name}") / 255
        original_size = [val for val in img.shape[:-1]] 
        resized = TF.resize(img.unsqueeze(0), img_size) 
        pred = model(resized.to(device)).detach().cpu() 
        
        resized_pred = TF.resize(pred, original_size)
        res = (argmax2img(resized_pred[0].argmax(dim=0)) * 255).type(torch.uint8)
        write_png(res, f"{OUTPUT_DIR}/{file_name[:-5]}.png", compression_level=0)


import pandas as pd


def rle_to_string(runs):
    return ' '.join(str(x) for x in runs)

def rle_encode_one_mask(mask):
    pixels = mask.flatten()
    pixels[pixels > 0] = 255
    use_padding = False
    if pixels[0] or pixels[-1]:
        use_padding = True
        pixel_padded = np.zeros([len(pixels) + 2], dtype=pixels.dtype)
        pixel_padded[1:-1] = pixels
        pixels = pixel_padded
    
    rle = np.where(pixels[1:] != pixels[:-1])[0] + 2
    if use_padding:
        rle = rle - 1
    rle[1::2] = rle[1::2] - rle[:-1:2]
    return rle_to_string(rle)

def mask2string(dir):
    
    strings = []
    ids = []
    ws, hs = [[] for i in range(2)]
    for image_id in os.listdir(dir):
        id = image_id.split('.')[0]
        path = os.path.join(dir, image_id)
        print(path)
        img = cv2.imread(path)[:,:,::-1]
        h, w = img.shape[0], img.shape[1]
        for channel in range(2):
            ws.append(w)
            hs.append(h)
            ids.append(f'{id}_{channel}')
            string = rle_encode_one_mask(img[:,:,channel])
            strings.append(string)
    r = {
        'ids': ids,
        'strings': strings,
    }
    return r


dir = OUTPUT_DIR
res = mask2string(dir)
df = pd.DataFrame(columns=['Id', 'Expected'])
df['Id'] = res['ids']
df['Expected'] = res['strings']
df.to_csv(r'output.csv', index=False)