# CIL Road Segmentation


## preparation
You need to download the oxford-iiit pet dataset https://www.robots.ox.ac.uk/~vgg/data/pets/.

Use data from drive, when using colab.

In [None]:
#from google.colab import drive

In [None]:
#drive.mount("/content/gdrive")
#cd gdrive/My Drive/img_seg_animals/

In [None]:
cd img_seg_animals

## explore data

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

In [None]:
df = pd.read_csv(
    "annotations/list.txt",
    delimiter=" ",
    skiprows=6,
    header=None,
    names=["stem", "class_id", "species", "breed"]
)
df.head()

In [None]:
df["class_name"] = df.stem.map(lambda x: x.split("_")[0])
df["image_path"] = df.stem.map(lambda x: f"images/{x}.jpg")
df["annotations"] = df.stem.map(lambda x: f"annotations/trimaps/trimaps/{x}.png")
df.tail()

The species 1 stands for cats, and 2 for dogs, class_id stands for the subtype.

Now we want to look at the distribution of classes.

In [None]:
plt.figure(figsize=(15,5))
sns.histplot(df.class_id)
plt.title("class id")

In [None]:
plt.figure(figsize=(15,5))
plot = sns.histplot(df.class_name)
plt.setp(plot.get_xticklabels(), rotation=90)

plt.title("Class name")

In [None]:
plt.figure(figsize=(15,5))
plot = sns.histplot(df.species, discrete=True)
plt.title("species")

Looks like there are more dogs than cats, but the breeds, classes etc. are balanced enough for our purpose(nothing likea breed that only has 2 images, which we would need to remove).

Now leets look at the images themselves.

In [None]:
fig, ax = plt.subplots(3,2, figsize=(10,15))

for i in range(3):
    img = Image.open(df.image_path[i])
    annot = Image.open(df.annotations[i])
    ax[i, 0].imshow(img)
    ax[i, 1].imshow(annot)
plt.show()

In [None]:
set(Image.open(df.annotations[i]).getdata())#inside outside border

The images have different sizes, what makes this problem harder, the annotations don't seem to be too hard to learn, as there are not that many wery thin features.

## data

In [None]:
from torch.utils.data import DataLoader, Dataset
import albumentations
from albumentations.pytorch import ToTensorV2
import cv2
from sklearn.model_selection import StratifiedKFold
import numpy as np
from torch.utils.data import DataLoader
import torch


In [None]:
class animal_data(Dataset):
    def __init__(self, df, tfm=None):
        self.df = df
        self.tfm = tfm #transformations
    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        img = Image.open(self.df.image_path.iloc[i]).convert('RGB')
        mask = Image.open(self.df.annotations.iloc[i])
        img = np.asarray(img)
        mask = np.where(np.asarray(mask)!=2, 1, 0)
        if self.tfm:
            augmented = self.tfm(image=img, mask=mask)
            img, mask = augmented["image"], augmented["mask"]
        #img = (img.float() - 128)/300 # rescale values
        img = (img.to(torch.float)-128)/300
        #print(torch.min(img))
        return img, mask#  BECAUSE THE LABELS ARE 1,2,3, and I dont really care about the difference between 2 and 3

In [None]:
transformations = albumentations.Compose([
    albumentations.HorizontalFlip(p=0.4), 
    albumentations.VerticalFlip(p=0.3),
    albumentations.RandomScale(),
    albumentations.Rotate(border_mode=cv2.BORDER_CONSTANT, mask_value=0),
    albumentations.RandomBrightnessContrast(p=0.3),
    albumentations.SmallestMaxSize(256), #spmewhat controll img size
    #albumentations.augmentations.crops.transforms.RandomCropFromBorders(p=.4),#somehow not working
    #albumentations.Normalize(),
    albumentations.RandomCrop(256, 256),
    ToTensorV2(),
])

val_transformations = albumentations.Compose([
    albumentations.SmallestMaxSize(256),
    #albumentations.Normalize(),
    albumentations.RandomCrop(256, 256),# could do CenterCrop for consistency, but i like this one more
    ToTensorV2(),
])

# missing : Mixup, Cutmix, RandAugment, Random erazing

In [None]:
skf = StratifiedKFold(5)
train_idx, val_idx = next(iter(skf.split(df, df.class_id)))
train_df = df.iloc[train_idx]
val_df = df.iloc[val_idx]

train_ds = animal_data(train_df, tfm=transformations)
val_ds = animal_data(val_df, tfm=val_transformations)

In [None]:
batch_size = 1
train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
test_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True, drop_last=True)

In [None]:

for _ in range(4):
    #img, mask = train_ds[0]
    img, mask = next(iter(train_dataloader))
    img, mask = img[0], mask[0] 
    plt.subplot(1,2,1)
    plt.imshow(img.numpy().transpose(1,2,0))
    plt.xticks([]); plt.yticks([])

    plt.subplot(1,2,2)
    plt.imshow(mask)
    plt.xticks([]); plt.yticks([])
    plt.show()
    #print(torch.min(img))


## models to explore

In [None]:
from torchvision.models import resnet50, ResNet50_Weights
import torch
resnet = resnet50(weights = ResNet50_Weights.IMAGENET1K_V2)
import torch_geometric
from torch_geometric.nn import GCNConv

In [None]:
# initial model for comparison
class CNN_might_help(torch.nn.Module):# just trash to see if everything works
    def __init__(self):
        super(CNN_might_help, self).__init__()

        self.batchsize = batch_size#batch_size
        self.num_classes = 2

        self.backbone = torch.nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2# batch_size * 512 * size/8
            #resnet.layer3,# og size / 16
        )
        self.backbone.requires_grad_ = True#False in the beginning
        
        self.cnn_pt2 = torch.nn.Sequential(# less compression than using more resnet layers
            torch.nn.Conv2d(512,256,kernel_size = 3, padding = 1),
            torch.nn.ReLU(),#batch_size * 256 * 100 *100
            torch.nn.Conv2d(256,128,kernel_size = 3, padding = 1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128,32,kernel_size = 3, padding = 1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, self.num_classes if 2 !=  self.num_classes else 1 ,kernel_size = 3, padding = 1),
        )
        #self.cnn_pt2.requires_grad_ = True

        # get back dimensions
        self.upsampler = torch.nn.Upsample(scale_factor=8, mode="bilinear")
        # I am not a fan of upsampling, but the model is already slow, so making it bigger with a u-net like achitecture is too expensive

    def forward(self, x):
        x = self.backbone(x)
        x = self.cnn_pt2(x)
        x = self.upsampler(x)
        x= torch.nn.functional.sigmoid(x)   
        return x

### SCG Net
resnet backbone GNN

The model has been inspired by https://arxiv.org/pdf/2009.01599.pdf but has been changed/adjusted to the task.

In [None]:
#pip install torch_geometric

In [None]:
class mini_Scg_Net(torch.nn.Module):# just trash to see if everything works
    def __init__(self):
        super(mini_Scg_Net, self).__init__()
        self.batchsize = batch_size#batch_size
        
        self.backbone = torch.nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2# batch_size * 512 * size/8
            #resnet.layer3,# og size / 16
        )
        self.backbone.requires_grad_ = True#False in the beginning
        num_classes = 2
        self.out_size = num_classes if 2 !=  num_classes else 1 

        self.cnn_extension = torch.nn.Sequential(# less compression than using more resnet layers
            torch.nn.Conv2d(512, 512, kernel_size = 3, padding = 1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(512, 256, kernel_size = 3, padding = 1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(256, 128, kernel_size = 3, padding = 1),
            torch.nn.ReLU()#batch_size * 256 * 100 *100
        )
        # endoder
        self.encoder_mean = torch.nn.Conv2d(128, self.out_size, kernel_size = 3, padding = 1)# could also do only 1 output but I think that might be bad
        self.encoder_var = torch.nn.Conv2d(128, self.out_size, kernel_size = 1, padding = 0)
        # I am not a fan of the variational part here

        # decoder
        # no

        # GNN
        #self.GNN = torch.nn.Sequential(
        self.conv1 = GCNConv(self.out_size, self.out_size)
        self.conv2 = GCNConv(self.out_size, self.out_size)
        #)

        # get back dimensions
        self.upsampler = torch.nn.Upsample(scale_factor=8, mode="bilinear")
        # I am not a fan of upsampling, but the model is already slow, so making it bigger with a u-net like achitecture is too expensive

        self.reduzed_size = int(self.batchsize*(256/8)**2)

    def forward(self, x):
        x = self.backbone(x)
        x = self.cnn_extension(x)#F
        
        # the GNN does not seem to allow batch sizes
        # only one adj. matrix is allowed
        # VAE
        
        M = self.encoder_mean(x).reshape((self.out_size, -1))# flatten img
        #sigma = torch.exp(self.encoder_var(x).reshape((self.out_size, -1)))
        
        x = M #+ sigma * torch.randn((self.out_size, self.reduzed_size)).cuda()# add or remove .cuda() depending on use
        #Z_res = M *(1 - log_sigma)
        x = torch.permute(x, (1, 0))# put channels at the end
        A = torch.nn.functional.relu(torch.inner(x, x)).to_sparse()
        # this part is horrible but the library requires trash input
        edge_idx, weights = torch_geometric.utils.to_edge_index(A)   
        
        #Z = Z.unsqueeze(-1)
        #Z
        # cnn needs channels, dims, but gnn needs dims, channels

        # GNN
        x = self.conv1(x, edge_idx, edge_weight = weights)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x, edge_idx, edge_weight = weights)
        #x = torch.nn.functional.softmax(x)# useless?
        
        x = torch.permute(x, (1, 0)) # not sure if needed or if reshape would do the right thing
        x = x.reshape((1, self.out_size, 32, 32))
        x = self.upsampler(x)
        x = torch.nn.functional.sigmoid(x)       
        return x, M, A#, sigma

This might also be influenced by sub-optimal kl-divergence implementations, but the vae version did not work at all for me. The "normal" ae marked roughly the correct part of the image but was not very good.

### U-net
https://github.com/milesial/Pytorch-UNet

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [None]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        x = torch.nn.functional.sigmoid(x) 
        return x

    def use_checkpointing(self):#unused
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        self.down4 = torch.utils.checkpoint(self.down4)
        self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
test123 = UNet(n_channels = 3, n_classes = 2)#mini_Scg_Net()#CNN_might_help()#mini_Scg_Net()
test123.to(device)
data_iter = iter(train_dataloader)

In [None]:
X, label = next(data_iter)
#print(X.shape)

In [None]:
#output, M, A, sigma = test123(X.to(device))
output = test123(X.to(device))
#output = test123(X.to(device))

output.shape

## Training

In [None]:
#pip install torchmetrics

In [None]:
from torch.optim.lr_scheduler import LinearLR
from tqdm import tqdm  # tqdm.notebook
import math
#from torchmetrics import Dice

In [None]:
num_epochs = 20

In [None]:
#curr_model.train()

In [None]:
#torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
curr_model = UNet(n_channels = 3, n_classes = 1)#mini_Scg_Net()#mini_Scg_Net()#CNN_might_help()#
curr_model.to(device)
print(device)

In [None]:
#torch.autograd.detect_anomaly(True)

choose which model you want to load, if you have a model

In [None]:
#curr_model.load_state_dict(torch.load("SCG-net/model_weights.pth", map_location=torch.device(device)),strict=False)
#curr_model.load_state_dict(torch.load("CNN/model_weights.pth", map_location=torch.device(device)))
#curr_model.load_state_dict(torch.load("unet/model_weights.pth", map_location=torch.device(device)))

In [None]:
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, curr_model.parameters()), #curr_model.parameters(), #filter(lambda p: p.requires_grad, curr_model.parameters())
                       lr=2e-3, weight_decay=1e-8)
scheduler = LinearLR(start_factor = 2e-3, 
                     end_factor = 2e-5,
                     last_epoch = -1,
                     total_iters = num_epochs,
                     optimizer = optimizer)

In [None]:
class DiceLoss(torch.nn.Module):# binary
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #inputs = torch.nn.functional.sigmoid(inputs)       
        
        #flatten 
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice



In [None]:
def train(model, data_loader, device, num_epochs=100):
    """Train the model."""
    #loss_func = Dice(num_classes = 3).to(device)
    loss_func = DiceLoss()
    #loss_func = torch.nn.BCELoss()
    const = (32*32)**2
    epsilon = 1e-7
    smaller_epsilon = 1e-20
    for epoch in range(1, num_epochs+1):
        iteration_loss = 0.
        for X, y in tqdm(data_loader):
            # Reset the optimizer.
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)

            #outputs, M,  A= model(X)#, sigma 
            outputs = model(X)

            #print("X input: ",X.shape)
            #print("y: ",y.shape)
            #print("prediction: ",outputs.shape)

            # Compute the loss and do the backward pass.
            loss_base = loss_func(outputs, y)# .flatten()   , .to(torch.float32)
            #loss_KL = torch.div(torch.sum(torch.add(torch.pow(log_sigma, 2), 1) - torch.pow(M, 2) - torch.pow(torch.exp(log_sigma), 2)),
            #                       -6144)# this version seems to be wrong(copied formula from paper) and gives you negative kl divs which is not allowed

            #loss_KL = torch.div(torch.sum(torch.pow(M, 2) + torch.pow(sigma, 2) - torch.log(torch.pow(sigma, 2)+smaller_epsilon) - 1 ),
            #                       2048)#32*32*num_classes(1 if 2)*2#formula according to PAI script

            #A = A.to_dense()
            #gamma = torch.sqrt(1 + torch.divide(1024, #32*32
            #                                    torch.sum(torch.diag(A) + epsilon)))
            #loss_dl = -(gamma/const)*(torch.sum(torch.log(torch.clamp(torch.abs(torch.diagonal(A)), 
            #                                                          min = 0., max = 1.)+ epsilon)))
            #print("loss_base: ", loss_base, "\nloss_KL: ", loss_KL, "\nloss_dl: ", loss_dl)
            
            loss = loss_base #+ loss_dl #+ loss_KL#  I think I need weighting
            loss.backward()
            #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0, norm_type=2)
            # was having problems with nans

            optimizer.step()
            iteration_loss += loss.item()
        
        print(f'Epoch {epoch}/{num_epochs}, ' ,
              f'Train Loss: {iteration_loss / len(data_loader):.4f}, ',
              )
        
        if math.isnan(iteration_loss):# something broke again
            from winsound import Beep
            Beep(300, 5000)
            break

        if(epoch %5 == 0):
            print("save the weights")
            #torch.save(model.state_dict(), "CNN/model_weights.pth")
            #torch.save(model.state_dict(), "SCG-net/model_weights.pth")
            #torch.save(model.state_dict(), "unet/model_weights.pth")
        scheduler.step()


In [None]:
train(curr_model, train_dataloader, device, num_epochs = num_epochs) 

In [None]:
from winsound import Beep
Beep(300, 5000)

In [None]:
#from google.colab import output
#output.eval_js('new Audio("https://upload.wikimedia.org/wikipedia/commons/0/05/Beep-09.ogg").play()')

In [None]:
#torch.save(curr_model.state_dict(), "SCG-net/model_weights.pth")
#torch.save(curr_model.state_dict(), "CNN/model_weights.pth")
torch.save(curr_model.state_dict(), "unet/model_weights.pth")

 ## Test
 Now we test the performance of our model

In [None]:
from torchmetrics.classification import F1Score, Accuracy

In [None]:
curr_model.eval()

In [None]:
f1 = F1Score(task="binary",
             #num_classes = 3
             )
acc = Accuracy(task="binary"#, num_classes=3
               )

In [None]:
f_list = []
acc_list = []
i = 0
with torch.no_grad():
    for X, y in tqdm(test_dataloader):
        X, y = X.to(device), y.squeeze()#.to(device)
        outputs= curr_model(X)#, _, _, _ 
        #outputs = torch.permute(outputs.squeeze(), (1,2,0))
        #print("X input: ", X.shape)
        #print("y: ", y.shape)
        #print("prediction: ", outputs.shape)
        outputs = outputs.squeeze().cpu()
        f_sc = f1(outputs, y)
        f_list.append(f_sc)
        
        acc_sc = acc(outputs, y)
        acc_list.append(acc_sc)
        i+=1
        if i%100==0:
            print("accuracy: ", sum(acc_list)/len(test_dataloader))
            print("f1-score: ", sum(f_list)/len(test_dataloader))
#f_list 

In [None]:
print("accuracy: ", sum(acc_list)/len(test_dataloader))

In [None]:
print("f1-score: ", sum(f_list)/len(test_dataloader))

In [None]:
#curr_model.to("cpu")
fig, ax = plt.subplots(1,4, figsize=(20,5))

img, mask = next(iter(test_dataloader))
img_plot = img.squeeze()

og_color = ax[0]#plt.subplot(1,4,1)
og_color.imshow(np.round(img_plot.numpy().transpose(1,2,0)*300+128).astype(int))
og_color.title.set_text("original(colors restored)")


og_permuted = ax[1]#plt.subplot(1,4,2)
og_permuted.imshow(img_plot.numpy().transpose(1,2,0))
og_permuted.title.set_text("original(permuted)")

plt.xticks([]); plt.yticks([])
mask = mask.squeeze()

sol = ax[2]#plt.subplot(1,4,3)
sol.imshow(mask.numpy())
sol.title.set_text("ground truth")

plt.xticks([]); plt.yticks([])

outputs, _, _= curr_model(img.to(device))#, _, _, _ 
#outputs = curr_model(img.to(device))

outputs = torch.where((outputs).squeeze() >0.5, 1, 0)  

# for multi class
#outputs = torch.permute(outputs.squeeze(), (1,2,0))
#outputs = torch.max(outputs.detach(), -1)
#print(outputs)

pred = ax[3]#plt.subplot(1,4,4)
pred.imshow(outputs.cpu())
pred.title.set_text("prediction")
