<a href="https://colab.research.google.com/github/DongDong-Zoez/ComputerVision/blob/main/Image%20Segmentation/UNet/ColorNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Train UNet from scratch

In [None]:
# connect to your google drive
from google.colab import drive
drive.mount('/content/gdrive')

#Split the image to training and validation

In [None]:
#import splitfolders
#splitfolders.ratio(input='/content/gdrive/MyDrive/AnimeFace/', output='/content/gdrive/MyDrive/AnimeFace/split', seed=1337, ratio=(0.8, 0.2))

# UNet Architecture

TO DO:

1. In encoder part, use Conv2d replace MaxPool2d, and see how it works.
2. In decoder part, use ConvTranpose2d replace Upsample, and see how it works.

NOTE:

1. in_channels = 3 for RGB images
2. change out_channels to your custom setting

In [1]:
from torch import nn

class DownSampleLayer(nn.Module):
    
    def __init__(self, in_ch, out_ch):
        super(DownSampleLayer, self).__init__()
        
        self.DoubleConv = nn.Sequential(
            nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=out_ch),
            nn.ReLU(),
            
            nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=out_ch),
            nn.ReLU()
        )
        
        self.downsample = nn.Sequential(
            nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(num_features=out_ch),
            nn.ReLU(),
            #nn.MaxPool2d(2, 2)
        )
        
    def forward(self, x):
        
        x = self.DoubleConv(x)
        d = self.downsample(x)
        
        return x, d
    
class UpSampleLayer(nn.Module):
    
    def __init__(self, in_ch, out_ch):
        super(UpSampleLayer, self).__init__()
        
        self.DoubleConv = nn.Sequential(
            nn.Conv2d(in_channels=in_ch, out_channels=out_ch*2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=out_ch*2),
            nn.ReLU(),
            
            nn.Conv2d(in_channels=out_ch*2, out_channels=out_ch*2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=out_ch*2),
            nn.ReLU()
        )
        
        self.Upsample = nn.Sequential(
            nn.ConvTranspose2d(in_channels=out_ch*2, out_channels=out_ch, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(num_features=out_ch),
            nn.ReLU()
        )
    
    def forward(self, x, copy_crop):
        
        x = self.DoubleConv(x)
        u = self.Upsample(x)
        copy_crop = torch.cat((u, copy_crop), dim=1)
        
        return copy_crop
    
class UNet(nn.Module):
    
    def __init__(self, channels=1):
        super(UNet, self).__init__()
        
        self.channels = channels
        in_channels = 1 #RGB
        out_channels= [16,32,64,128,256] #image tile
        
        self.d1 = DownSampleLayer(in_ch=in_channels, out_ch=out_channels[0])
        self.d2 = DownSampleLayer(in_ch=out_channels[0], out_ch=out_channels[1])
        self.d3 = DownSampleLayer(in_ch=out_channels[1], out_ch=out_channels[2])
        self.d4 = DownSampleLayer(in_ch=out_channels[2], out_ch=out_channels[3])
        
        self.u1 = UpSampleLayer(in_ch=out_channels[3], out_ch=out_channels[3])
        self.u2 = UpSampleLayer(in_ch=out_channels[4], out_ch=out_channels[2])
        self.u3 = UpSampleLayer(in_ch=out_channels[3], out_ch=out_channels[1])
        self.u4 = UpSampleLayer(in_ch=out_channels[2], out_ch=out_channels[0])
        
        self.output = nn.Sequential(
            nn.Conv2d(in_channels=out_channels[1], out_channels=out_channels[0], kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=out_channels[0]),
            nn.ReLU(),
            
            nn.Conv2d(in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=out_channels[0]),
            nn.ReLU(),    
            
            nn.Conv2d(in_channels=out_channels[0], out_channels=self.channels, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
            
        c1, d1 = self.d1(x)
        c2, d2 = self.d2(d1)
        c3, d3 = self.d3(d2)
        c4, d4 = self.d4(d3)
            
        u1 = self.u1(d4, c4)
        u2 = self.u2(u1, c3)
        u3 = self.u3(u2, c2)
        u4 = self.u4(u3, c1)
            
        out = self.output(u4)
            
        return out
    
    def save_model(self, filename):
        torch.save(self.state_dict(), filename)

    def load_model(self, filename, cpu=False):
        if not cpu:
            self.load_state_dict(torch.load(filename))
        else:
            self.__init__(self.nbase,
                    self.nout,
                    self.kernel_size,
                    self.concatenation)

            self.load_state_dict(torch.load(filename,
                                      map_location=torch.device('cpu')))

#Build the Dataset

Noet: We use the data from [Kaggle Anime Face](https://www.kaggle.com/soumikrakshit/anime-faces)

In [None]:
import torch
import cv2
import os
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import numpy as np
import albumentations

WIDTH, HEIGHT = 64, 64

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize([HEIGHT, WIDTH]),
    transforms.ColorJitter(contrast=0.5, hue=0.25),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                  
])

class AnimeFacesDataset(Dataset):

    def __init__(self, transforms=transform ,path='/content/gdrive/MyDrive/AnimeFace/split/train/data/'):
        self.path = path
        self.transforms = transforms
        for root, dirs, files in os.walk(self.path):
            self.imgs = [self.path + file for file in files]

    def __getitem__(self, idx):
        img = cv2.imread(self.imgs[idx])
        img = self.transforms(img)
        Y = img
        X = (Y[0,:,:] * 0.299 + Y[1,:,:] * 0.587 + Y[2,:,:] * 0.114).reshape(1, HEIGHT, WIDTH)

        return X, Y

    def __len__(self):
        return len(self.imgs)

    def show_img(self, idx):
        img = cv2.imread(self.imgs[idx])
        img = self.transforms(img)
        img = np.transpose(img, [0,2,1])
        img = np.array(img)[::-1,...]

        Y = img
        X = (Y[0,:,:] * 0.299 + Y[1,:,:] * 0.587 + Y[2,:,:] * 0.114)

        plt.subplot(1,2,2)
        plt.title('Gray scale')
        plt.axis('off')
        plt.imshow(X.T, cmap='Greys_r')
        plt.subplot(1,2,1)
        plt.title('Except Output')
        plt.axis('off')
        plt.imshow(Y.T)


In [None]:
anime = AnimeFacesDataset()

idx = 10
anime.show_img(idx)

#Train

In [None]:
from torch import optim
import torch.nn as nn
from tqdm import tqdm

def train(net, device, path, loss_func, epochs=100, batch_size=32, lr=0.001):

    anime = AnimeFacesDataset(path=path)
    data_loader = torch.utils.data.DataLoader(dataset=anime, batch_size=batch_size, shuffle=True)

    optimizer = optim.Adam(net.parameters(), lr=lr)

    criterion = loss_func

    best_loss = float('inf')

    net.train()

    for epoch in tqdm(range(epochs)):
        for image, label in data_loader:

            optimizer.zero_grad()

            image = image.to(device=device)
            label = label.to(device=device)

            pred = net(image)

            loss = criterion(pred, label)
            print('Loss/train', loss.item())

            if loss < best_loss:
                best_loss = loss
                torch.save(net.state_dict(), '/content/gdrive/MyDrive/best_model_n.pth')

            loss.backward()
            optimizer.step()


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(channels=3)
net.to(device=device)
net.load_state_dict(torch.load('/content/gdrive/MyDrive/best_model_n.pth'))

path = "/content/gdrive/MyDrive/AnimeFace/split/train/data/"
train(net, device, path, nn.MSELoss())

#Inference

In [None]:
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(channels=3)
#net.load_state_dict(torch.load('/content/gdrive/MyDrive/best_model_n.pth', map_location=device))
net.to(device=device)
net.eval()

with torch.no_grad():

    anime = AnimeFacesDataset(path='/content/gdrive/MyDrive/AnimeFace/split/val/data/')
    data_loader = torch.utils.data.DataLoader(dataset=anime, batch_size=1, shuffle=True)

    for batch_idx, (images, targets) in enumerate(tqdm(data_loader), 1):
        images = images.to(device)
        targets = targets.to(device)

        output = net(images)
        output = np.array(output.data.cpu()[0]) * 255
        images = np.array(images.data.cpu()[0]).squeeze(0) * 255
        targets = np.array(targets.data.cpu()[0]) * 255

        output = np.transpose(output, [1,2,0])
        targets = np.transpose(targets, [1,2,0])

        cv2.imwrite(f'/content/gdrive/MyDrive/AnimeFace/ColorUNet/Pred/{batch_idx}.png', output)
        cv2.imwrite(f'/content/gdrive/MyDrive/AnimeFace/ColorUNet/X/{batch_idx}.png', images)
        cv2.imwrite(f'/content/gdrive/MyDrive/AnimeFace/ColorUNet/Y/{batch_idx}.png', targets)