<a href="https://colab.research.google.com/github/NatRoj/AML-AVGAN-22/blob/main/Main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
!pip install SimpleITK



In [11]:
import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# C:\Users\seyhosseini\Desktop\AV\jupyter\files\OneDriveShared
imfn = "image1.nii.gz" # image filename
im = sitk.ReadImage(imfn)

vssmskfn = "image1-vessels.nii.gz" # vessel mask filename
vssmsk = sitk.ReadImage(vssmskfn)

In [40]:
imnd = sitk.GetArrayFromImage(im)
vsnd = sitk.GetArrayFromImage(vssmsk)

fig ,ax = plt.subplots(1,2,figsize=(8,8))
ax[0].imshow(imnd[300]) # 300th slice
ax[1].imshow(vsnd[300]) # 300th slice

NameError: ignored

In [None]:
# Some useful functions to later possibly use:
# vssmskbin = sitk.Threshold(vssmsk, lower = 0, upper = 1, outsideValue = 1) # vessel mask made binary # values 0&1 might be wrong
# immasked = sitk.Mask(im, vssmskbin, maskingValue = 0, outsideValue = -1024) # image masked
# sitk.WriteImage(im,outfilename)

In [29]:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import torchvision.transforms as T

In [18]:
class Block(nn.Module):
  def __init__(self,inCh,outCh):
    super(Block,self).__init__()

    self.conv1 = nn.Conv2d(inCh,outCh,kernel_size=3)
    self.conv2 = nn.Conv2d(outCh,outCh,kernel_size=3)

  def forward(self,x):
    x = self.conv1(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = F.relu(x)

    return x

In [19]:
enc_block = Block(1,64)
x = torch.randn(1,1,572,572)
enc_block(x).shape

torch.Size([1, 64, 568, 568])

In [27]:
class Enc(nn.Module):
  def __init__(self,channels=(3,64,128,256,512,1024)):
    super(Enc,self).__init__()

    self.encBlocks = nn.ModuleList([Block(channels[i],channels[i+1]) for i in range(len(channels)-1)])

    self.pool = nn.MaxPool2d(kernel_size=2)

  def forward(self,x):
    trans = []
    for blocks in self.encBlocks:
      x = blocks(x)
      trans.append(x)
      x = self.pool(x)

    return trans

In [28]:
enc = Enc()
x = torch.randn(1,3,572,572)
trans = enc(x)
for i in trans: print(i.shape)

torch.Size([1, 64, 568, 568])
torch.Size([1, 128, 280, 280])
torch.Size([1, 256, 136, 136])
torch.Size([1, 512, 64, 64])
torch.Size([1, 1024, 28, 28])


In [33]:
class Dec(nn.Module):
  def __init__(self,channels=(1024,512,256,128,64)):
    super(Dec,self).__init__()
    self.channels = channels
    self.upconv = nn.ModuleList([nn.ConvTranspose2d(channels[i],channels[i+1],kernel_size=2,stride=2) for i in range(len(channels)-1)])
    self.decBlocks = nn.ModuleList([Block(channels[i],channels[i+1]) for i in range(len(channels)-1)])

  def forward(self,x,encFeat):
    for i in range(len(self.channels)-1):
      x = self.upconv[i](x)
      encTrans = self.crop(encFeat[i],x)
      x = torch.cat([x,encTrans],dim=1)
      x = self.decBlocks[i](x)
    return x

  def crop(self,encTrans,x):
    _,_,H,W = x.shape
    encTrans = T.CenterCrop([H,W])(encTrans)
    return encTrans

In [34]:
dec = Dec()
x = torch.randn(1,1024,28,28)
dec(x,trans[::-1][1:]).shape

torch.Size([1, 64, 388, 388])

In [38]:
class UNet(nn.Module):
  def __init__(
      self,
      encChannels = (3,64,128,256,512,1024),
      decChannels = (1024,512,256,128,64),
      numClass = 1,
      retainDim = False,
      outSize = (572,572)
  ):
    super(UNet,self).__init__()
    
    self.enc = Enc(encChannels)
    self.dec = Dec(decChannels)
    self.head = nn.Conv2d(decChannels[-1],numClass,1)
    self.retainDim = retainDim
    self.outSize = outSize

  def forward(self,x):
    encTrans = self.enc(x)
    out = self.dec(encTrans[::-1][0], encTrans[::-1][1:])
    out = self.head(out)
    
    if self.retainDim:
      out = F.interpoalate(out,self.outSize)
    return out

In [39]:
unet = UNet()
x = torch.randn(1,3,572,572)
unet(x).shape

torch.Size([1, 1, 388, 388])