In [3]:
# import lib
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import math
import tqdm
from tqdm.auto import trange, tqdm

# import pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# to get CFIAR10 dataset
from torchvision import transforms
import torchvision
import torchvision.transforms as transforms

# to import pretrained models
from transformers import AutoImageProcessor, MobileNetV1Model
import timm

# import sklearn
from sklearn.model_selection import train_test_split

# set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
from get_layersv2 import *

In [5]:
class Unet(nn.Module):
	def __init__(self):
		super(Unet, self).__init__()
		# get encoder layers
		self.encoder_blocks, self.first_layer = get_encoderv2_layers()
		self.decoder_blocks, self.last_layer = get_decoderv2_layers()
  
	def forward(self,x):
		assert x.shape[1] == 3, "input image should have 3 channels(nx3x224x224)"
  
		temp_in_x = x
		print(f"input size: {x.shape}") 

		x = self.first_layer(x)
		print(f"First size: {x.shape}") 
  
		enc_outputs = []
		for indx, enc_block in enumerate(self.encoder_blocks):
				x = enc_block(x)
				print(f"Encoder block {indx} output shape: {x.shape}")
				enc_outputs.append(x)
	
		print(f"Rep size: {x.shape}") 
  
		for indx, dec_block in enumerate(self.decoder_blocks):
			if indx == 0:
				print("we")
				x = dec_block(x)
				
			else:
				x = dec_block(
					torch.cat([x, enc_outputs[len(self.decoder_blocks) - indx - 1]], dim=1))
			print(f"Decoder block {indx} output shape: {x.shape}")
   
		x = self.last_layer(x)
		print(f"Last size: {x.shape}")
  
		# attention
  
		return x + temp_in_x

In [8]:

class h_sigmoid(nn.Module):
	def __init__(self, inplace=True):
		super(h_sigmoid, self).__init__()
		self.relu = nn.ReLU6(inplace=inplace)

	def forward(self, x):
		return self.relu(x + 3) / 6

class h_swish(nn.Module):
	def __init__(self, inplace=True):
		super(h_swish, self).__init__()
		self.sigmoid = h_sigmoid(inplace=inplace)

	def forward(self, x):
		return x * self.sigmoid(x)

class CoordAtt(nn.Module):
	def __init__(self, inp, oup, reduction=32):
		super(CoordAtt, self).__init__()
		self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
		self.pool_w = nn.AdaptiveAvgPool2d((1, None))

		mip = max(8, inp // reduction)

		self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
		self.bn1 = nn.BatchNorm2d(mip)
		self.act = h_swish()
		
		self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
		self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
		

	def forward(self, x):
		identity = x
		
		n,c,h,w = x.size()
		x_h = self.pool_h(x)
		x_w = self.pool_w(x).permute(0, 1, 3, 2)

		y = torch.cat([x_h, x_w], dim=2)
		y = self.conv1(y)
		y = self.bn1(y)
		y = self.act(y) 
		
		x_h, x_w = torch.split(y, [h, w], dim=2)
		x_w = x_w.permute(0, 1, 3, 2)

		a_h = self.conv_h(x_h).sigmoid()
		a_w = self.conv_w(x_w).sigmoid()

		out = identity * a_w * a_h

		return out 

In [10]:
class UnetWithAT(nn.Module):
	def __init__(self, lr=0.5):
		super(UnetWithAT, self).__init__()
		self.encoder_blocks, self.first_layer = get_encoderv2_layers()
		self.decoder_blocks, self.last_layer = get_decoderv2_layers()
		
		# encoder_blocks, image_stem_layer, image_processor = get_encoder_layers()
		# decoder_blocks = get_decoder_layers()

		# self.encoder_blocks = encoder_blocks
		# self.decoder_blocks = decoder_blocks

		# self.image_processor = image_processor
		# self.image_stem_layer = image_stem_layer
		self.lra = CoordAtt(3, 3)
		self.ldra = CoordAtt(3, 3)
		
		# self.out_image_stem_layer = nn.Sequential(
		#     nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2,
		#                        bias=False, padding=1, output_padding=1),
		#     nn.BatchNorm2d(3, eps=0.001, momentum=0.9997,
		#                    affine=True, track_running_stats=True),
		#     nn.ReLU()
		# )

		# self.loss2 = nn.
		
		
	def forward(self, x):
		"""
		Performs forward pass through the U-Net model.

		Args:
			x (torch.Tensor): Input image tensor.
			process_image (bool): Whether to preprocess input image.

		Returns:
			torch.Tensor: Output image tensor.
		"""
		assert x.shape[1] == 3, "input image should have 3 channels(nx3x224x224)"

		temp_in_x = x
		# print(f"input size: {x.shape}") 

		x = self.first_layer(x)
		# print(f"First size: {x.shape}") 
  
		enc_outputs = []
		for indx, enc_block in enumerate(self.encoder_blocks):
				x = enc_block(x)
				# print(f"Encoder block {indx} output shape: {x.shape}")
				enc_outputs.append(x)
	
		# print(f"Rep size: {x.shape}") 
  
		for indx, dec_block in enumerate(self.decoder_blocks):
			if indx == 0:
				x = dec_block(x)
				
			else:
				x = dec_block(
					torch.cat([x, enc_outputs[len(self.decoder_blocks) - indx - 1]], dim=1))
			# print(f"Decoder block {indx} output shape: {x.shape}")
   
		x = self.last_layer(x)
		# print(f"Last size: {x.shape}")
			
		# lra attention on skip connection
		temp_in_x = self.lra(temp_in_x)
		# ldra attention on output
		x = self.ldra(x)
		
		return x + temp_in_x

In [11]:
ex_unet = UnetWithAT()
  
# input 
ex_in_u = torch.randn(1, 3, 224, 224)

ex_out_u = ex_unet(ex_in_u)

Rep size: torch.Size([1, 1280, 7, 7])
Last size: torch.Size([1, 3, 224, 224])
