## LinkNet Architecture

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18
import sys
sys.path.append('../dataset')
from Datasets import BaseDataset

import numpy as np
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.data import random_split, DataLoader
import torchvision.transforms as transforms
from sklearn.metrics import accuracy_score

# Set the device (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


### Model

In [2]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(Encoder, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.basic_block = BasicBlock(out_channels, out_channels)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)

        out = self.basic_block(out)

        return out


class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Decoder, self).__init__()
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.deconv(x)
        out = self.bn(out)
        out = self.relu(out)

        return out
    

class LinkNet(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super(LinkNet, self).__init__()

        self.num_classes = num_classes

        # Load the pretrained ResNet18 model
        resnet = resnet18(pretrained=pretrained)

        self.conv1 = resnet.conv1
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool

        self.encoder1 = resnet.layer1
        self.encoder2 = resnet.layer2
        self.encoder3 = resnet.layer3
        self.encoder4 = resnet.layer4

        self.decoder4 = Decoder(512, 256)
        self.decoder3 = Decoder(256, 128)
        self.decoder2 = Decoder(128, 64)
        self.decoder1 = Decoder(64, 64)

        self.final_deconv1 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
        self.final_bn1 = nn.BatchNorm2d(32)
        self.final_relu1 = nn.ReLU(inplace=True)
        self.final_conv2 = nn.Conv2d(32, num_classes, kernel_size=3, stride=1, padding=1, bias=False)

    def forward(self, x):
        # Encoder
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        enc1 = self.encoder1(x)
        enc2 = self.encoder2(enc1)
        enc3 = self.encoder3(enc2)
        enc4 = self.encoder4(enc3)

        # Decoder
        dec4 = self.decoder4(enc4)
        dec4 = torch.add(dec4, enc3)  # Use torch.add() instead of +=
        dec3 = self.decoder3(dec4)
        dec3 = torch.add(dec3, enc2)  # Use torch.add() instead of +=
        dec2 = self.decoder2(dec3)
        dec2 = torch.add(dec2, enc1)  # Use torch.add() instead of +=
        dec1 = self.decoder1(dec2)

        # Final Convolution
        x = self.final_deconv1(dec1)
        x = self.final_bn1(x)
        x = self.final_relu1(x)
        x = self.final_conv2(x)

        # Upsample to the original input size
        x = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)

        return x[:, :self.num_classes, :, :]


### Load Train and Test Data using BaseDataSet

In [3]:
#transform
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),  # Resizing to smaller dimensions for quicker training
    transforms.ToTensor()  # Convert image to PyTorch tensor
])

dataset = BaseDataset('../data/', transform=transform)

train_ratio = 0.8
train_size = int(train_ratio * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

### Training

In [5]:
# Set the device (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("DEVICE USED: ", device)
# Create an instance of the LinkNet model
model = LinkNet(num_classes=3, pretrained=True).to(device)

# Define the loss function
criterion = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Define the number of training epochs
num_epochs = 10


train_loss = []
train_accuracy = []
test_loss = []
test_accuracy = []

# Training loop
for epoch in range(num_epochs):
    model.train()  # Set the model to train mode
    running_loss = 0.0
    correct = 0
    total = 0

    # Iterate over the training dataset
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        # Update the weights
        optimizer.step()

        _, predicted = torch.max(outputs.data, 1)
        labels = torch.argmax(labels, dim=1) 
        correct += (predicted == labels).sum().item()
        total += torch.numel(labels)

        # Update the running loss
        running_loss += loss.item() * inputs.size(0)

    train_accuracy.append(100 * correct / total)
    train_loss.append(running_loss / i)

    # Testing
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    for i, data in enumerate(test_loader, 0):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        _, predicted = torch.max(outputs.data, 1)
        labels = torch.argmax(labels, dim=1) 
        correct += (predicted == labels).sum().item()
        total += torch.numel(labels)

        running_loss += loss.item()
    
    test_accuracy.append(100 * correct / total)
    test_loss.append(running_loss / i)

    # Print the loss and accuracy for each epoch
    print(f'Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss[-1]:.4f} | Train Accuracy: {train_accuracy[-1]:.4f}')


DEVICE USED:  cuda




Epoch 1/10 | Train Loss: 1.2083 | Train Accuracy: 91.0175
Epoch 2/10 | Train Loss: 0.8035 | Train Accuracy: 94.1621
Epoch 3/10 | Train Loss: 0.6012 | Train Accuracy: 95.7985
Epoch 4/10 | Train Loss: 0.5229 | Train Accuracy: 95.7052
Epoch 5/10 | Train Loss: 0.4433 | Train Accuracy: 91.9628
Epoch 6/10 | Train Loss: 0.4997 | Train Accuracy: 92.0408
Epoch 7/10 | Train Loss: 0.4306 | Train Accuracy: 93.2241
Epoch 8/10 | Train Loss: 0.3040 | Train Accuracy: 91.6414
Epoch 9/10 | Train Loss: 0.2851 | Train Accuracy: 91.2715
Epoch 10/10 | Train Loss: 0.2012 | Train Accuracy: 92.5580


In [None]:
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt


def plot_segmentation(image, ground_truth, predicted):
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(image)
    ax[0].set_title('Input Image')
    ax[0].axis('off')
    ax[1].imshow(ground_truth)
    ax[1].set_title('Ground Truth Mask')
    ax[1].axis('off')
    ax[2].imshow(predicted)
    ax[2].set_title('Predicted Mask')
    ax[2].axis('off')
    plt.tight_layout()
    plt.show()


# Assuming you have the following variables:
# input_image: The input image tensor of shape (3, height, width)
# ground_truth_mask: The ground truth mask tensor of shape (1, height, width)
# predicted_mask: The predicted mask tensor of shape (1, height, width)

# Convert tensors to numpy arrays and transpose dimensions
input_image_np = TF.to_pil_image(input_image)
ground_truth_mask_np = ground_truth_mask.squeeze().cpu().numpy()
predicted_mask_np = predicted_mask.squeeze().cpu().numpy()

# Plot the segmentation results
plot_segmentation(input_image_np, ground_truth_mask_np, predicted_mask_np)
