# IONetv2 model

---

In [1]:
# import lib
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import math
import tqdm
from tqdm.auto import trange, tqdm

# import pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# to get CFIAR10 dataset
from torchvision import transforms
import torchvision
import torchvision.transforms as transforms

# to import pretrained models
from transformers import AutoImageProcessor, MobileNetV1Model
import timm

# import sklearn
from sklearn.model_selection import train_test_split

# set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# import fb decoder blocks
from get_layersv2 import DecoderBlock

---
## Misc

In [3]:
def cvtImg(img):
    img = img.permute([0, 2, 3, 1])
    img = img - img.min()
    img = (img / img.max())
    return img.numpy().astype(np.float32)


def show_examples(x):
    plt.figure(figsize=(10, 10))
    imgs = cvtImg(x)
    for i in range(25):
        plt.subplot(5, 5, i+1)
        plt.imshow(imgs[i])
        plt.axis('off')

---
## Dino ResNet50 model

In [4]:
from transformers import AutoFeatureExtractor, ResNetModel
from PIL import Image

feature_extractor = AutoFeatureExtractor.from_pretrained('Ramos-Ramos/dino-resnet-50')
model = ResNetModel.from_pretrained('Ramos-Ramos/dino-resnet-50')



In [5]:
model.embedder

ResNetEmbeddings(
  (embedder): ResNetConvLayer(
    (convolution): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU()
  )
  (pooler): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)

In [6]:
model.encoder.stages[2]

ResNetStage(
  (layers): Sequential(
    (0): ResNetBottleNeckLayer(
      (shortcut): ResNetShortCut(
        (convolution): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (layer): Sequential(
        (0): ResNetConvLayer(
          (convolution): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (normalization): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activation): ReLU()
        )
        (1): ResNetConvLayer(
          (convolution): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (normalization): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activation): ReLU()
        )
        (2): ResNetConvLayer(
          (convolution): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=Fa

---
## MobileNetV2: Inverted Residuals and Linear Bottlenecks

In [7]:
from transformers import MobileNetV2Config, MobileNetV2Model

# Initializing a "mobilenet_v2_1.0_224" style configuration
configuration = MobileNetV2Config()

# Initializing a model from the "mobilenet_v2_1.0_224" style configuration
model = MobileNetV2Model(configuration)

model

MobileNetV2Model(
  (conv_stem): MobileNetV2Stem(
    (first_conv): MobileNetV2ConvLayer(
      (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      (activation): ReLU6()
    )
    (conv_3x3): MobileNetV2ConvLayer(
      (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
      (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      (activation): ReLU6()
    )
    (reduce_1x1): MobileNetV2ConvLayer(
      (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
    )
  )
  (layer): ModuleList(
    (0): MobileNetV2InvertedResidual(
      (expand_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias

In [8]:
# get first layer
first_layer = model.conv_stem

print(first_layer)

MobileNetV2Stem(
  (first_conv): MobileNetV2ConvLayer(
    (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
    (activation): ReLU6()
  )
  (conv_3x3): MobileNetV2ConvLayer(
    (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
    (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
    (activation): ReLU6()
  )
  (reduce_1x1): MobileNetV2ConvLayer(
    (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
  )
)


In [9]:
model_layers = model.layer

# mobileNetv2_encoder = nn.ModuleList()
list_en = []

encoder_layers_idx = [
	[0],
    [1,2],
	[3,4,5],
	[6,7,8,9],
	[10,11,12],
	[13,14,15],
]

for idx, layer in enumerate(encoder_layers_idx):
	# mobileNetv2_encoder.append(nn.Sequential(*[model_layers[i] for i in layer]))
	list_en.append(nn.Sequential(*[model_layers[i] for i in layer]))
 
list_en[0]

Sequential(
  (0): MobileNetV2InvertedResidual(
    (expand_1x1): MobileNetV2ConvLayer(
      (convolution): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (normalization): BatchNorm2d(96, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      (activation): ReLU6()
    )
    (conv_3x3): MobileNetV2ConvLayer(
      (convolution): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), groups=96, bias=False)
      (normalization): BatchNorm2d(96, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      (activation): ReLU6()
    )
    (reduce_1x1): MobileNetV2ConvLayer(
      (convolution): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (normalization): BatchNorm2d(24, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
    )
  )
)

In [10]:
for layer_idx, layer in enumerate(list_en):
	print(f"\nBlock {layer_idx}:")
	print(f"  input channels {layer_idx}: {layer[0].expand_1x1.convolution.in_channels}")
	print(f"  output channels {layer_idx}: {layer[-1].reduce_1x1.convolution.out_channels}")
    # break
    # if isinstance(layer.convolution, nn.Conv2d):
    # print(f"  Layer {layer_idx}: Input: {layer.convolution.in_channels}, Output: {layer.convolution.out_channels}, Kernel: {layer.convolution.kernel_size}")
    
print(f"\nBlock {5}:")
list_en.append(model.conv_1x1)
# print(list_en[-1])
print(f"  input channels {5}: {list_en[-1].convolution.in_channels}")
print(f"  output channels {5}: {list_en[-1].convolution.out_channels}")


Block 0:
  input channels 0: 16
  output channels 0: 24

Block 1:
  input channels 1: 24
  output channels 1: 32

Block 2:
  input channels 2: 32
  output channels 2: 64

Block 3:
  input channels 3: 64
  output channels 3: 96

Block 4:
  input channels 4: 96
  output channels 4: 160

Block 5:
  input channels 5: 160
  output channels 5: 320

Block 5:
  input channels 5: 320
  output channels 5: 1280


In [17]:
def get_encoderv2_layers():
	list_en = nn.ModuleList()
	encoder_layers_idx = [
		[0],
  		[1,2],
		[3,4,5],
		[6,7,8,9],
		[10,11,12],
		[13,14,15],
	]
	for idx, layer in enumerate(encoder_layers_idx):
		# list_en.append(nn.Sequential(*[model_layers[i] for i in layer]))
		list_en.append(nn.Sequential(*[model_layers[i] for i in layer]))
	list_en.append(model.conv_1x1)
	return list_en, model.conv_stem

def get_decoderv2_layers():
	decoder_layers = nn.ModuleList()
	decoder_layers.append(DecoderBlock(1280, 320, do_up_sampling=False)) 
	decoder_layers.append(DecoderBlock(320+320, 160, do_up_sampling=False))
	decoder_layers.append(DecoderBlock(160+160, 96, do_up_sampling=True))
	decoder_layers.append(DecoderBlock(96+96, 64, do_up_sampling=False))
	decoder_layers.append(DecoderBlock(64+64, 32, do_up_sampling=True))
	decoder_layers.append(DecoderBlock(32+32, 24, do_up_sampling=True))
	decoder_layers.append(DecoderBlock(24+24, 16, do_up_sampling=True))
	
	# a conv layer to get the final output
	out_stem = nn.Sequential(
		nn.Sequential(
			nn.ConvTranspose2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False),
			nn.BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True),
			nn.ReLU6(),
		),
		nn.Sequential(
			nn.ConvTranspose2d(32,32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False,padding=1),
			nn.BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True),
			nn.ReLU6(),
		),
		nn.Sequential(
			nn.ConvTranspose2d(32, 3, kernel_size=(3, 3), stride=(2, 2), bias=False, padding=1, output_padding=1),
			nn.BatchNorm2d(3, eps=0.001, momentum=0.997, affine=True, track_running_stats=True), 
		)
	)
 
	return decoder_layers, out_stem

In [18]:
# forward pass random image 3x224x224
input = torch.randn(1, 3, 224, 224)

# get encoder layers
ex_enc_list,ex_first_layer = get_encoderv2_layers()

# forward pass
x = input
print(f"Input: {x.size()}")

# preprocess input
x = ex_first_layer(x)
print(f"First Layer: {x.size()}")

for idx, layer in enumerate(ex_enc_list):
	x = layer(x)
	print(f"Layer {idx}: {x.size()}")
 
print()

# ex_first_layer
# ex_enc_list[0]

Input: torch.Size([1, 3, 224, 224])
First Layer: torch.Size([1, 16, 112, 112])
Layer 0: torch.Size([1, 24, 56, 56])
Layer 1: torch.Size([1, 32, 28, 28])
Layer 2: torch.Size([1, 64, 14, 14])
Layer 3: torch.Size([1, 96, 14, 14])
Layer 4: torch.Size([1, 160, 7, 7])
Layer 5: torch.Size([1, 320, 7, 7])
Layer 6: torch.Size([1, 1280, 7, 7])



In [19]:
out_stem = nn.Sequential(
		nn.Sequential(
			nn.ConvTranspose2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False),
			nn.BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True),
			nn.ReLU6(),
		),
		nn.Sequential(
			nn.ConvTranspose2d(32,32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False,padding=1),
			nn.BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True),
			nn.ReLU6(),
		),
		nn.Sequential(
			nn.ConvTranspose2d(32, 3, kernel_size=(3, 3), stride=(2, 2), bias=False, padding=1, output_padding=1),
			nn.BatchNorm2d(3, eps=0.001, momentum=0.997, affine=True, track_running_stats=True), 
		)
	)

x_in = torch.randn(1, 16, 112, 112)
x = x_in

x_out = out_stem(x)

print(f"Input: {x_in.size()}")
print(f"Output: {x_out.size()}")

x_in = torch.randn(1, 3, 224, 224)
x = x_in

x_out = model.conv_stem(x)

print(f"Input: {x_in.size()}")
print(f"Output: {x_out.size()}")

Input: torch.Size([1, 16, 112, 112])
Output: torch.Size([1, 3, 224, 224])
Input: torch.Size([1, 3, 224, 224])
Output: torch.Size([1, 16, 112, 112])


In [20]:
class Unet(nn.Module):
	def __init__(self):
		super(Unet, self).__init__()
		# get encoder layers
		self.encoder_blocks, self.first_layer = get_encoderv2_layers()
		self.decoder_blocks, self.last_layer = get_decoderv2_layers()
  
	def forward(self,x):
		assert x.shape[1] == 3, "input image should have 3 channels(nx3x224x224)"
  
		temp_in_x = x
		print(f"input size: {x.shape}") 

		x = self.first_layer(x)
		print(f"First size: {x.shape}") 
  
		enc_outputs = []
		for indx, enc_block in enumerate(self.encoder_blocks):
				x = enc_block(x)
				print(f"Encoder block {indx} output shape: {x.shape}")
				enc_outputs.append(x)
    
		print(f"Rep size: {x.shape}") 
  
		for indx, dec_block in enumerate(self.decoder_blocks):
			if indx == 0:
				print("we")
				x = dec_block(x)
				
			else:
				x = dec_block(
					torch.cat([x, enc_outputs[len(self.decoder_blocks) - indx - 1]], dim=1))
			print(f"Decoder block {indx} output shape: {x.shape}")
   
		x = self.last_layer(x)
		print(f"Last size: {x.shape}")
  
		# attention
  
		return x + temp_in_x

In [21]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, expansion=2, do_up_sampling=True):
        # TODO: adding skip connections
        """
        Decoder block module.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            expansion (int, optional): Expansion factor. Default is 3.
        """
        super(DecoderBlock, self).__init__()
        self.relu = nn.ReLU(inplace=True)

        self.cnn1 = nn.Conv2d(in_channels, in_channels *
                              expansion, kernel_size=1, stride=1)
        self.bnn1 = nn.BatchNorm2d(in_channels*expansion)

        # nearest neighbor x2
        self.do_up_sampling = do_up_sampling
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

        # DW conv/ c_in*exp x 5 x 5 x c_in*exp
        self.cnn2 = nn.Conv2d(in_channels*expansion, in_channels *
                              expansion, kernel_size=5, padding=2, stride=1)
        self.bnn2 = nn.BatchNorm2d(in_channels*expansion)

        self.cnn3 = nn.Conv2d(in_channels*expansion,
                              out_channels, kernel_size=1, stride=1)
        self.bnn3 = nn.BatchNorm2d(out_channels)
        
        self.identity = nn.Sequential(
			nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1),
			nn.ReLU(inplace=True)
		)

    def forward(self, x):
        """
        Forward pass through the decoder block.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        """
        temp_x = self.identity(x)
        x = self.cnn1(x)
        x = self.bnn1(x)
        x = self.relu(x)

        if self.do_up_sampling:
            x = self.upsample(x)
            temp_x = self.upsample(temp_x)

        x = self.cnn2(x)
        x = self.bnn2(x)
        x = self.relu(x)

        x = self.cnn3(x)
        x = self.bnn3(x)

        return x + temp_x

In [22]:
ex_unet = Unet()
  
# input 
ex_in_u = torch.randn(1, 3, 224, 224)

ex_out_u = ex_unet(ex_in_u)

input size: torch.Size([1, 3, 224, 224])
First size: torch.Size([1, 16, 112, 112])
Encoder block 0 output shape: torch.Size([1, 24, 56, 56])
Encoder block 1 output shape: torch.Size([1, 32, 28, 28])
Encoder block 2 output shape: torch.Size([1, 64, 14, 14])
Encoder block 3 output shape: torch.Size([1, 96, 14, 14])
Encoder block 4 output shape: torch.Size([1, 160, 7, 7])
Encoder block 5 output shape: torch.Size([1, 320, 7, 7])
Encoder block 6 output shape: torch.Size([1, 1280, 7, 7])
Rep size: torch.Size([1, 1280, 7, 7])
we
Decoder block 0 output shape: torch.Size([1, 320, 7, 7])
Decoder block 1 output shape: torch.Size([1, 160, 7, 7])
Decoder block 2 output shape: torch.Size([1, 96, 14, 14])
Decoder block 3 output shape: torch.Size([1, 64, 14, 14])
Decoder block 4 output shape: torch.Size([1, 32, 28, 28])
Decoder block 5 output shape: torch.Size([1, 24, 56, 56])
Decoder block 6 output shape: torch.Size([1, 16, 112, 112])
Last size: torch.Size([1, 3, 224, 224])
