In [53]:
# import lib
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import math
import tqdm
from tqdm.auto import trange, tqdm
import os

# import pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# to get CFIAR10 dataset
from torchvision import transforms
import torchvision
import torchvision.transforms as transforms

# to import pretrained models
from transformers import AutoImageProcessor, MobileNetV1Model
import timm

# import sklearn
from sklearn.model_selection import train_test_split
from PIL import Image


# set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

---
## Encoder layer

In [2]:
from get_layers import get_encoder_layers

# get encoder layers
encoder_blocks, image_stem_layer, image_processor = get_encoder_layers()

print("Encoder blocks len: ", len(encoder_blocks))
print("Image stem layer: ", image_stem_layer)
print("Image processor: ", image_processor)

Encoder blocks len:  5
Image stem layer:  MobileNetV1ConvLayer(
  (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
  (normalization): BatchNorm2d(32, eps=0.001, momentum=0.9997, affine=True, track_running_stats=True)
  (activation): ReLU6()
)
Image processor:  MobileNetV1ImageProcessor {
  "_valid_processor_keys": [
    "images",
    "do_resize",
    "size",
    "resample",
    "do_center_crop",
    "crop_size",
    "do_rescale",
    "rescale_factor",
    "do_normalize",
    "image_mean",
    "image_std",
    "return_tensors",
    "data_format",
    "input_data_format"
  ],
  "crop_size": {
    "height": 224,
    "width": 224
  },
  "do_center_crop": true,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "MobileNetV1ImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge"

---
## Decoder layer

In [3]:
from get_layers import DecoderBlock

dec90 = DecoderBlock(512, 256)

dec90

Alert: skip connection is not implemented in the decoder block


DecoderBlock(
  (relu): ReLU(inplace=True)
  (cnn1): Conv2d(512, 1536, kernel_size=(1, 1), stride=(1, 1))
  (bnn1): BatchNorm2d(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (upsample): Upsample(scale_factor=2.0, mode='nearest')
  (cnn2): Conv2d(1536, 1536, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (bnn2): BatchNorm2d(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (cnn3): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1))
  (bnn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [16]:
def get_decoder_blocks(out_sizes = [512, 256, 128, 64, 32]):
    decoder_blocks = []
    for i, out_size in enumerate(out_sizes):
        if i == 0:
            decoder_blocks.append(DecoderBlock(out_size*2 , out_size))
        else:
            decoder_blocks.append(DecoderBlock(out_size*4 , out_size))
    return decoder_blocks

decoder_blocks = get_decoder_blocks()

Alert: skip connection is not implemented in the decoder block
Alert: skip connection is not implemented in the decoder block
Alert: skip connection is not implemented in the decoder block
Alert: skip connection is not implemented in the decoder block
Alert: skip connection is not implemented in the decoder block


In [19]:
# print in and out channels of each decoder block
for i, block in enumerate(decoder_blocks):
    print(f"Decoder block {i}: {block.cnn1.in_channels} -> {block.cnn3.out_channels}")

Decoder block 0: 1024 -> 512
Decoder block 1: 1024 -> 256
Decoder block 2: 512 -> 128
Decoder block 3: 256 -> 64
Decoder block 4: 128 -> 32


---
## Unet architecture

In [57]:
class Unet(nn.Module):
    def __init__(self):
        '''
	        Unet model
        '''
        super(Unet, self).__init__()
        # get encoder layers
        encoder_blocks, image_stem_layer, image_processor = get_encoder_layers()
        # get decoder layers
        decoder_blocks = get_decoder_blocks()
        
        self.encoder_blocks = encoder_blocks
        self.image_stem_layer = image_stem_layer
        self.image_processor = image_processor
        
        self.decoder_blocks = decoder_blocks
        
    def forward(self, x, process_image = False):
        '''
	        Forward pass
	        x: input image (ideally should be of shape (batch_size, 3, 224, 224))
			process_image: whether to process image or not to the appropriate size
        '''
        # asset x is a tensor
        assert isinstance(x, torch.Tensor), "Input should be a tensor"
        if process_image:
            new_x = []
            for img in x:
                img = self.image_processor(img)['pixel_values'][0]
                new_x.append(img)
            x = torch.stack(new_x)
            x = x.permute(0, 3, 1, 2)
            
        # assetion to check if the image is of the right size
        assert x.shape[2] == 224 and x.shape[3] == 224, "Image size should be 224x224"
        
        x = image_stem_layer(x)
        print("Image stem layer output shape: ", x.shape)
        
        enc_outputs = []
        # run on each encoder block
        for indx, enc_block in enumerate(self.encoder_blocks):
            x = enc_block(x)
            enc_outputs.append(x)
            # print(f"Encoder block {indx} | output shape: {x.shape}")
            
        # last encoder block to dec 
        # print input and output shapes
        # print("\nLast encoder block output shape: ", x.shape)
        # print()
        
        for indx, dec_block in enumerate(self.decoder_blocks):
            if indx == 0:
                x = dec_block(x)
                # print(f"Decoder block {indx} | output shape: {x.shape}")
            else:
                x = dec_block(torch.cat([x, enc_outputs[len(self.decoder_blocks) - indx - 1]], dim=1))
                # print(f"Decoder block {indx} | output shape: {x.shape}")
        
        return x

In [58]:
unet = Unet()

# random input
x = torch.randn(1, 3, 224, 224)

# forward pass
y = unet(x)

# print output shape
print(y.shape)

Alert: skip connection is not implemented in the decoder block
Alert: skip connection is not implemented in the decoder block
Alert: skip connection is not implemented in the decoder block
Alert: skip connection is not implemented in the decoder block
Alert: skip connection is not implemented in the decoder block
Image stem layer output shape:  torch.Size([1, 32, 112, 112])
torch.Size([1, 32, 224, 224])


In [59]:
def load_img_to_tensor(img_dir):
    '''
        Load image to tensor
        img_dir: image directory
    '''
    # load all the images paths in the directory
    img_paths = [img_dir + img for img in os.listdir(img_dir)]
    
    images = []
    # load images to tensor
    for img_path in img_paths:
        img = Image.open(img_path)
        img = img.resize((224, 224))
        img = np.array(img)
        img = torch.tensor(img)
        images.append(img)
        
    return torch.stack(images)

input_img = load_img_to_tensor("..\\test\\input\\")

print(input_img.shape)

torch.Size([7, 224, 224, 3])


In [60]:
# now load the model
unet = Unet()

# send input image to the model
output_img = unet(input_img, process_image = True)

Alert: skip connection is not implemented in the decoder block
Alert: skip connection is not implemented in the decoder block
Alert: skip connection is not implemented in the decoder block
Alert: skip connection is not implemented in the decoder block
Alert: skip connection is not implemented in the decoder block


TypeError: expected Tensor as element 0 in argument 0, but got numpy.ndarray

In [56]:
# # Load an image with path
# image_path = "..\\test\\input\\IMG_0106.png"

# # Load image
# image = Image.open(image_path)

# # Show original image
# plt.imshow(image)
# plt.title("Original Image")
# plt.axis("off")
# plt.show()

# # Preprocess image
# preprocessed_image = image_processor(image)['pixel_values'][0]

# # Transpose the NumPy array
# # preprocessed_image = preprocessed_image.transpose((1, 2, 0))

# print("Preprocessed image shape: ", preprocessed_image.shape)

# # Show preprocessed image
# plt.imshow(preprocessed_image.transpose((1, 2, 0)))
# plt.title("Preprocessed Image")
# plt.axis("off")
# plt.show()