<a href="https://colab.research.google.com/github/Namtk214/Unet/blob/main/UNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import os
import random
import numpy as np
import pandas as pf
import matplotlib.pyplot as plt

from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models, utils
from sklearn.model_selection import train_test_split

In [None]:
class FirstFeature(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(FirstFeature, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.ReLU(),
        )
    def forward(self, x):
        return self.conv(x)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(inplace=True),
        )
    def forward(self, x):
        return self.conv(x)

class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.MaxPool2d(2),
            ConvBlock(in_channels, out_channels),
        )
    def forward(self, x):
        return self.encoder(x)

class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Decoder, self).__init__()
        self.conv = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(),
        )
    def forward(self, x, skip):
        x = self.conv(x)
        x = torch.cat((x, skip), dim=1)
        x = self.conv_block(x)
        return x

class FinalOutput(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(FinalOutput, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.Tanh()
        )
    def forward(self, x):
        return self.conv(x)

In [None]:
class Unet(nn.Module):
    def __init__(self, n_channel=3, n_classes=3, features=[64, 128, 256, 512, 1024]):
        super(Unet, self).__init__()
        self.n_channel = n_channel
        self.n_classes = n_classes

        self.in_conv1 = FirstFeature(n_channel, features[0])
        self.in_conv2 = ConvBlock(features[0], features[0])

        self.enc1 = Encoder(features[0], features[1])
        self.enc2 = Encoder(features[1], features[2])
        self.enc3 = Encoder(features[2], features[3])
        self.enc4 = Encoder(features[3], features[4])

        self.dec1 = Decoder(features[4], features[3])
        self.dec2 = Decoder(features[3], features[2])
        self.dec3 = Decoder(features[2], features[1])
        self.dec4 = Decoder(features[1], features[0])

        self.out_conv = FinalOutput(features[0], n_classes)

    def forward(self, x):
        x = self.in_conv1(x)
        x1 = self.in_conv2(x)
        x2 = self.enc1(x1)
        x3 = self.enc2(x2)
        x4 = self.enc3(x3)
        x5 = self.enc4(x4)
        x = self.dec1(x5, x4)
        x = self.dec2(x, x3)
        x = self.dec3(x, x2)
        x = self.dec4(x, x1)
        x = self.out_conv(x)
        return x



Super Resolution

In [None]:
self.resize_fnc = transforms.Resize(LOW_IMG_HEIGHT*4, LOW_IMG_HEIGHT*4, antialias=True)
x = self.resize_fnc(x)

In [None]:
class FirstFeatureNoSkip(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(FirstFeatureNoSkip, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.ReLU(),
        )
    def forward(self, x):
        return self.conv(x)

class ConvBlockNoSkip(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlockNoSkip, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class EncoderNoSkip(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super(EncoderNoSkip, self).__init__()
        self.encoder = nn.Sequential(
            nn.MaxPool2d(2),
            ConvBlockNoSkip(in_channels, out_channels),
        )
    def forward(self, x):
        return self.encoder(x)

class DecoderNoSkip(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DecoderNoSkip, self).__init__()
        self.conv = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(),
        )
        self.conv_block = ConvBlockNoSkip(in_channels, out_channels)
    def forward(self, x, skip):
        x = self.conv(x)
        x = torch.cat((x, skip), dim=1)
        x = self.conv_block(x)
        return x

class FinalOutputNoSkip(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(FinalOutputNoSkip, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.Tanh()
        )
    def forward(self, x):
        return self.conv(x)

class SR_Unet_NoSkip(nn.Module):
    def __init__(self, n_channel=3, n_classes=3, features=[64, 128, 256, 512, 1024]):
        super(SR_Unet_NoSkip, self).__init__()
        self.n_channel = n_channel
        self.n_classes = n_classes

        self.in_conv1 = FirstFeatureNoSkip(n_channel, features[0])
        self.in_conv2 = ConvBlockNoSkip(features[0], features[0])

        self.enc1 = EncoderNoSkip(features[0], features[1])
        self.enc2 = EncoderNoSkip(features[1], features[2])
        self.enc3 = EncoderNoSkip(features[2], features[3])
        self.enc4 = EncoderNoSkip(features[3], features[4])

        self.dec1 = DecoderNoSkip(features[4], features[3])
        self.dec2 = DecoderNoSkip(features[3], features[2])
        self.dec3 = DecoderNoSkip(features[2], features[1])
        self.dec4 = DecoderNoSkip(features[1], features[0])

        self.out_conv = FinalOutputNoSkip(features[0], n_classes)

    def forward(self, x):
        x = self.resize_fnc(x)
        x = self.in_conv1(x)
        x = self.in_conv2(x)

        x = self.enc1(x)
        x = self.enc2(x)
        x = self.enc3(x)
        x = self.enc4(x)

        x = self.dec1(x)
        x = self.dec2(x)
        x = self.dec3(x)
        x = self.dec4(x)

        x = self.out_conv(x)
        return x

In [None]:
class ImageDataset(Dataset):
    def __init__(self, img_dir, is_train=True):
        self.img_dir = img_dir
        self.is_train = is_train
        self.img_list = os.listdir(img_dir)

    def __len__(self):
        return len(self.img_list)

    def normalize(self, x):
        return (x - x.min()) / (x.max() - x.min())

    def random_jitter(self, input_image, target_image):
        if torch.rand([]) < 0.5:
            input_image = transforms.functional.hflip(input_image)
            target_image = transforms.functional.hflip(target_image)

        return input_image, target_image

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_list[idx])
        image = np.array(Image.open(img_path)).convert('RGB')
        image = transforms.ToTensor()(image)
        input_image = self.resize(image).type(torch.float32)
        target_image = self.resize(image).type(torch.float32)
        input_image, target_image = self.normalize(input_image), self.normalize(target_image)
        if self.is_train:
            input_image, target_image = self.random_jitter(input_image, target_image)

        return input_image, target_image