# Semantic Segmentation with PyTorch

Mount google drive to colab.

In [None]:
from google.colab import drive
drive.mount("/content/drive")

Import neccessary libraties and set parameters.

In [None]:
import os
import time
import json

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models

In [None]:
project_name = "RNE_2022_Segmentation_Colab"
project_path = "/content/drive/My Drive/" + project_name
train_dataset_path = project_path + "/SimulationDataset/train"
test_dataset_path = project_path + "/SimulationDataset/test"

model_type = "encdec" # encdec / fcn / unet / pspnet

# Create folder to store training results.
if model_type == "encdec":
    results_path = project_path + "/results_encdec"
elif model_type == "fcn":
    results_path = project_path + "/results_fcn"
elif model_type == "unet":
    results_path = project_path + "/results_unet"
elif model_type == "pspnet":
    results_path = project_path + "/results_pspnet"

if os.path.isdir(results_path) == False:
   os.mkdir(results_path)

# Parameters
num_class = 3 
input_h, input_w = 256, 256
batch_size = 16
epochs = 10
lr = 1e-4
use_gpu = torch.cuda.is_available()

## Simulation Dataset

In [None]:
class SimDataset(Dataset):
    def __init__(self, path, n_class=num_class, flip_rate=0.5, train=True):
        self.img_folder_path = path + "/img"
        self.label_folder_path = path + "/label"
        self.file_list = os.listdir(self.img_folder_path)
        self.n_class = n_class
        self.flip_rate = flip_rate
        self.train = train

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_path = self.img_folder_path + "/" + self.file_list[idx]
        label_path = self.label_folder_path + "/" + "jetbot_" + self.file_list[idx].split("_")[1] + "_layer.png"

        # open image data
        img = np.asarray(Image.open(img_path).resize((256, 256), Image.NEAREST))
        img = img.astype(float)/255.0
        label_img = np.asarray(Image.open(label_path).resize((256, 256), Image.NEAREST))
        label = np.zeros((img.shape[0], img.shape[1], self.n_class), dtype=float)
        label[label_img[:,:,0]==178,0] = 1
        label[label_img[:,:,0]==255,1] = 1
        label[label_img[:,:,0]==0,2] = 1

        if np.random.sample() < self.flip_rate:
            img = np.fliplr(img)
            label = np.fliplr(label)

        img = torch.from_numpy(img.copy()).float()
        img = img.permute(2,0,1)
        label = torch.from_numpy(label.copy()).float()
        label = label.permute(2,0,1)
        sample = {"X": img, "Y": label}
        return sample

# Load dataset
train_data = SimDataset(path=train_dataset_path, flip_rate=0.5)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
test_data = SimDataset(path=test_dataset_path, flip_rate=0.0)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)

## Network Model
### VGG16 Feature Extractor (pretrained)

In [None]:
class Vgg16(nn.Module):
    def __init__(self, pretrained = True):
        super(Vgg16, self).__init__()
        self.vggnet = models.vgg16(pretrained)
        del(self.vggnet.classifier) # Remove fully connected layer to save memory.
        features = list(self.vggnet.features)
        self.layers = nn.ModuleList(features).eval() 
        
    def forward(self, x):
        results = []
        for ii,model in enumerate(self.layers):
            x = model(x)
            if ii in [3,8,15,22,29]:
                results.append(x) #(64,256,256),(128,128,128),(256,64,64),(512,32,32),(512,16,16)
        return results

vgg_model = Vgg16()
vgg_model = vgg_model.cuda()
print(vgg_model.layers)

### Encoder-Decoder

In [None]:
class DeConv2d(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride, padding, dilation):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="nearest")
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)
    
    def forward(self, x):
        output = self.up(x)
        output = self.conv(output)
        return output

class EncoderDecoder(nn.Module):
    def __init__(self, pretrained_net, n_class):
        super().__init__()
        self.n_class = n_class
        self.pretrained_net = pretrained_net
        self.relu = nn.ReLU(inplace=True)

        self.deconv1 = DeConv2d(512, 512, kernel_size=3, stride=1, padding=1, dilation=1)
        self.bn1 = nn.BatchNorm2d(512)
        
        self.deconv2 = DeConv2d(512, 256, kernel_size=3, stride=1, padding=1, dilation=1)
        self.bn2 = nn.BatchNorm2d(256)
        
        self.deconv3 = DeConv2d(256, 128, kernel_size=3, stride=1, padding=1, dilation=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        self.deconv4 = DeConv2d(128, 64, kernel_size=3, stride=1, padding=1, dilation=1)
        self.bn4 = nn.BatchNorm2d(64)
        
        self.classifier = nn.Conv2d(64, n_class, kernel_size=1)

    def forward(self, x):
        pre_output = self.pretrained_net(x)
        output = self.bn1(self.relu(self.deconv1(pre_output[4]))) #(512,32,32)
        output = self.bn2(self.relu(self.deconv2(output))) #(256,64,64)
        output = self.bn3(self.relu(self.deconv3(output))) #(128,128,128)
        output = self.bn4(self.relu(self.deconv4(output))) #(64,256,256)
        output = self.classifier(output)
        return output

### Fully Convolution Network (FCN)


In [None]:
class FCN(nn.Module):
    def __init__(self, pretrained_net, n_class):
        super().__init__()
        self.n_class = n_class
        self.pretrained_net = pretrained_net
        #####################################
        #TODO
        #####################################

    def forward(self, x):
        #####################################
        #TODO
        #####################################

### U-Net

In [None]:
class UNet(nn.Module):
    def __init__(self, pretrained_net, n_class):
        super().__init__()
        self.n_class = n_class
        self.pretrained_net = pretrained_net
        #####################################
        #TODO
        #####################################
    
    def forward(self, x):
        #####################################
        #TODO
        #####################################

### PSPNet

In [None]:
class PSPNet(nn.Module):
    def __init__(self, pretrained_net, n_class):
        super().__init__()
        self.n_class = n_class
        self.pretrained_net = pretrained_net
        #####################################
        #TODO
        #####################################

    def forward(self, x):
        #####################################
        #TODO
        #####################################

Construct models.

In [None]:
if model_type == "encdec":
    seg_model = EncoderDecoder(pretrained_net=vgg_model, n_class=num_class)
elif model_type == "fcn":
    seg_model = FCN(pretrained_net=vgg_model, n_class=num_class)
elif model_type == "unet":
    seg_model = UNet(pretrained_net=vgg_model, n_class=num_class)
elif model_type == "pspnet":
    seg_model = PSPNet(pretrained_net=vgg_model, n_class=num_class)

seg_model = seg_model.cuda()
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(seg_model.parameters(), lr=lr)

## Training and Validation

In [None]:
def train(seg_model, train_loader, test_loader):
    # pixel accuracy and mIOU list 
    pixel_acc_list = []
    mIOU_list = []
    for epoch in range(1, epochs+1):
        ts = time.time()
        for iter, batch in enumerate(train_loader):
            optimizer.zero_grad()
            inputs, labels = batch["X"], batch["Y"]
            if use_gpu:
              inputs = inputs.cuda()
              labels = labels.cuda()

            outputs = seg_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            if iter % 10 == 0:
                print("epoch:{:2}, iter:{:2}, loss: {:.4f}".format(epoch, iter, loss.data.item()))
        
        print("Finish epoch:{:2}, time elapsed: {:.4f}".format(epoch, time.time() - ts))
        
        print("Start evaluation ...")
        acc, iou = eval(seg_model, test_loader)
        pixel_acc_list.append(acc)
        mIOU_list.append(iou)

        print("Output test results ...")
        file_name = results_path + "/" + str(epoch).zfill(3) + ".jpg"
        for iter, batch in enumerate(test_loader):
          inputs, labels = batch["X"], batch["Y"]
          if use_gpu:
            inputs = inputs.cuda()
            labels = labels.cuda()
          outputs = seg_model(inputs)
          save_result(file_name, inputs, labels, outputs)
          break
        
        print("Save model ...")
        model_path = results_path + "/" + "segnet.pt"
        torch.save(seg_model.state_dict(), model_path)
        print("========================================")
        
    highest_pixel_acc = max(pixel_acc_list)
    highest_mIOU = max(mIOU_list)
    
    highest_pixel_acc_epoch = pixel_acc_list.index(highest_pixel_acc)
    highest_mIOU_epoch = mIOU_list.index(highest_mIOU)
    
    # Extract evaluation record
    record_path = results_path + "/record.json"
    ret = json.dumps({"acc":pixel_acc_list, "iou":mIOU_list})
    with open(record_path, 'w') as fp:
        fp.write(ret)
    
    print("The highest mIOU is {} and is achieved at epoch-{}".format(highest_mIOU, highest_mIOU_epoch+1))
    print("The highest pixel accuracy  is {} and is achieved at epoch-{}".format(highest_pixel_acc, highest_pixel_acc_epoch+1))

In [None]:
def eval(seg_model, test_loader):
    seg_model.eval()
    total_ious = []
    pixel_accs = []

    for iter, batch in enumerate(test_loader): ## batch is 1 in this case
        inputs = torch.FloatTensor(batch["X"])
        if use_gpu:
          inputs = inputs.cuda()

        output = seg_model(inputs)
        
        # only save the 1st image for comparison
        if iter == 0:
            # generate images
            input_np = batch["X"][0].data.cpu().numpy().transpose(1,2,0)
            output_np = output[0].data.cpu().numpy().transpose(1,2,0)
            gt_np = batch["Y"][0].data.cpu().numpy().transpose(1,2,0)
        output = output.data.cpu().numpy()

        N, _, h, w = output.shape
        pred = output.transpose(0, 2, 3, 1).reshape(-1, num_class).argmax(axis=1).reshape(N, h, w)
        target = batch['Y'].data.cpu().numpy().transpose(0, 2, 3, 1).reshape(-1, num_class).argmax(axis=1).reshape(N, h, w)

        for p, t in zip(pred, target):
            total_ious.append(iou(p, t))
            pixel_accs.append(pixel_acc(p, t))

    # Calculate average IoU
    total_ious = np.array(total_ious).T  # n_class * val_len
    ious = np.nanmean(total_ious, axis=1)
    pixel_accs = np.array(pixel_accs).mean()
    print("pix_acc: {:.4f}, meanIoU: {:.4f}".format(pixel_accs, np.nanmean(ious)))
    return pixel_accs, np.nanmean(ious)

# Calculates class intersections over unions
def iou(pred, target):
    ious = []
    for cls in range(num_class):
        pred_inds = pred == cls
        target_inds = target == cls
        intersection = pred_inds[target_inds].sum()
        union = pred_inds.sum() + target_inds.sum() - intersection
        if union == 0:
            ious.append(float("nan")) # if there is no ground truth, do not include in evaluation
        else:
            ious.append(float(intersection) / max(union, 1))
    return ious

def pixel_acc(pred, target):
    correct = (pred == target).sum()
    total = (target == target).sum()
    return correct / total

def save_result(file_name, input, label, output, n_samples=3):
    input_np = input[:n_samples].data.cpu().numpy().transpose(0,2,3,1)
    label_np = label[:n_samples].data.cpu().numpy().transpose(0,2,3,1)
    output_np = output[:n_samples].data.cpu().numpy().transpose(0,2,3,1)
    
    result_list = []
    for k in range(n_samples):
        tmp = np.zeros([256,256,3], dtype=np.float32)
        for i in range(256):
            for j in range(256):
                tmp[i,j,output_np[k][i,j].argmax()] = 1
        result = np.hstack((input_np[k], label_np[k], tmp))
        result_list.append(result)

    # horizontally stack original image and its corresponding segmentation results
    vstack_image = np.vstack(result_list)
    new_im = Image.fromarray(np.uint8(vstack_image*255))
    new_im.save(file_name)

In [None]:
# perform training 
train(seg_model, train_loader, test_loader)

In [None]:
# Used for evaluation
#load_path = results_path + "/" + "segnet.pt"
#seg_model.load_state_dict(torch.load(load_path))
#eval(seg_model, test_loader)