In [1]:
# import lib
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import math
import tqdm
from tqdm.auto import trange, tqdm

# import pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# to get CFIAR10 dataset
from torchvision import transforms
import torchvision
import torchvision.transforms as transforms

# to import pretrained models
from transformers import AutoImageProcessor, MobileNetV1Model
import timm

# import sklearn
from sklearn.model_selection import train_test_split

# set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


---
## Encoder layer

In [2]:
from get_layers import get_encoder_layers

# get encoder layers
encoder_blocks, image_stem_layer, image_processor = get_encoder_layers()

print("Encoder blocks len: ", len(encoder_blocks))
print("Image stem layer: ", image_stem_layer)
print("Image processor: ", image_processor)

Encoder blocks len:  5
Image stem layer:  MobileNetV1ConvLayer(
  (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
  (normalization): BatchNorm2d(32, eps=0.001, momentum=0.9997, affine=True, track_running_stats=True)
  (activation): ReLU6()
)
Image processor:  MobileNetV1ImageProcessor {
  "_valid_processor_keys": [
    "images",
    "do_resize",
    "size",
    "resample",
    "do_center_crop",
    "crop_size",
    "do_rescale",
    "rescale_factor",
    "do_normalize",
    "image_mean",
    "image_std",
    "return_tensors",
    "data_format",
    "input_data_format"
  ],
  "crop_size": {
    "height": 224,
    "width": 224
  },
  "do_center_crop": true,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "MobileNetV1ImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge"

---
## Decoder layer

In [3]:
from get_layers import DecoderBlock

dec90 = DecoderBlock(512, 256)

dec90

Alert: skip connection is not implemented in the decoder block


DecoderBlock(
  (relu): ReLU(inplace=True)
  (cnn1): Conv2d(512, 1536, kernel_size=(1, 1), stride=(1, 1))
  (bnn1): BatchNorm2d(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (upsample): Upsample(scale_factor=2.0, mode='nearest')
  (cnn2): Conv2d(1536, 1536, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (bnn2): BatchNorm2d(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (cnn3): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1))
  (bnn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [4]:
dec_4_block = DecoderBlock(1024, 512)
dec_3_block = DecoderBlock(512+512, 256)
dec_2_block = DecoderBlock(256+256, 128)
dec_1_block = DecoderBlock(128+128, 64)
dec_0_block = DecoderBlock(64+64, 32)

Alert: skip connection is not implemented in the decoder block
Alert: skip connection is not implemented in the decoder block
Alert: skip connection is not implemented in the decoder block
Alert: skip connection is not implemented in the decoder block
Alert: skip connection is not implemented in the decoder block


----
## Unet dry run

In [6]:
input_rand_image = torch.rand((1, 3, 224, 224))

output = image_stem_layer(input_rand_image)
print("Image stem layer output: ", output.shape)

enc_outputs = []

# run on each encoder block
for indx, enc_block in enumerate(encoder_blocks):
    output = enc_block(output)
    enc_outputs.append(output)
    print(f"Encoder block {indx} | output shape: {output.shape}")
    
# last encoder block to dec 
# print input and output shapes
print("Last encoder block output shape: ", output.shape)

# run on each decoder block
output = dec_4_block(output)
print("Decoder block 4 | output shape: ", output.shape)

output = dec_3_block(torch.cat([output, enc_outputs[3]], dim=1))
print("Decoder block 3 | output shape: ", output.shape)

output = dec_2_block(torch.cat([output, enc_outputs[2]], dim=1))
print("Decoder block 2 | output shape: ", output.shape)

# print("Encoder block 1 shape: ", enc_outputs[1].shape)
# print("cat shape: ", torch.cat([output, enc_outputs[1]], dim=1).shape)
output = dec_1_block(torch.cat([output, enc_outputs[1]], dim=1))
print("Decoder block 1 | output shape: ", output.shape)

output = dec_0_block(torch.cat([output, enc_outputs[0]], dim=1))
print("Decoder block 0 | output shape: ", output.shape)


Image stem layer output:  torch.Size([1, 32, 112, 112])
Encoder block 0 | output shape: torch.Size([1, 64, 112, 112])
Encoder block 1 | output shape: torch.Size([1, 128, 56, 56])
Encoder block 2 | output shape: torch.Size([1, 256, 28, 28])
Encoder block 3 | output shape: torch.Size([1, 512, 14, 14])
Encoder block 4 | output shape: torch.Size([1, 1024, 7, 7])
Last encoder block output shape:  torch.Size([1, 1024, 7, 7])
