# STEP 01: Have a look of Image and Label

In [None]:
import matplotlib.pyplot as plt
import cv2 as cv
import numpy as np
import os

root_train_imgs = "../__HW7_DATA/rgb_images/"
root_test_imgs = "../__HW7_DATA/rgb_images(test_set)/"
root_train_mask = "../__HW7_DATA/semantic_annotations/gtLabels/"

# Adjust this number to see other image
img_num = 2

train_imgs = os.listdir(root_train_imgs)
print(f"First 5 train imgs are {train_imgs[:5]}")
img = cv.imread(root_train_imgs + train_imgs[img_num])
img_resize = cv.resize(img, (256, 256), interpolation = cv.INTER_CUBIC)
print(img_resize.shape)
plt.imshow(img_resize)
plt.show()

train_mask = os.listdir(root_train_mask)
print(f"First 5 train mask are {train_mask[:5]}")
mask = cv.imread(root_train_mask + train_mask[img_num])
mask_resize = cv.resize(mask, (256, 256), interpolation = cv.INTER_NEAREST)
print(mask_resize.shape)
mask_resize_draw = mask_resize/10
plt.imshow(mask_resize_draw)
plt.show()

# STEP 02: Build Custom Dataset Class

In [None]:
import cv2 as cv
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
from keras.utils import to_categorical

class RoadDataset(Dataset):
    
    def __init__(self, width, height, path_to_imgs, path_to_mask, transform = None):
        
        self.height = height
        self.width = width
        self.path_to_img = path_to_imgs
        self.path_to_mask= path_to_mask
        
        self.train_imgs = os.listdir(path_to_imgs)
        self.train_mask = os.listdir(path_to_mask)
        
        self.length = len(self.train_imgs)
        self.transform = transform
        
    def __getitem__(self, index):
        
        img = cv.imread(self.path_to_img + self.train_imgs[index])
        msk = cv.imread(self.path_to_mask + self.train_mask[index])
        
        img_resize = cv.resize(img, (self.width, self.height), interpolation = cv.INTER_CUBIC)
        msk_resize = cv.resize(msk, (self.width, self.height), interpolation = cv.INTER_NEAREST)
        
        msk_transpose = msk_resize.transpose((2, 0, 1))
        msk_one_channel = msk_transpose[0]
        
        if self.transform:
            img_tensor = self.transform(img_resize)
        
        return (img_tensor, msk_one_channel)
        
        
    def __len__(self):
        return self.length
    
# This will normalize the image value
trans = transforms.Compose([
    transforms.ToTensor(),
])

# Set DataLoader
width = 256
height = 256
batch_size = 10

Train_Dataset = RoadDataset(width, height, root_train_imgs, root_train_mask, trans)
Train_Dataloader = DataLoader(Train_Dataset, batch_size = batch_size, shuffle = True, num_workers = 0)

# STEP 03: See Model Summary

In [None]:
from torchsummary import summary
import torch
import pytorch_unet

if False:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = pytorch_unet.UNet(10)
    model = model.to(device)
    summary(model, input_size = (3, 256, 256))

# STEP 04: Train UNet

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

def train_model(model, optimizer, scheduler, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10
    criterion = nn.CrossEntropyLoss()
    

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        since = time.time()

        for param_group in optimizer.param_groups:
            print("LR", param_group['lr'])
                    
            model.train()

        epoch_samples = 0
        probe_num = 0
            
        for inputs, labels in tqdm(Train_Dataloader):
            inputs, labels = inputs.to(device), labels.to(device).long()
            #labels = labels.long()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            loss_list = []
            
            probe_num = probe_num + 1
            if True and probe_num % 10 == 0:
                    
                print("loss": loss)
                # Plot predict mask
                pred_np = outputs.data.cpu().numpy()[0]
                pred_argmax = np.argmax(pred_np, axis = 0)
                print(np.unique(pred_argmax))
                plt.imshow(pred_argmax/10)
                plt.show()
                    
                # Plot ground truth mask
                lab_np = labels.data.cpu().numpy()[0]
                print(np.unique(lab_np))
                plt.imshow(lab_np/10)
                plt.show()
                    
                # Plot input image
                inp_np = inputs.data.cpu().numpy()[0].transpose((1, 2, 0))
                plt.imshow(inp_np)
                print("============================")

            loss.backward()
            
            optimizer.step()

            epoch_samples += inputs.size(0)
        print(epoch_samples)
            
        scheduler.step()  

        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [None]:
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import copy

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

num_class = 10

model = pytorch_unet.UNet(num_class).to(device)

optimizer_ft = optim.Adam(model.parameters(), lr = 1e-4)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size = 25, gamma = 0.1)

model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs = 30)

# STEP 05: Test UNet

In [None]:
Test_Dataset = RoadDataset(width, height, root_test_imgs, root_train_mask, trans)
Test_Dataloader = DataLoader(Test_Dataset, batch_size = 1, shuffle = False, num_workers = 0)
img, _ = next(iter(Test_Dataloader))
img = img.to(device)
pred = model(img)
print(pred.shape)