<a href="https://colab.research.google.com/github/JuhiRaj/3DSimulations/blob/main/LaUNET2D221_SPECT_Image2Image.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Model


In [None]:
from IPython.core.display import Image
import torch
import torch.nn as nn
from torchsummary import summary


def double_conv(in_c, out_c):
  conv = nn.Sequential(
      nn.Conv2d(in_c, out_c, kernel_size=3, padding="same"),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_c, out_c, kernel_size=3, padding="same"),
      nn.ReLU(inplace=True),
  )
  return conv

def double_conv_up(in_c, out_c):
  conv = nn.Sequential(
      nn.Conv2d(in_c, out_c, kernel_size=3, padding="same"),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_c, out_c, kernel_size=3, padding="same"),
      nn.ReLU(inplace=True),
  )
  return conv


def crop_img(tensor, target_tensor):
    target_size = target_tensor.size()[2]

    # Calculate the difference in sizes
    delta_h = tensor.size()[2] - target_size
    delta_w = tensor.size()[3] - target_size

    # Calculate cropping indices
    h_start, h_end = delta_h // 2, tensor.size()[2] - delta_h // 2
    w_start, w_end = delta_w // 2, tensor.size()[3] - delta_w // 2

    return tensor[:, :, h_start:h_end, w_start:w_end]


'''
def crop_img(tensor, target_tensor):
  target_size = target_tensor.size()[2]
  if tensor.size()[2] % 2 == 1:
    tensor_size = tensor.size()[2]-1
  else:
    tensor_size = tensor.size()[2]
  delta = tensor_size - target_size
  delta = delta // 2
  return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta]
'''

class UNet(nn.Module):
  def __init__(self):
    super(UNet, self).__init__()
    self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
    self.down_conv_1 = double_conv(1,64)
    self.down_conv_2 = double_conv(64,128)
    self.down_conv_3 = double_conv(128,256)
    self.down_conv_4 = double_conv(256,512)
    self.down_conv_5 = double_conv(512,1024)



    self.up_trans_1 = nn.ConvTranspose2d(in_channels=1024,
                                         out_channels=512,
                                         kernel_size=2,
                                         stride=2,
                                         output_padding=1)

    self.up_conv_1 = double_conv_up(1024, 512)



    self.up_trans_2 = nn.ConvTranspose2d(in_channels=512,
                                         out_channels=256,
                                         kernel_size=2,
                                         stride=2,
                                         output_padding=0)

    self.up_conv_2 = double_conv_up(512, 256)




    self.up_trans_3 = nn.ConvTranspose2d(in_channels=256,
                                     out_channels=128,
                                     kernel_size=2,
                                     stride=2,
                                     output_padding=1)  # Adjusted output_padding


    self.up_conv_3 = double_conv_up(256, 128)



    self.up_trans_4 = nn.ConvTranspose2d(in_channels=128,
                                     out_channels=64,
                                     kernel_size=2,
                                     stride=2,
                                     output_padding=0)  # Change this line

    self.up_conv_4 = double_conv_up(128, 64)

    self.out = nn.Conv2d(
        in_channels=64,
        out_channels=1,
        kernel_size=1,
        stride=1
    )




  def forward(self, image):
    # bs, c, h, w
    #encoder
    maxx = torch.mean(image)
    x1 = self.down_conv_1(image/maxx) #A1
    x2 = self.max_pool_2x2(x1) #A2

    x3 = self.down_conv_2(x2) #B1
    x4 = self.max_pool_2x2(x3) #B2

    x5 = self.down_conv_3(x4) #C1
    x6 = self.max_pool_2x2(x5) #C2

    x7 = self.down_conv_4(x6) #D1
    x8 = self.max_pool_2x2(x7) #D2

    #print('Bottle Neck')
    x9 = self.down_conv_5(x8) #E1 (Neck of the UNet)

    #decoder
    x10 = self.up_trans_1(x9)
    print("x10 size:", x10.size())
    print("x7 size:", x7.size())
    c1_cat = torch.cat([x10, x7], 1)
    print("Concatenated size:", c1_cat.size())

    '''
    #decoder
    x10 = self.up_trans_1(x9)
    c1_cat = torch.cat([x10, x7], 1)
    '''

    x11 = self.up_conv_1(c1_cat)
    x12 = self.up_trans_2(x11)
    c2_cat = torch.cat([x12, x5], 1)

    x13 = self.up_conv_2(c2_cat)
    x14 = self.up_trans_3(x13)
    c3_cat = torch.cat([x14, x3], 1)

    x15 = self.up_conv_3(c3_cat)
    x16 = self.up_trans_4(x15)
    c4_cat = torch.cat([x16, x1], 1)

    x17 = self.up_conv_4(c4_cat)
    x18 = self.out(x17)*maxx
    return x18 + image


if __name__ == "__main__":
  image = torch.rand((1, 1, 250, 250))
  model = UNet()
  print('1',model(image).shape)
  #print(model)


from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

count_parameters(model)


x10 size: torch.Size([1, 512, 31, 31])
x7 size: torch.Size([1, 512, 31, 31])
Concatenated size: torch.Size([1, 1024, 31, 31])
1 torch.Size([1, 1, 250, 250])
+----------------------+------------+
|       Modules        | Parameters |
+----------------------+------------+
| down_conv_1.0.weight |    576     |
|  down_conv_1.0.bias  |     64     |
| down_conv_1.2.weight |   36864    |
|  down_conv_1.2.bias  |     64     |
| down_conv_2.0.weight |   73728    |
|  down_conv_2.0.bias  |    128     |
| down_conv_2.2.weight |   147456   |
|  down_conv_2.2.bias  |    128     |
| down_conv_3.0.weight |   294912   |
|  down_conv_3.0.bias  |    256     |
| down_conv_3.2.weight |   589824   |
|  down_conv_3.2.bias  |    256     |
| down_conv_4.0.weight |  1179648   |
|  down_conv_4.0.bias  |    512     |
| down_conv_4.2.weight |  2359296   |
|  down_conv_4.2.bias  |    512     |
| down_conv_5.0.weight |  4718592   |
|  down_conv_5.0.bias  |    1024    |
| down_conv_5.2.weight |  9437184   |
|  down

31030593