In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
import torchsummary
from torch_snippets import *

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import os

device = "cuda" if torch.cuda.is_available() else "cpu"

torch.manual_seed(36)
torch.cuda.manual_seed(36)

In [None]:
# Hyper Params

IMAGE_SIZE = 128
LR = 1e-4
BIAS = False
BATCH_SIZE = 2

path_2_train = "Pix2Pix/maps/train"
path_2_val = "Pix2Pix/maps/val"

In [None]:
def show_numpy(img):
    img = np.transpose(img, axes =(1,0,2))
    plt.imshow(img,cmap="RdYlBu")
    plt.axis(False)


In [None]:
import PIL.Image as Image

x = Image.open(os.path.join("maps/maps/train/4.jpg"))
x = np.array(x)
first_half = x[:,:600,:]
second_half = x[:,600:,:]
show_numpy(first_half)

In [None]:
random_index = np.random.randint(0,1000)
random_index = int(random_index)
all_dataset = os.path.join("maps/maps/train/",f"{random_index}.jpg")

all_dataset

In [None]:
import torchvision.transforms as transforms

class Image_Dataset(Dataset):
    def __init__(self,path):
        self.path_2_images = os.path.join(path) 
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        ])
        self.len = os.listdir(self.path_2_images)
    def __len__(self):
        return len(self.len)
    
    def __getitem__(self, index):
        index += 1
        full_image = os.path.join(self.path_2_images,f"{index}.jpg")
        image = Image.open(full_image)
        image = np.array(image)

        label = self.transform(image[:,600:,:])
        target = self.transform(image[:,:600,:])

        return label,target,image

In [None]:
data_ds = Image_Dataset("maps/maps/train/")

for x,y,z in data_ds:
    print(x.shape,y.shape,z.shape)

    break

In [None]:
class ResConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(ResConv, self).__init__()
        self.conv_1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
        self.conv_2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
        self.bn_1 = nn.BatchNorm2d(out_ch)
        self.bn_2 = nn.BatchNorm2d(out_ch)
        if in_ch != out_ch:
            self.res = nn.Conv2d(in_ch, out_ch, kernel_size=1)
        else:
            self.res = nn.Identity()

    def forward(self, x):
        residue = x
        out = self.conv_1(x)
        out = self.bn_1(out)
        out = F.relu(out)


        out = self.conv_2(out)
        out = self.bn_2(out)
        out = F.relu(out)

        out += self.res(x)

        return out, residue  # Return the residual connection


class DownConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DownConv, self).__init__()

        self.res = ResConv(in_ch, out_ch)
        self.downsample = nn.MaxPool2d(2)
        self.conv = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.bn = nn.BatchNorm2d(out_ch)

    def forward(self, x):
        x, residue = self.res(x)  # Get the output tensor and the residue
        x = self.downsample(x)
        x = self.conv(x)
        x = self.relu(x)
        x = self.bn(x)

        return x, residue


class UpConv(nn.Module):
    def __init__(self, in_ch, out_ch, batch_norm_features):
        super(UpConv, self).__init__()

        self.res = ResConv(in_ch, in_ch)
        self.upsample = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.bn = nn.BatchNorm2d(batch_norm_features)

    def forward(self, x, x_skip):
        x, residue = self.res(x)  # Get the output tensor and the residue
        x = self.upsample(x)

        x = torch.cat([x, x_skip], dim=1)  # Concatenate with skip connection
        x = self.relu(x)
        x = self.bn(x)

        return x, residue

class Attention(nn.Module):
    def __init__(self, in_ch_skip):
        super(Attention, self).__init__()
        self.in_ch_skip = in_ch_skip
        self.in_ch_lower = in_ch_skip // 2

        # Adjust the number of output channels of conv_x
        self.conv_x = nn.Conv2d(self.in_ch_skip, self.in_ch_lower, kernel_size=1, stride=2)
        self.conv_g = nn.Conv2d(self.in_ch_lower, self.in_ch_lower, kernel_size=1)
        self.conv_1 = nn.Conv2d(self.in_ch_lower, 1, kernel_size=1)

        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear")
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x_skip, x_lower):
         x_skip_adjusted = self.conv_x(x_skip)
         x_lower = self.conv_g(x_lower)

         out = x_skip_adjusted + x_lower
         out = self.relu(out)
         out = self.conv_1(out)

         out = torch.sigmoid(out)
         out = self.upsample(out)

         return out




Now we will make the UNET model with self attention layers :
[]-> [][] ----------------------------------------------->O->[][] -> []- final output # 128
        ->[][] ----------------------------->O->[][]->[]--^----^                      # 64
              ->[][]-------------->O->[][]->[]^---^                                   # 32
                    ------->16[][]-^--------^                                         # 16

[]->ResConv
O ->Self Attention
image-> 128-64-32-16-32-64-128 -> new image

num_block = 2+2+2+2+3+3 = 8+6 = 15 +- 1 vad eu


In [None]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.in_ch = 3
        self.out_ch = 64

        self.conv_down_1 = DownConv(3, 64)
        self.conv_down_2 = DownConv(64, 128)
        self.conv_down_3 = DownConv(128, 256)

        self.bottleneck = ResConv(256, 256)

        self.conv_up_1 = UpConv(256, 128, 256)  # Corrected to match the number of output channels and batch_norm_features
        self.att_1 = Attention(in_ch_skip=256)
        
        self.conv_up_2 = UpConv(128, 64, 128)  # Corrected to match the number of output channels and batch_norm_features
        self.att_2 = Attention(in_ch_skip=128)

        self.conv_up_3 = UpConv(64, 32,192)
        self.att_3 = Attention(in_ch_skip=64)

        self.out_proj = nn.Conv2d(32, 3, kernel_size=3, padding=1)

    def forward(self, x):
        x, x1 = self.conv_down_1(x)
        x, x2 = self.conv_down_2(x)
        x, x3 = self.conv_down_3(x)

        x,_ = self.bottleneck(x)

        x, _ = self.conv_up_1(x, x3)
        x = self.att_1(x3, x)

        x, _ = self.conv_up_2(x, x2)
        x = self.att_2(x2, x)

        x, _ = self.conv_up_3(x, x1)
        x = self.att_3(x1, x)

        x = self.out_proj(x)
        return x

In [None]:
model = UNet()

total_params = [param.numel() for param in model.parameters()]
total_params = sum(total_params)
total_params

In [None]:
sample_input = torch.randn(1, 3, 256, 256)  # Assuming input size of 256x256 and 3 channels
model(sample_input)