<a href="https://colab.research.google.com/github/Najme-naseri/Image-Procssing/blob/main/UnetImp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import matplotlib as plt
import numpy as np
import PIL
import torchvision
import torchvision.transforms as T
from PIL import Image
import pdb

In [2]:
class double_conv(nn.Module):
  def __init__(self , in_c, out_c):
    super().__init__()
    self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3)
    self.bn1 = nn.BatchNorm2d(out_c)
    self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3)
    self.bn2 = nn.BatchNorm2d(out_c)
    self.relu = nn.ReLU(inplace = True)

  def forward(self, inputs):
    x = self.conv1(inputs)
    x = self.bn1(x)
    x = self.conv2(x)
    x = self.bn2(x)
    x = self.relu(x)
    return x


In [3]:
class encoder(nn.Module):
  def __init__(self, in_c, out_c):
    super().__init__()
    self.conv = double_conv(in_c, out_c)
    self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

  def forward(self, inputs):
    x = self.conv(inputs)
    p = self.maxpool(x)
    return x, p


In [4]:
def crop_img(tensor, target_tensor):
  target_size = target_tensor.size()[2]
  tensor_size = tensor.size()[2]
  delta = tensor_size - target_size
  delta = delta // 2
  return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta]

In [5]:
class decoder(nn.Module):
  def __init__(self, in_c, out_c):
    super().__init__()
    self.up = nn.ConvTranspose2d(in_c, out_c, 2, 2)
    self.conv = double_conv(out_c + out_c, out_c)

  def forward(self, inputs, skip):
    x = self.up(inputs)
    skip = crop_img(skip, x)
    x = torch.cat([x, skip], axis=1)
    x = self.conv(x)
    return x


In [6]:

class Unet(nn.Module):
  def __init__(self):
    super().__init__()
    self.e1 = encoder(1, 64)
    self.e2 = encoder(64, 128)
    self.e3 = encoder(128, 256)
    self.e4 = encoder(256, 512)

    self.b = double_conv(512, 1024)

    self.d1 = decoder(1024, 512)
    self.d2 = decoder(512, 256)
    self.d3 = decoder(256, 128)
    self.d4 = decoder(128, 64)

    self.output = nn.Conv2d(64, 1, 1)

  def forward(self, inputs):
    s1, p1 = self.e1(inputs)
    s2, p2 = self.e2(p1)
    s3, p3 = self.e3(p2)
    s4, p4 = self.e4(p3)

    b = self.b(p4)

    d1 = self.d1(b, s4)
    d2 = self.d2(d1, s3)
    d3 = self.d3(d2, s2)
    d4 = self.d4(d3, s1)

    outputs = self.output(d4)
    return outputs


In [23]:
img_tensor = torch.rand((1, 1, 572, 572))
transform = T.ToPILImage()

img1 = transform(img_tensor.reshape([1,572,572]))
img1.show()

model = Unet()

#pdb.set_trace()
result = model.forward(img_tensor)
#print(result.size())

img2 = transform(result.reshape([1, 388, 388]))
img2.show()