In [1]:
# import lib
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import math
import tqdm
# from tqdm.auto import trange, tqdm

# import pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# to get CFIAR10 dataset
from torchvision import transforms
import torchvision
import torchvision.transforms as transforms

# to import pretrained models
from transformers import AutoImageProcessor, MobileNetV1Model
import timm

# import sklearn
from sklearn.model_selection import train_test_split

# set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


---
## Layers

In [2]:
class ChannelShuffle(nn.Module):
	def __init__(self, group=1):
		assert group > 1
		super(ChannelShuffle, self).__init__()
		self.group = group

	def forward(self, x):
		"""https://github.com/Randl/ShuffleNetV2-pytorch/blob/master/model.py
		"""
		batchsize, num_channels, height, width = x.data.size()
		assert (num_channels % self.group == 0)
		channels_per_group = num_channels // self.group
		# reshape
		x = x.view(batchsize, self.group, channels_per_group, height, width)
		# transpose
		# - contiguous() required if transpose() is used before view().
		#   See https://github.com/pytorch/pytorch/issues/764
		x = torch.transpose(x, 1, 2).contiguous()
		# flatten
		x = x.view(batchsize, -1, height, width)
		return x


class identity(nn.Module):
	def __init__(self):
		super(identity, self).__init__()

	def forward(self, x):
		return x

# based on FB block


class DecoderLayer(nn.Module):
	def __init__(self, in_channels, out_channels, expansion=3):

		super(DecoderLayer, self).__init__()
		self.conv1 = nn.ConvTranspose2d(
			in_channels, in_channels*expansion, kernel_size=1, stride=1)
		self.bn1 = nn.BatchNorm2d(in_channels*expansion)
		self.relu1 = nn.ReLU(inplace=True)

		self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

		self.conv2 = nn.Conv2d(
			in_channels*expansion, in_channels*expansion, kernel_size=5, padding=2, stride=1)
		self.bn2 = nn.BatchNorm2d(in_channels*expansion)
		self.relu2 = nn.ReLU(inplace=True)

		self.conv3 = nn.Conv2d(in_channels*expansion,
							   out_channels, kernel_size=1, stride=1)
		self.bn3 = nn.BatchNorm2d(out_channels)

		# self.skip_connection = None

		# if in_channels != out_channels:
		#   self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1)
		# else:
		#   self.skip_connection = nn.Identity()

	def forward(self, x):
		in_x = x

		x = self.conv1(x)
		x = self.bn1(x)
		x = self.relu1(x)

		x = self.upsample(x)

		x = self.conv2(x)
		x = self.bn2(x)
		x = self.relu2(x)

		x = self.conv3(x)
		x = self.bn3(x)

		# TODO add skip connection

		# print("layer output:", x.shape)
		# print("skip connection:", self.skip_connection(in_x).shape)

		# add skip connection
		# x = x + self.skip_connection(in_x)

		return x


# random input
x = torch.randn(1, 256, 32, 32)

# forward
model = DecoderLayer(256, 256)

y = model(x)

print(y.shape)

torch.Size([1, 256, 64, 64])


In [3]:
# # Initialize the MobileNetV2 model
# image_processor = AutoImageProcessor.from_pretrained(
# 	"google/mobilenet_v1_1.0_224")
# mobilenet_v1_model = MobileNetV1Model.from_pretrained(
# 	"google/mobilenet_v1_1.0_224")

In [4]:
# # a funtion that return the layer of images net in blocks and rest of the model
# def get_encoder_layers():
# 	# download the model
# 	image_processor = AutoImageProcessor.from_pretrained(
# 		"google/mobilenet_v1_1.0_224")
# 	model = MobileNetV1Model.from_pretrained(
# 		"google/mobilenet_v1_1.0_224")
	
# 	mobilenet_seq_blocks = []
# 	# block 1 will contain 4 layer of model.layer
# 	block = nn.Sequential(*list(model.layer)[:2])
# 	mobilenet_seq_blocks.append(block)
# 	# print("-"*30,"\n\nblock 1:", block)
		
# 	block = nn.Sequential(*list(model.layer)[2:4])
# 	mobilenet_seq_blocks.append(block)
# 	# print("-"*30,"\n\nblock 2:", block)

# 	block = nn.Sequential(*list(model.layer)[4:8])
# 	mobilenet_seq_blocks.append(block)
# 	# print("-"*30,"\n\nblock 3:", block)	

# 	block = nn.Sequential(*list(model.layer)[8:12])
# 	mobilenet_seq_blocks.append(block)
# 	# print("-"*30,"\n\nblock 4:", block)

# 	block = nn.Sequential(*list(model.layer)[12:])
# 	mobilenet_seq_blocks.append(block)
# 	# print("-"*30,"\n\nblock 5:", block)

# 	# printing the input and output channels of the first and last layers of each block
# 	# for i, block in enumerate(mobilenet_seq_blocks):
# 	# 	# Extracting the first and last layers of the block
# 	# 	first_layer = block[0]
# 	# 	last_layer = block[-1]
		
# 	# 	# Get the input and output channels of the first and last layers
# 	# 	input_channel_first_layer = first_layer.convolution.in_channels
# 	# 	output_channel_last_layer = last_layer.convolution.out_channels
		
# 	# 	print(f"Block {i + 1}:")
# 	# 	print("Input channel of the first layer:", input_channel_first_layer)
# 	# 	print("Output channel of the last layer:", output_channel_last_layer)
 
# 	return mobilenet_seq_blocks, model.conv_stem, image_processor

In [6]:
from get_layers import get_encoder_layers

# get encoder layers
mobilenet_seq_blocks, conv_stem, image_processor = get_encoder_layers()

In [7]:
# FB decoeder block
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, expansion=3):
        super(DecoderBlock, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        
        self.cnn1 = nn.Conv2d(in_channels, in_channels*expansion, kernel_size=1, stride=1)
        self.bnn1 = nn.BatchNorm2d(in_channels*expansion)
        
        # nearest neighbor x2
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        
        # DW conv/ c_in*exp x 5 x 5 x c_in*exp
        self.cnn2 = nn.Conv2d(in_channels*expansion, in_channels*expansion, kernel_size=5, padding=2, stride=1)
        self.bnn2 = nn.BatchNorm2d(in_channels*expansion)
        
        
        self.cnn3 = nn.Conv2d(in_channels*expansion, out_channels, kernel_size=1, stride=1)
        self.bnn3 = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        in_x = x
        print("input shape:", x.shape)    
        
        x = self.cnn1(x)
        x = self.bnn1(x)
        x = self.relu(x)
        print("after cnn1 shape:", x.shape)
        
        x = self.upsample(x)
        print("after upsample shape:", x.shape)
        
        x = self.cnn2(x)
        x = self.bnn2(x)
        x = self.relu(x)
        print("after cnn2 shape:", x.shape)
        
        x = self.cnn3(x)
        x = self.bnn3(x)
        print("after cnn3 shape:", x.shape)
        
        return x

----
## Testings

In [None]:
# sanity check for decoder block
decoder_block = DecoderBlock(256, 128)

# random input
x = torch.randn(1, 256, 28, 28)

# forward
y = decoder_block(x)

print(y.shape)

In [None]:
# send a random image through the model

input_img = torch.randn(1, 3, 224, 224)

# run the first layer 
input_img = model.conv_stem(input_img)

for block_idx, block in enumerate(mobilenet_seq_blocks):
	# input of the block
	print(f"\ninput shape: {input_img.shape}")
	input_img = block(input_img)
	print(f"Block {block_idx} output shape: {input_img.shape}")

In [None]:
# an encoding block takes 512 14 14 and outputs 1024 7 7

	# print(f"input shape: {x.shape}")
 
	# # first layer conv normalisation and relu
	# print(f"\nlayer 1: {in_channels} {in_channels}")
	# x = nn.Conv2d(in_channels, in_channels, kernel_size=(3,3), stride=(1,1), groups=in_channels, bias=False)(x)
	# print(f"conv1 shape: {x.shape}")
 
	# x = nn.BatchNorm2d(in_channels)(x)
	# x = nn.ReLU6()(x)
	# print(f"norm and act shape: {x.shape}")
 
	# # second layer pointwise conv normalisation and relu
	# print(f"\nlayer 2: {in_channels} {in_channels}")
	# x = nn.Conv2d(in_channels, in_channels, kernel_size=(1,1), stride=(1,1), bias=False)(x)
	# print(f"conv2 shape: {x.shape}")
 
	# x = nn.BatchNorm2d(in_channels)(x)
	# x = nn.ReLU6()(x)
	# print(f"norm and act shape: {x.shape}")
 
	# # third layer pointwise conv normalisation and relu
	# print(f"\nlayer 3: {in_channels} {in_channels}")
	# x = nn.Conv2d(in_channels, in_channels, kernel_size=(3,3), stride=(2,2), groups=in_channels, bias=False)(x)
	# print(f"conv3 shape: {x.shape}")
 
	# x = nn.BatchNorm2d(in_channels)(x)
	# x = nn.ReLU6()(x)
	# print(f"output shape: {x.shape}")
 
	# # fourth layer pointwise conv normalisation and relu
	# print(f"\nlayer 4: {in_channels} {out_channels}")
	# x = nn.Conv2d(in_channels, out_channels, kernel_size=(1,1), stride=(1,1), bias=False)(x)
	# print(f"norm and act shape: {x.shape}")
	
	# x = nn.BatchNorm2d(out_channels)(x)
	# x = nn.ReLU6()(x)
	# print(f"output shape: {x.shape}")
  
	# return x
 
def encoding_block(x,in_channels, out_channels):
	
	# each layer in mobilenet_seq_blocks: i want input and output shapes 
	for block_idx, block in enumerate(mobilenet_seq_blocks[-2]):
		# input of the block
		print("-"*40,f"\n\n\ninput shape: {x.shape}")
		print(f"Block {block_idx} \n block description: {block}")
		x = block(x)
		print(f"Block {block_idx} output shape: {x.shape}")
  
	return x

def decoding_block (x, in_channels, out_channels, expansion=3):
	print(f"input shape: {x.shape}")
 
	# # first layer conv normalisation and relu
	print(f"\nlayer 1: {in_channels} {in_channels*expansion}")
	x = nn.Conv2d(in_channels, in_channels*expansion, kernel_size=1, stride=1)(x)
	print(f"conv1 shape: {x.shape}")
 
	x = nn.BatchNorm2d(in_channels*expansion)(x)
	x = nn.ReLU(inplace=True)(x)
 
	print(f"norm and act shape: {x.shape}")
	# second layer pointwise conv normalisation and relu
	print(f"\nlayer 2: {in_channels*expansion} {in_channels*expansion}")
	x = nn.ConvTranspose2d(in_channels*expansion, in_channels*expansion, kernel_size=(3, 3), stride=(2, 2), padding=1, output_padding=1, groups=in_channels*expansion, bias=False)(x)
	# x = nn.ConvTranspose2d(in_channels*expansion, in_channels*expansion, kernel_size=(5, 5), stride=(2, 2), padding=1, bias=False,)(x)
	print(f"conv2 shape: {x.shape}")
	
	x = nn.BatchNorm2d(in_channels*expansion)(x)
	x = nn.ReLU(inplace=True)(x)
 
	print(f"norm and act shape: {x.shape}")
	# third layer pointwise conv normalisation and relu
	print(f"\nlayer 3: {in_channels*expansion} {out_channels}")
	x = nn.Conv2d(in_channels*expansion, out_channels, kernel_size=1, stride=1)(x)
	print(f"conv3 shape: {x.shape}")
	
	x = nn.BatchNorm2d(out_channels)(x)	
	print(f"output shape: {x.shape}")
	
	return x
 

In [None]:
# send a random image through the model
input_img = torch.randn(1, 512, 14, 14)
output_img = torch.randn(1, 512, 14, 14)

# run the first layer
# end_out = encoding_block(input_img, 512, 1024)
dec_out = decoding_block(output_img, 512, 512)

# print(f"final output shape: {end_out.shape}")