In [1]:
import torch
import pandas as pd
import numpy as np
import albumentations as ATransforms
from albumentations.pytorch import ToTensorV2
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms.functional as TF

import os
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

cuda


## Hyperparameters

In [3]:
EPOCHS = 50
BATCH_SIZE = 16
LEARNING_RATE = 1e-3
NUM_WORKERS = 4

LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_FILE = "checkpoint-test.pth"

IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
IMAGE_CHANNELS = 3




## Model Architecture

In [4]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3,stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


In [5]:
class UNET(nn.Module):
    def __init__(self, in_channels=IMAGE_CHANNELS, out_channels=IMAGE_CHANNELS):
        super(UNET, self).__init__()
        self.conv1 = DoubleConv(in_channels, 64)
        self.conv2 = DoubleConv(64, 128)
        self.conv3 = DoubleConv(128, 256)
        self.conv4 = DoubleConv(256, 512)
        self.conv5 = DoubleConv(512, 1024)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv6 = DoubleConv(1024, 512)
        self.conv7 = DoubleConv(512, 256)
        self.conv8 = DoubleConv(256, 128)
        self.conv9 = DoubleConv(128, 64)

        self.tconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.tconv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.tconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.tconv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)

        self.bottleneck = DoubleConv(1024, 1024)
        self.out = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        c1 = self.conv1(x)
        skip_connections.append(c1)
        c2 = self.conv2(self.maxpool(c1))
        skip_connections.append(c2)
        c3 = self.conv3(self.maxpool(c2))
        skip_connections.append(c3)
        c4 = self.conv4(self.maxpool(c3))
        skip_connections.append(c4)
        out = self.conv5(self.maxpool(c4))
        out = self.bottleneck(out)
        skip_connections = skip_connections[::-1]
        out = self.tconv1(out)
        if out.shape != skip_connections[0].shape:
            out = TF.resize(out, size=skip_connections[0].shape[2:])
        out = torch.cat([out, skip_connections[0]], dim=1)
        out = self.conv6(out)

        out = self.tconv2(out)
        if out.shape != skip_connections[1].shape:
            out = TF.resize(out, size=skip_connections[1].shape[2:])
        out = torch.cat((out, skip_connections[1]), dim=1)
        out = self.conv7(out)

        out = self.tconv3(out)
        if out.shape != skip_connections[2].shape:
            out = TF.resize(out, size=skip_connections[2].shape[2:])
        out = torch.cat((out, skip_connections[2]), dim=1)
        out = self.conv8(out)

        out = self.tconv4(out)
        if out.shape != skip_connections[3].shape:
            out = TF.resize(out, size=skip_connections[3].shape[2:])
        out = torch.cat((out, skip_connections[3]), dim=1)
        out = self.conv9(out)

        out = self.out(out)
        return out
    




## Dataloader Preparation

In [6]:
class PolypDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
        self.masks = os.listdir(mask_dir)

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.masks[index])
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask > 0] = 1 # try with mask == 255.0
        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

In [7]:
augmentations = ATransforms.Compose([
    ATransforms.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    ATransforms.Rotate(limit=35, p=1.0),
    ATransforms.HorizontalFlip(p=0.5),
    ATransforms.VerticalFlip(p=0.1),
    ATransforms.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
    ToTensorV2(),
])

In [8]:
train_loader = DataLoader(
    PolypDataset('../data/train/images/', '../data/train/masks', transform=augmentations),
    batch_size = BATCH_SIZE,
    shuffle = True,
    num_workers=NUM_WORKERS
)
val_loader = DataLoader(
    PolypDataset('../data/val/images/', '../data/val/masks'),
    batch_size = BATCH_SIZE,
    shuffle = True,
    num_workers=NUM_WORKERS
)
test_loader = DataLoader(
    PolypDataset('../data/test/images/', '../data/test/masks'),
    batch_size = 1,
    shuffle = True,
    num_workers=NUM_WORKERS
)

## Train

In [9]:
model = UNET(in_channels=3, out_channels=1).to(DEVICE)


In [10]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS, eta_min=1e-7)


In [11]:
def train():
    all_losses = []
    for epoch in range(EPOCHS):
        total_loss = 0
        print(f"Epoch {epoch+1}/{EPOCHS}")
        model.train()
        for data, targets in tqdm(train_loader):
            data = data.to(device=DEVICE)
            # targets = targets.float().unsqueeze(1).to(device=DEVICE)
            targets = targets.unsqueeze(1).to(device=DEVICE)
            scores = model(data)
            loss = criterion(scores, targets)
            total_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        scheduler.step()
        print(f"Total loss: {total_loss}")
        all_losses.append(total_loss)

        # checkpoint = {
        #     "state_dict": model.state_dict(),
        #     "optimizer": optimizer.state_dict(),
        # }

        # save_checkpoint(checkpoint, filename=CHECKPOINT_FILE)

        with torch.no_grad():
            for image, mask in tqdm(val_loader):
                image = image.to(device=DEVICE)
                mask = mask.unsqueeze(1).to(device=DEVICE)
                preds = torch.sigmoid(model(image))
                preds = (preds > 0.5).float()
                

    return all_losses



In [12]:
train()

Epoch 1/100


  0%|          | 0/32 [00:00<?, ?it/s]

100%|██████████| 32/32 [00:15<00:00,  2.00it/s]


Epoch 2/100


100%|██████████| 32/32 [00:14<00:00,  2.19it/s]


Epoch 3/100


100%|██████████| 32/32 [00:14<00:00,  2.14it/s]


Epoch 4/100


100%|██████████| 32/32 [00:14<00:00,  2.14it/s]


Epoch 5/100


100%|██████████| 32/32 [00:15<00:00,  2.13it/s]


Epoch 6/100


100%|██████████| 32/32 [00:15<00:00,  2.12it/s]


Epoch 7/100


100%|██████████| 32/32 [00:15<00:00,  2.12it/s]


Epoch 8/100


100%|██████████| 32/32 [00:15<00:00,  2.12it/s]


Epoch 9/100


100%|██████████| 32/32 [00:15<00:00,  2.12it/s]


Epoch 10/100


100%|██████████| 32/32 [00:15<00:00,  2.09it/s]


Epoch 11/100


100%|██████████| 32/32 [00:14<00:00,  2.14it/s]


Epoch 12/100


100%|██████████| 32/32 [00:14<00:00,  2.13it/s]


Epoch 13/100


100%|██████████| 32/32 [00:14<00:00,  2.14it/s]


Epoch 14/100


100%|██████████| 32/32 [00:15<00:00,  2.13it/s]


Epoch 15/100


100%|██████████| 32/32 [00:14<00:00,  2.13it/s]


Epoch 16/100


100%|██████████| 32/32 [00:15<00:00,  2.13it/s]


Epoch 17/100


100%|██████████| 32/32 [00:15<00:00,  2.13it/s]


Epoch 18/100


100%|██████████| 32/32 [00:14<00:00,  2.14it/s]


Epoch 19/100


100%|██████████| 32/32 [00:14<00:00,  2.14it/s]


Epoch 20/100


100%|██████████| 32/32 [00:14<00:00,  2.14it/s]


Epoch 21/100


100%|██████████| 32/32 [00:14<00:00,  2.14it/s]


Epoch 22/100


100%|██████████| 32/32 [00:14<00:00,  2.14it/s]


Epoch 23/100


100%|██████████| 32/32 [00:14<00:00,  2.15it/s]


Epoch 24/100


100%|██████████| 32/32 [00:14<00:00,  2.14it/s]


Epoch 25/100


100%|██████████| 32/32 [00:14<00:00,  2.14it/s]


Epoch 26/100


100%|██████████| 32/32 [00:14<00:00,  2.14it/s]


Epoch 27/100


100%|██████████| 32/32 [00:14<00:00,  2.14it/s]


Epoch 28/100


100%|██████████| 32/32 [00:15<00:00,  2.13it/s]


Epoch 29/100


100%|██████████| 32/32 [00:14<00:00,  2.14it/s]


Epoch 30/100


100%|██████████| 32/32 [00:14<00:00,  2.14it/s]


Epoch 31/100


100%|██████████| 32/32 [00:14<00:00,  2.14it/s]


Epoch 32/100


100%|██████████| 32/32 [00:14<00:00,  2.14it/s]


Epoch 33/100


100%|██████████| 32/32 [00:14<00:00,  2.14it/s]


Epoch 34/100


100%|██████████| 32/32 [00:14<00:00,  2.14it/s]


Epoch 35/100


100%|██████████| 32/32 [00:14<00:00,  2.14it/s]


Epoch 36/100


 38%|███▊      | 12/32 [00:06<00:10,  1.96it/s]


KeyboardInterrupt: 

In [14]:
torch.save(model.state_dict(), '../models/model-1.pth')

In [None]:
# def evaluate(model, loader):
#     model.eval()
#     total = 0
#     correct = 0
#     with torch.no_grad():
#         for x, y in loader:
#             x = x.to(DEVICE)
#             y = y.unsqueeze(1).to(DEVICE)
#             preds = torch.sigmoid(model(x))
#             preds = (preds > 0.5).float()
#             total += y.flatten().size(0)
#             correct += (preds == y).float().sum().item()
#     return correct / total