In [1]:
from model import ResNetUNet
import torch
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import glob
import re
import os
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
import albumentations as A
from pathlib import Path

In [4]:
### Model Preparation

model = ResNetUNet()
model.base_model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.layer0[0] = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.conv_original_size0[0] = torch.nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

In [5]:
from torchvision import transforms
import torchvision.datasets as datasets

In [6]:
import warnings
warnings.filterwarnings('ignore')

In [7]:
def change_img_to_label_path(path):
    parts = list(path.parts)
    parts[parts.index("image")] = "label"  
    return Path(*parts)

In [8]:
data_path = Path("/scratch/scratch6/akansh12/Parse_data/processed_train/")

In [22]:
train_transforms = A.Compose([
    A.CenterCrop(224,224),
    A.HorizontalFlip(p=0.3),
    A.VerticalFlip(p = 0.2),
    A.RandomRotate90(p=0.1),
    A.Rotate((-30,30), p = 0.5),
        ])
test_transforms = A.Compose([
    A.CenterCrop(224,224),
    A.Rotate(0, p = 1)
        ])


In [28]:
class ParseDataset(Dataset):
    def __init__(self, img_dir,data_type = 'train', transform = None):
        if data_type == 'train':
            self.img_dir = list(img_dir.glob("*train/*/image/*.npy"))
        if data_type == 'val':
            self.img_dir = list(img_dir.glob("*val/*/image/*.npy"))
            
        self.img_dir.sort()
        self.transforms = transform
    def __len__(self):
        return len(self.img_dir)
    def __getitem__(self,idx):
        img = np.load(self.img_dir[idx]).astype("float32")
        mask = np.load(change_img_to_label_path(self.img_dir[idx]))
        mask = np.clip(mask, 0, 1).astype("float32")
        augmented = self.transforms(image=img, mask=mask)
        img = augmented['image']
        mask = augmented['mask']
        return  torch.tensor(img).unsqueeze(0), torch.FloatTensor(mask).unsqueeze(0)

In [29]:
train_data = ParseDataset(img_dir = data_path, transform = train_transforms)
val_data = ParseDataset(img_dir = data_path,data_type = 'val', transform = test_transforms)

In [30]:
for x,y in train_data:
    print(x.shape)
    print(y.shape)
    break

torch.Size([1, 224, 224])
torch.Size([1, 224, 224])


In [31]:
from torch.utils.data import DataLoader
train_dl = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=0)
val_dl = DataLoader(val_data, batch_size=32, shuffle=False, num_workers = 0)

In [32]:
for img_b, mask_b in train_dl:
    print(img_b.shape,img_b.dtype)
    print(mask_b.shape, mask_b.dtype)
    break


torch.Size([64, 1, 224, 224]) torch.float32
torch.Size([64, 1, 224, 224]) torch.float32


In [33]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device);

In [34]:
def jaccard_coef_metric(inputs, target, eps=1e-7):
    intersection = (target * inputs).sum()
    union = (target.sum() + inputs.sum()) - intersection + eps

    if target.sum() == 0 and inputs.sum() == 0:
        return 1.0

    return (intersection + eps) / union

def dice_coef_metric(inputs, target):
    intersection = 2.0 * (target * inputs).sum()
    union = target.sum() + inputs.sum()
    if target.sum() == 0 and inputs.sum() == 0:
        return 1.0

    return intersection / union

## Loss

def dice_coef_loss(inputs, target):
    smooth = 1.0
    intersection = 2.0 * ((target * inputs).sum()) + smooth
    union = target.sum() + inputs.sum() + smooth

    return 1 - (intersection / union)


def bce_dice_loss(inputs, target):
    dicescore = dice_coef_loss(inputs, target)
    bcescore = nn.BCELoss()
    bceloss = bcescore(inputs, target)

    return bceloss + dicescore


In [35]:
def train_one_epoch(model, optimizer, lr_scheduler, metric,
                    dataloader, epoch, criterion=bce_dice_loss):
    
    print("Start Train ...")
    model.train()

    losses = []
    accur = []

    for data, target in tqdm(dataloader):

        data = data.to(DEVICE).float()
        targets = target.to(DEVICE)

        outputs = model(data)

        out_cut = np.copy(outputs.data.cpu().numpy())
        out_cut[np.nonzero(out_cut < 0.5)] = 0.0
        out_cut[np.nonzero(out_cut >= 0.5)] = 1.0

        train_dice = metric(out_cut, targets.data.cpu().numpy())

        loss = criterion(outputs, targets)

        losses.append(loss.item())
        accur.append(train_dice)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    if lr_scheduler is not None:
        lr_scheduler.step()

    lr = lr_scheduler.get_last_lr()[0]
    print("Epoch [%d]" % (epoch),
          "Mean loss on train:", np.array(losses).mean(), 
          "Mean DICE on train:", np.array(accur).mean(), 
          "Learning Rate:", lr)

    
    return np.array(losses).mean(), np.array(accur).mean(), lr


def val_epoch(model, metric, dataloader, epoch, threshold=0.5):
    
    print("Start Validation ...")
    model.eval()
    
    val_acc = []

    with torch.no_grad():
        for data, targets in tqdm(dataloader):

            data = data.to(DEVICE).float()
            targets = targets.to(DEVICE)

            outputs = model(data)

            out_cut = np.copy(outputs.data.cpu().numpy())
            out_cut[np.nonzero(out_cut < threshold)] = 0.0
            out_cut[np.nonzero(out_cut >= threshold)] = 1.0

            val_dice = metric(out_cut, targets.data.cpu().numpy())
            val_acc.append(val_dice)

        print("Epoch:  " + str(epoch) + "  Threshold:  " + str(threshold)\
              + " Mean Validation DICE Score:", np.array(val_acc).mean())
        
        return  np.array(val_acc).mean()



In [36]:
for param in model.parameters():
    param.requires_grad = True
    
params = [p for p in model.parameters() if p.requires_grad]

stage_epoch =  [12, 8, 15] #[12, 8, 5]
stage_optimizer = [
    torch.optim.Adamax(params, lr=0.0002),
    torch.optim.SGD(params, lr=0.00009, momentum=0.9),
    torch.optim.Adam(params, lr=0.00005),
]

stage_scheduler = [
    torch.optim.lr_scheduler.CosineAnnealingLR(stage_optimizer[0], 4, 1e-6),
    torch.optim.lr_scheduler.CyclicLR(stage_optimizer[1], base_lr=1e-5, max_lr=2e-4),
    torch.optim.lr_scheduler.CosineAnnealingLR(stage_optimizer[2], 4, 1e-6),
]

In [37]:


DEVICE = device
weights_dir = "/scratch/scratch6/akansh12/challenges/parse2022/temp/weights"
if os.path.exists(weights_dir) == False:
    os.mkdir(weights_dir)


loss_history = []
train_dice_history = []
val_dice_history = []
lr_history = []

for k, (num_epochs, optimizer, lr_scheduler) in enumerate(zip(stage_epoch, stage_optimizer, stage_scheduler)):
    for epoch in range(num_epochs):
        
        
        loss, train_dice, lr = train_one_epoch(model, optimizer, lr_scheduler, 
                                               dice_coef_metric, train_dl, epoch)
    
        val_dice = val_epoch(model, dice_coef_metric, val_dl, epoch)
        
        
        # train history
        loss_history.append(loss)
        train_dice_history.append(train_dice)
        lr_history.append(lr)
        val_dice_history.append(val_dice)

        # save best weights
        best_dice = max(val_dice_history)
        if val_dice >= best_dice:
            torch.save({'state_dict': model.state_dict()},
                        os.path.join(weights_dir, f"{val_dice:0.6f}_.pth"))
    
    print("\nNext stage\n")
    # Load the best weights
    best_weights =  sorted(glob.glob(weights_dir + "/*"),
                       key= lambda x: x[8:-5])[-1]
    checkpoint = torch.load(best_weights)
    model.load_state_dict(checkpoint['state_dict'])

    print(f'Loaded model: {best_weights.split("/")[1]}')



Start Train ...


  0%|          | 0/282 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
fig = plt.figure()
plt.imshow(np.load(list(train_path.glob("train/*/image/*.npy"))[0]), cmap = "bone")
plt.imshow(np.load(change_img_to_label_path(list(train_path.glob("train/*/image/*.npy"))[0])), alpha = 0.4,cmap = "bone")