In [None]:
# import lib
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import math
import tqdm
# from tqdm.auto import trange, tqdm
import os

# import pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# to get CFIAR10 dataset
from torchvision import transforms
import torchvision
import torchvision.transforms as transforms

# to import pretrained models
from transformers import AutoImageProcessor, MobileNetV1Model
import timm

# import sklearn
from sklearn.model_selection import train_test_split
from PIL import Image


# set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

---
## Encoder layer

In [None]:
from get_layers import get_encoder_layers

# get encoder layers
encoder_blocks, image_stem_layer, image_processor = get_encoder_layers()

print("Encoder blocks len: ", len(encoder_blocks))
print("Image stem layer: ", image_stem_layer)
print("Image processor: ", image_processor)

---
## Decoder layer

In [None]:
from get_layers import get_get_decoder_layers

get_decoder_layers = get_get_decoder_layers()

len(get_decoder_layers)

In [None]:
# print in and out channels of each decoder block
for i, block in enumerate(get_decoder_layers):
    print(f"Decoder block {i}: {block.cnn1.in_channels} -> {block.cnn3.out_channels}")

---
## Unet architecture

In [None]:
class Unet(nn.Module):
    def __init__(self):
        '''
	        Unet model
        '''
        super(Unet, self).__init__()
        # get encoder layers
        encoder_blocks, image_stem_layer, image_processor = get_encoder_layers()
        # get decoder layers
        get_decoder_layers = get_get_decoder_layers()
        
        print(len(encoder_blocks))
        print(len(get_decoder_layers))
        
        raise Exception("stop here")
        
        self.encoder_blocks = encoder_blocks
        self.image_stem_layer = image_stem_layer
        self.image_processor = image_processor
        
        self.get_decoder_layers = get_decoder_layers
        
    def forward(self, x, process_image = False):
        '''
	        Forward pass
	        x: input image (ideally should be of shape (batch_size, 3, 224, 224))
			process_image: whether to process image or not to the appropriate size
        '''
        # asset x is a tensor
        assert isinstance(x, torch.Tensor), "Input should be a tensor"
        if process_image:
            new_x = []
            for img in x:
                img = self.image_processor(img)['pixel_values'][0]
                new_x.append(img)
            x = torch.stack(new_x)
            x = x.permute(0, 3, 1, 2)
            
        # assetion to check if the image is of the right size
        assert x.shape[2] == 224 and x.shape[3] == 224, "Image size should be 224x224"
        
        x = image_stem_layer(x)
        print("Image stem layer output shape: ", x.shape)
        
        enc_outputs = []
        # run on each encoder block
        for indx, enc_block in enumerate(self.encoder_blocks):
            x = enc_block(x)
            enc_outputs.append(x)
            # print(f"Encoder block {indx} | output shape: {x.shape}")
            
        # last encoder block to dec 
        # print input and output shapes
        # print("\nLast encoder block output shape: ", x.shape)
        # print()
        
        for indx, dec_block in enumerate(self.get_decoder_layers):
            if indx == 0:
                x = dec_block(x)
                # print(f"Decoder block {indx} | output shape: {x.shape}")
            else:
                x = dec_block(torch.cat([x, enc_outputs[len(self.get_decoder_layers) - indx - 1]], dim=1))
                # print(f"Decoder block {indx} | output shape: {x.shape}")
        
        return x

In [None]:
unet = Unet()

# random input
x = torch.randn(1, 3, 224, 224)

# forward pass
y = unet(x)

# print output shape
print(y.shape)

In [None]:
# now load the model
unet = Unet()

# send input image to the model
output_img = unet(input_img, process_image = True)

In [None]:
# # Load an image with path
# image_path = "..\\test\\input\\IMG_0106.png"

# # Load image
# image = Image.open(image_path)

# # Show original image
# plt.imshow(image)
# plt.title("Original Image")
# plt.axis("off")
# plt.show()

# # Preprocess image
# preprocessed_image = image_processor(image)['pixel_values'][0]

# # Transpose the NumPy array
# # preprocessed_image = preprocessed_image.transpose((1, 2, 0))

# print("Preprocessed image shape: ", preprocessed_image.shape)

# # Show preprocessed image
# plt.imshow(preprocessed_image.transpose((1, 2, 0)))
# plt.title("Preprocessed Image")
# plt.axis("off")
# plt.show()