In [25]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import cv2
import os
import matplotlib.pyplot as plt
import nibabel as nlb
import numpy as np
from tqdm import tqdm

In [26]:
device = 'cufa' if torch.cuda.is_available() else 'cpu'
device

'cpu'

In [18]:
class Unet(nn.Module):
    def __init__(self, img_channel, classnum):
        super().__init__()
        self.img_channel = img_channel
        self.classnum = classnum
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.downSample1 = self._convBlock(64, firstDown=True)
        self.downSample2 = self._convBlock(128)
        self.downSample3 = self._convBlock(256)
        self.downSample4 = self._convBlock(512)
        self.upsammple1 = self._convBlockTranspos(1024,firstUp=True)
        self.upsammple2 = self._convBlockTranspos(512)
        self.upsammple3 = self._convBlockTranspos(256)
        self.upsammple4 = self._convBlockTranspos(128)
        self.lastBlock = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.Conv2d(64, 64, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.Conv2d(64, self.classnum, kernel_size=1, padding='same')
        )
        
    def forward(self, x:torch.tensor):
        out1 = self.downSample1(x)
        out = self.pool(out1)
        out2 = self.downSample2(out)
        out = self.pool(out2)
        out3 = self.downSample3(out)
        out = self.pool(out3)
        out4 = self.downSample4(out)
        out = self.pool(out4)

        out = self.upsammple1(out)
        out = torch.concat([out, out4], dim=1)
        out = self.upsammple2(out)
        out = torch.concat([out, out3], dim=1)
        out = self.upsammple3(out)
        out = torch.concat([out, out2], dim=1)
        out = self.upsammple4(out)
        out = torch.concat([out, out1], dim=1)
        out = self.lastBlock(out)

        return out


    def _convBlock(self, in_channel, firstDown=False):
        if firstDown:
            block = nn.Sequential(
                nn.Conv2d(self.img_channel, in_channel, kernel_size=3, padding='same'),
                nn.ReLU(),
                nn.Dropout2d(0.2),
                nn.Conv2d(in_channel, in_channel, kernel_size=3, padding='same'),
                nn.ReLU()
            )
        else:
            block = nn.Sequential(
            nn.Conv2d(in_channel//2, in_channel, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.Conv2d(in_channel, in_channel, kernel_size=3, padding='same'),
            nn.ReLU()
        )
        
        return block
    
    def _convBlockTranspos(self, in_channel, firstUp=False):
        if firstUp:
            block = nn.Sequential(
                nn.Conv2d(in_channel//2, in_channel, kernel_size=3, padding='same'),
                nn.BatchNorm2d(in_channel),
                nn.ReLU(inplace=True),
                nn.Dropout2d(0.2),
                nn.Conv2d(in_channel, in_channel, kernel_size=3, padding='same'),
                nn.BatchNorm2d(in_channel),
                nn.ReLU(inplace=True),
                nn.ConvTranspose2d(in_channel, in_channel//2, kernel_size=2, stride=2),
            )
        else:
            block = nn.Sequential(
                nn.Conv2d(in_channel*2, in_channel, kernel_size=3, padding='same'),
                nn.BatchNorm2d(in_channel),
                nn.ReLU(inplace=True),
                nn.Dropout2d(0.2),
                nn.Conv2d(in_channel, in_channel, kernel_size=3, padding='same'),
                nn.BatchNorm2d(in_channel),
                nn.ReLU(inplace=True),
                nn.ConvTranspose2d(in_channel, in_channel//2, kernel_size=2, stride=2),
            )

        return block

In [22]:
x = torch.rand((1, 3, 240, 240))
model = Unet(3, 2)
model(x)

tensor([[[[-0.0617, -0.0757, -0.0528,  ..., -0.0860, -0.0541, -0.0735],
          [-0.0495, -0.0856, -0.0670,  ..., -0.0738, -0.0806, -0.0647],
          [-0.0672, -0.0759, -0.0627,  ..., -0.0750, -0.0882, -0.0790],
          ...,
          [-0.0637, -0.0973, -0.0790,  ..., -0.0943, -0.0674, -0.0699],
          [-0.0596, -0.0802, -0.0707,  ..., -0.0776, -0.0656, -0.0657],
          [-0.0578, -0.0540, -0.0703,  ..., -0.0686, -0.0626, -0.0585]],

         [[-0.0628, -0.0735, -0.0776,  ..., -0.0707, -0.0626, -0.0617],
          [-0.0613, -0.0749, -0.0649,  ..., -0.0725, -0.0562, -0.0772],
          [-0.0606, -0.0558, -0.0809,  ..., -0.0566, -0.0543, -0.0631],
          ...,
          [-0.0674, -0.0619, -0.0868,  ..., -0.0747, -0.0471, -0.0549],
          [-0.0722, -0.0688, -0.0691,  ..., -0.0881, -0.0855, -0.0565],
          [-0.0491, -0.0717, -0.0690,  ..., -0.0572, -0.0740, -0.0614]]]],
       grad_fn=<ConvolutionBackward1>)