# Implementation of the U-Net architechture
Ref: [https://arxiv.org/abs/1505.04597](https://arxiv.org/abs/1505.04597) 

In [20]:
import torch
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.preprocessing import OneHotEncoder
import torch.nn as nn
import torchvision.transforms.functional as TF
from datetime import datetime
import torch.optim as optim
from tqdm import trange

## Prepare datasets based on the arrays


In [None]:
# from google.colab import drive
# drive.mount('/content/drive', force_remount=True)

In [21]:
# arr_data = np.load("/content/drive/MyDrive/project21/latest_dataset.npy") #path of the dataset
arr_data = np.load("../../data/train.npy") #path of the dataset

In [22]:
print(f"Number of arrays: {len(arr_data)}")
print(arr_data .shape)


((834, 256, 256, 4), (2001, 256, 256, 4), (169, 256, 256, 4))

In [23]:
def onehot(data, n):
    categories = [[0, 10, 20, 30, 40, 50, 60, 70, 80, 90]]
    encoder = OneHotEncoder(categories=categories, sparse_output=False)
    data_flat = data.ravel()
    onehot_encoded = encoder.fit_transform(data_flat.reshape(-1, 1))
    onehot_encoded = onehot_encoded.reshape(256, 256, -1)
    return onehot_encoded


class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.transform = transforms.Compose([transforms.ToTensor(),
                                             transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                                                  std=[0.229, 0.224, 0.225])])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        image = sample[:,:,:3]  #rgb image
        label = sample[:,:,3]  #label image

        # change date type form numpy to tensor
        if self.transform is not None:
            image = self.transform(image)

        label = onehot(label, 10) # (n,256,256,10)
        label = label.transpose(2,0,1)#(n,10, 256,256)
        label = torch.FloatTensor(label)

        return image, label

## Prepare datasets based on the images

In [24]:
carpart = MyDataset(arr_data)
train_size = int(0.95 * len(carpart))   # 95% for train
test_size = len(carpart) - train_size   # 5% for validation
train_dataset, test_dataset = random_split(carpart, [train_size, test_size])

In [25]:
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, num_workers=2)

## Create the U-Net architecture
U-Net consist of an encoder and a decoder, so let's start by defining those. 


In [32]:
#Architect of U_net
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias = False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1,bias = False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=10, features = [64, 128, 256, 512]):
        super(UNET, self).__init__()
        self.name = "UNET"
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        #Down part of U_net
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        #Up
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                    )# 28*1024 -> 56 * 512
            )
            self.ups.append(DoubleConv(feature*2, feature))
        #bottle
        self.bottleneck = DoubleConv(features[-1], features[-1]*2)  #1024
        #final FF
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)

        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            #checking
            if x.shape != skip_connection.shape:
                x = TF.resize(x, size= skip_connection.shape[2:])
            concat_skip = torch.cat((skip_connection, x), dim = 1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

In [34]:
def test():
    x = torch.randn((3, 3, 256, 256))# n = 3, in_channel = 3, h = w =256
    model = UNET(in_channels = 3, out_channels = 10)
    preds = model(x)
    print(preds.shape)
    print(x.shape)
    #shape of input and output is same


# test()

## Prepare for training 

In [37]:
def train(epo_num=10):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = UNET(in_channels = 3, out_channels = 10)#input is rgb output is 10 classes
    model_name = model.name
    model = model.to(device)
    criterion = nn.BCELoss().to(device) #loss
    optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.7) #optimizer

    all_train_iter_loss = []
    all_test_iter_loss = []

    # start timing
    prev_time = datetime.now()
    for epo in trange(epo_num):

        train_loss = 0
        model.train()
        for index, (car, car_msk) in enumerate(train_dataloader):
            # car.shape is torch.Size([12, 3, 256, 256])
            # car_msk.shape is torch.Size([12, 10, 256, 256])

            car = car.to(device)
            car_msk = car_msk.to(device)

            optimizer.zero_grad()
            output = model(car)
            output = torch.sigmoid(output) # output.shape is torch.Size([12, 10, 256, 256])
            loss = criterion(output, car_msk)
            loss.backward()
            optimizer.step()
            
            iter_loss = loss.item()
            all_train_iter_loss.append(iter_loss)
            train_loss += iter_loss

        test_loss = 0
        model.eval()
        with torch.no_grad():
            for _, (car, car_msk) in enumerate(test_dataloader):

                car = car.to(device)
                car_msk = car_msk.to(device)

                optimizer.zero_grad()
                output = model(car)
                output = torch.sigmoid(output) # output.shape is torch.Size([12, 10, 256, 256])
                loss = criterion(output, car_msk)
                iter_loss = loss.item()
                all_test_iter_loss.append(iter_loss)
                test_loss += iter_loss


        cur_time = datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)
        prev_time = cur_time

        print('epoch:', epo, '/', epo_num)
        print('epoch train loss = %f, epoch test loss = %f, %s'
                %(train_loss/len(train_dataloader), test_loss/len(test_dataloader), time_str))

        #save model each 5 epoch
        if np.mod(epo, 5) == 0:
            # save model
            filename = f'{model_name}_{epo}_loss_trian_{round(train_loss/len(train_dataloader),5)}.pt'
            torch.save(model, filename)
            print(f"\nSaving {filename}")
    return model


if __name__ == "__main__":
    model = train(epo_num=100) #maybe larger

## Save the model

In [None]:
torch.save(model, 'unet_model.pth')