# Semantic Segmentation with UNet

In [3]:
import torch 
import torch.nn as nn 
import torch.optim as optim
import torchvision 
import torchvision.transforms as transforms 
import torchvision.datasets as datasets
from torch.utils.data import DataLoader 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 10 
batch_size = 16 

In [9]:
# transform - download data - dataloader 

transform = transforms.Compose([
    transforms.Resize(size=(128, 128)),
    transforms.ToTensor(),
]) 

transform_target = transforms.Compose([
    transforms.Resize(size=(128, 128)),
    transforms.ToTensor(),
])

train_dataset = datasets.OxfordIIITPet(root='./data', split='trainval', target_types='segmentation',
                                       transform=transform, target_transform=transform_target, download=True)
test_dataset = datasets.OxfordIIITPet(root='./data', split='test', target_types='segmentation',
                                       transform=transform, target_transform=transform_target, download=True)

# DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [14]:
# U-Net model
# torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch','unet', 
                       in_channels=3, out_channels=1, init_features=32, pretrained=True)
model.to(device)

# loss and optim
# CE / Focal Loss(better for unlance dataset) / IoU Loss(Jaccard) 
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001) 


Using cache found in C:\Users\13476/.cache\torch\hub\mateuszbuda_brain-segmentation-pytorch_master
Downloading: "https://github.com/mateuszbuda/brain-segmentation-pytorch/releases/download/v1.0/unet-e012d006.pt" to C:\Users\13476/.cache\torch\hub\checkpoints\unet-e012d006.pt


In [15]:
# training loop 

model.train()
for epoch in range(num_epochs):
    for idx, (image, label) in enumerate(train_loader):
        image, label = image.to(device), label.to(device) 
        
        # forward
        output = model(image) 
        # loss - backward - optim
        label = label * 255 # to remove the ToTensor operation
        label = (label==1).float() # foreground all the other points are back 
        loss = criterion(output, label) 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 

        if (idx+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
        


KeyboardInterrupt: 

In [None]:
from sklearn.metrics import jaccard_score  
from torchmetrics.classification import BinaryJaccardIndex

import matplotlib.pyplot as plt
import numpy as np
# eval

# accuracy / IoU

model.eval()
IoU = 0
IoU_by_me = 0
n_batch = 0
metric = BinaryJaccardIndex( )
with torch.no_grad():
    correct = 0
    total = 0
    for idx, (images, labels) in enumerate(test_loader):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)

        labels = labels*255
        labels = (labels==1).float()

        outputs = (outputs > 0.5).float()
        bs, c, h , w = labels.size()
        n_batch += 1
        IoU += jaccard_score(labels.view(-1).cpu().numpy().reshape(-1), outputs.view(-1).cpu().numpy().reshape(-1),average='binary')

        total += bs*h*w*c
        correct += (outputs == labels).sum().item()

        union = ((outputs + labels) >= 1).float().sum()
        inters = (outputs * labels).float().sum()

        union = torch.logical_or(outputs.reshape(-1), labels.reshape(-1)).float().sum()
        inters = torch.logical_and(outputs.reshape(-1), labels.reshape(-1)).float().sum()

        IoU_by_me += abs(inters)/abs(union)

        assert (abs(inters)/abs(union)) <= 1

        # to visualize IMG / GT / Prediction
        if idx > 2 and idx < 10 :
            img = images.permute(0,2,3,1)

            img = np.concatenate(img.cpu().numpy(), axis=1)
            fig = plt.figure(figsize=(15, 10))
            plt.imshow(img)
            plt.show()

            img = labels.permute(0,2,3,1)
            img = np.concatenate(img.cpu().numpy(), axis=1)
            fig = plt.figure(figsize=(15, 10))
            plt.imshow(img)
            plt.show()

            img = outputs.permute(0,2,3,1)
            img = np.concatenate(img.cpu().numpy(), axis=1)
            fig = plt.figure(figsize=(15, 10))
            plt.imshow(img)
            plt.show()

accuracy = correct / total
print(f'Accuracy on the test set: {100 * accuracy:.2f}%')
IoU = IoU / n_batch
print(f'IoU on the test set: {IoU:.2f}%')

IoU_by_me = IoU_by_me / n_batch
print(f'IoU by me on the test set: {IoU_by_me:.2f}%')