# Library

In [1]:
import sys
from torch import nn, cat,rand
import torch

from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import os

import pandas as pd
import shutil

# Splitting Datasets into two part which is train and val

In [2]:
raw_data_images_dir = "data/raw_data/raw_images"
raw_data_mask_dir = "data/raw_data/raw_masks"

dest_train_images_dir = "data/train/train_images"
dest_train_masks_dir = "data/train/train_masks"
dest_val_images_dir = "data/val/val_images"
dest_val_masks_dir = "data/val/val_masks"

In [3]:
# Create Dataframe
data={"images": os.listdir(raw_data_images_dir),
"masks" : os.listdir(raw_data_mask_dir)
}
images_data_dir_df = pd.DataFrame(data)

# Split Datasets
train_size = 0.8
train_dataset=images_data_dir_df.sample(frac=train_size,random_state=200)
test_dataset=images_data_dir_df.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)

In [50]:
#Copy to desired folder
def copy_data(src_images_dir,src_images_name,dest_images_dir):
    shutil.copy(os.path.join(src_images_dir,src_images_name),dest_images_dir)

#Copy Train Images
for images in train_dataset['images']:
    copy_data(raw_data_images_dir,images,dest_train_images_dir)

#Copy Train Masks
for images in train_dataset['masks']:
    copy_data(raw_data_mask_dir,images,dest_train_masks_dir)

#Copy val Images
for images in test_dataset['images']:
    copy_data(raw_data_images_dir,images,dest_val_images_dir)

#Copy val Masks
for images in test_dataset['masks']:
    copy_data(raw_data_mask_dir,images,dest_val_masks_dir)

# Hyperparameter

In [4]:
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 3
NUM_WORKERS = 0
IMAGE_HEIGHT = 160  # 1280 originally
IMAGE_WIDTH = 240  # 1918 originally
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = dest_train_images_dir
TRAIN_MASK_DIR = dest_train_masks_dir
VAL_IMG_DIR = dest_val_images_dir
VAL_MASK_DIR = dest_val_masks_dir

# Datasets

In [8]:
class CarvanaDataset(Dataset):
    def __init__(self, image_dir,mask_dir,transform=None):
        self.image_dir = image_dir
        self.mask_dir  = mask_dir
        self.transfrom = transform
        self.images = os.listdir(image_dir)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self,index):
        img_path = os.path.join(self.image_dir,self.images[index])
        mask_path = os.path.join(self.mask_dir,self.images[index].replace(".jpg","_mask.gif"))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"),dtype=np.float32)
        mask[mask==255.0] = 1.0

        if self.transfrom is not None:
            augmentations = self.transfrom(image=image,mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]
        
        return image,mask

# Transformation

In [9]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

val_transforms = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

# DataLoader

In [10]:
from torch.utils.data import DataLoader

train_params = {'batch_size': BATCH_SIZE,
                'shuffle': True,
                'num_workers': NUM_WORKERS
                }

test_params = {'batch_size': BATCH_SIZE,
                'shuffle': True,
                'num_workers': NUM_WORKERS
                }

training_set = CarvanaDataset(TRAIN_IMG_DIR,TRAIN_MASK_DIR,transform=train_transform)
testing_set = CarvanaDataset(VAL_IMG_DIR,VAL_MASK_DIR,transform=val_transforms)

training_loader = DataLoader(dataset = training_set, **train_params)
testing_loader = DataLoader(dataset = testing_set, **test_params)

In [11]:
data=next(iter(testing_loader))
print("Images: ",data[0].shape)
print("Masks: ",data[1].shape)

Images:  torch.Size([16, 3, 160, 240])
Masks:  torch.Size([16, 160, 240])


# Model

In [12]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
    ):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)
model = UNET(in_channels=3, out_channels=1)
model

UNET(
  (ups): ModuleList(
    (0): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
    (1): DoubleConv(
      (conv): Sequential(
        (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (2): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
    (3): DoubleConv(
      (conv): Sequential(
        (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), p

## Pred Test

In [55]:
x = torch.randn((3, 1, 161, 161))
model = UNET(in_channels=1, out_channels=1)
preds = model(x)
assert preds.shape == x.shape

# Train Function

In [13]:
from tqdm import tqdm
def train_fn(loader,model, optimizer, loss_fn, scaler):
    with tqdm(loader, unit="batch") as tepoch:
        for batch_idx, (data,targets) in enumerate(tepoch):
            data = data.to(device=DEVICE)
            targets = targets.float().unsqueeze(1).to(device=DEVICE)

            # forward
            with torch.cuda.amp.autocast():
                predictions = model(data)
                loss = loss_fn(predictions,targets)
            
            #backward
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            # update tqdm loop
            tepoch.set_postfix(loss=loss.item())

# Optimizer

In [14]:
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler()

# if LOAD_MODEL:
    # load_checkpoint(torch.load("my_checkpoint.pth.tar"))



# Check Accuracy

In [15]:
def check_accuracy(loader,model,device):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)

            preds = model((x))
            preds = torch.sigmoid(preds)
            preds = (preds > 0.5).float()
            num_correct +=(preds==y).sum()

            num_pixels += torch.numel(preds)

            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )
        print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
        )
    
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()
    

# Save Model Checkpoint

In [16]:
import torchvision
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def save_predictions_as_imgs(
    loader, model, folder="saved_images/", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()

# Train

In [17]:
for epoch in range(NUM_EPOCHS):
    train_fn(training_loader,model,optimizer,loss_fn,scaler)

    # save model
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer":optimizer.state_dict(),
    }
    save_checkpoint(checkpoint)

    # check accuracy
    check_accuracy(testing_loader, model, device=DEVICE)

    # print some examples to a folder
    save_predictions_as_imgs(
        testing_loader, model, folder="saved_images/", device=DEVICE
    )
    

  2%|▏         | 5/255 [01:40<1:21:20, 19.52s/batch, loss=0.546]