<a href="https://colab.research.google.com/github/Eiko58/Hippocampus_segmentation/blob/main/hippocampus_seg.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [1]:
import torch
import torch.nn as nn


In [23]:
def double_conv(in_c, out_c):
  conv = nn.Sequential(
      nn.Conv2d(in_c,out_c,kernel_size=3),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_c,out_c,kernel_size=3),
      nn.ReLU(inplace=True),
  )
  return conv

def crop_img(tensor, target_tensor):
  target_size = target_tensor.size()[2]
  tensor_size = tensor.size()[2]
  delta = tensor_size - target_size
  delta = delta //2
  return tensor[:,:,delta:tensor_size-delta,delta:tensor_size-delta]


In [21]:
class UNet(nn.Module):
  def __init__(self):
    super(UNet,self).__init__()
    self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
    self.down_conv1 = double_conv(1,64)
    self.down_conv2 = double_conv(64,128)
    self.down_conv3 = double_conv(128,256)
    self.down_conv4 = double_conv(256,512)
    self.down_conv5 = double_conv(512,1024)

    self.up_trans1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2,stride=2)
    self.up_conv1 = double_conv(1024,512)
    self.up_trans2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2,stride=2)
    self.up_conv2 = double_conv(512,256)
    self.up_trans3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2,stride=2)
    self.up_conv3 = double_conv(256,128)
    self.up_trans4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2,stride=2)
    self.up_conv4 = double_conv(128,64)

    self.out = nn.Conv2d(in_channels=64,out_channels=2,kernel_size=1)

  def forward(self,image):
    #encoder
    x1 = self.down_conv1(image)
    x2 = self.max_pool(x1)
    #print(x1.size())
    #print(x2.size())
    x3 = self.down_conv2(x2)
    x4 = self.max_pool(x3)
    #print(x3.size())
    #print(x4.size())
    x5 = self.down_conv3(x4)
    x6 = self.max_pool(x5)
    #print(x5.size())
    #print(x6.size())
    x7 = self.down_conv4(x6)
    x8 = self.max_pool(x7)
    #print(x7.size())
    #print(x8.size())
    x9 = self.down_conv5(x8)

    x = self.up_trans1(x9)
    y = crop_img(x7,x)
    #print(x9.size())
    #print(x.size())
    #print(y.size())
    x = self.up_conv1(torch.cat([x,y],1))

    #decoder
    x = self.up_trans2(x)
    y = crop_img(x5,x)
    x = self.up_conv2(torch.cat([x,y],1))

    x = self.up_trans3(x)
    y = crop_img(x3,x)
    x = self.up_conv3(torch.cat([x,y],1))

    x = self.up_trans4(x)
    y = crop_img(x1,x)
   
    x = self.up_conv4(torch.cat([x,y],1))
    x = self.out(x)

    return x



In [24]:
image = torch.rand((1,1,572,572))
model = UNet()
print(model(image))

torch.Size([1, 64, 568, 568])
torch.Size([1, 64, 284, 284])
torch.Size([1, 128, 280, 280])
torch.Size([1, 128, 140, 140])
torch.Size([1, 256, 136, 136])
torch.Size([1, 256, 68, 68])
torch.Size([1, 512, 64, 64])
torch.Size([1, 512, 32, 32])
torch.Size([1, 1024, 28, 28])
torch.Size([1, 512, 56, 56])
torch.Size([1, 512, 56, 56])
tensor([[[[ 0.0773,  0.0858,  0.0793,  ...,  0.0809,  0.0772,  0.0769],
          [ 0.0798,  0.0851,  0.0868,  ...,  0.0754,  0.0787,  0.0752],
          [ 0.0778,  0.0745,  0.0788,  ...,  0.0749,  0.0816,  0.0801],
          ...,
          [ 0.0793,  0.0721,  0.0755,  ...,  0.0742,  0.0830,  0.0759],
          [ 0.0745,  0.0743,  0.0732,  ...,  0.0754,  0.0764,  0.0748],
          [ 0.0809,  0.0789,  0.0741,  ...,  0.0819,  0.0749,  0.0729]],

         [[-0.0057, -0.0049,  0.0005,  ..., -0.0090, -0.0035, -0.0045],
          [-0.0044, -0.0048, -0.0054,  ..., -0.0043, -0.0084, -0.0022],
          [-0.0041, -0.0054, -0.0003,  ..., -0.0059, -0.0014, -0.0037],
       