<a href="https://colab.research.google.com/github/Eiko58/Hippocampus_segmentation/blob/main/Segnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torchvision import datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader as DataLoader
from torchvision import transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader, Dataset, random_split
import glob
from google.colab import drive
import matplotlib.pyplot as plt
import os
import os.path
from pathlib import Path
import cv2
import copy
import torchvision.transforms.functional as TF
import random
import math
import pandas as pd
from torchvision.utils import save_image

In [2]:
drive.mount('/content/drive/')
!pwd

Mounted at /content/drive/
/content


In [3]:
class img_dataset(Dataset):
  def __init__(self, kind='train', transforms=False, center_crop=False):
    self.center_crop = center_crop
    self.transforms = transforms
    self.kind = kind
    super(Dataset,self).__init__()
    path_kind = 'drive/MyDrive/hippocampus/' + kind
    self.features = [cv2.cvtColor(cv2.imread(file),cv2.COLOR_BGR2GRAY) for file in glob.glob(path_kind+'/Total/*.jpg')] 
    self.targets = [cv2.cvtColor(cv2.imread(file),cv2.COLOR_BGR2GRAY) for file in glob.glob(path_kind+'/label/*.jpg')]
    assert len(self.features) == len(self.targets), "Something wrong with the dataset"
    
  def __len__(self):
    return len(self.features)
  
  def __getitem__(self, index):
    feature, target = self.features[index], self.targets[index]
    feature_tensor = torch.tensor(feature)
    feature_tensor = torch.unsqueeze(feature_tensor/255, 0)
    target_tensor = torch.tensor(target)/255
    target_tensor = torch.round(target_tensor)
    target_tensor = torch.unsqueeze(target_tensor, 0)
    if self.center_crop:
      feature_tensor = TF.center_crop(feature_tensor, 150)
      target_tensor = TF.center_crop(target_tensor, 150)
    if self.transforms:
      if self.kind == 'train' or self.kind == 'balanced_train':
        if random.uniform(0,1) > 0.8:
                x_unif = random.uniform(0.5, 1.5)
                feature_tensor = TF.adjust_gamma(feature_tensor, x_unif)
        if random.uniform(0,1) > 0.8:
                x1 = np.random.binomial(4, 0.5) - 2
                y1 = np.random.binomial(4, 0.5) - 2
                x2 = random.uniform(0.9, 1.1)
                x3 = random.uniform(-5, 5)
                feature_tensor = TF.affine(feature_tensor, angle=0, translate = [x1, y1], scale = x2, shear=x3)
                target_tensor = TF.affine(target_tensor, angle=0, translate = [x1, y1], scale = x2, shear=x3) # think it's needed because of shear (?)
        if random.uniform(0,1) > 0.8:
                feature_tensor = TF.gaussian_blur(feature_tensor, 3)
    return feature_tensor, target_tensor

In [4]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels, batchnorm=False):
    super(DoubleConv, self).__init__()
    self.batchnorm = batchnorm
    if self.batchnorm:
        self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=True),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
        )
    else:
        self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True), 
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=True),
        nn.ReLU(inplace=True)
        )

  def forward(self, x):
    return self.conv(x)

def init_xavier(model, retrain_seed):
  torch.manual_seed(retrain_seed)
  def init_weights(m):
    if type(m) == nn.Linear and m.weight.requires_grad and m.bias.requires_grad:
      g = nn.init.calculta_gain('ReLU')
      torch.nn.init.xavier_uniform_(m.weight, gain=g)
      m.bias.data.fill_(0)
  model.apply(init_weights)

class UNet(nn.Module):
  def __init__(self, in_channels=1, out_channels=1, features=[64, 128, 256, 512], batchnorm=False, initialization=False):
    super(UNet, self).__init__()
    self.batchnorm = batchnorm
    self.initialization = initialization
    self.name = f'UNet_batchnorm_{self.batchnorm}_initialization_{self.initialization}'
    self.downs = nn.ModuleList()
    self.ups = nn.ModuleList()
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    # Down part of UNet
    for feature in features:
      self.downs.append(DoubleConv(in_channels, feature, batchnorm=self.batchnorm))
      in_channels = feature

    # Up part of UNet
    for feature in reversed(features):
      self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)) #feature*2 because of concatination
      self.ups.append(DoubleConv(feature*2, feature, batchnorm=self.batchnorm))
    
    # Bottleneck
    self.bottleneck = DoubleConv(features[-1], features[-1]*2, batchnorm=self.batchnorm)

    # Final Conv
    self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
  
  def forward(self, x):
    skip_connections = []

    for down in self.downs:
      x = down(x)
      skip_connections.append(x) # first has highest resolution
      x = self.pool(x)
    
    x = self.bottleneck(x)
    skip_connections = skip_connections[::-1]

    for i in range(0, len(self.ups), 2):
      x = self.ups[i](x) #ConvTranspose
      skip_connection = skip_connections[i//2]
      if x.shape != skip_connection.shape:
        x = transforms.functional.resize(x, skip_connection.shape[2:]) #ignoring batch size and channel size
      concat_skip = torch.concat((skip_connection, x), dim=1) #dim=1 is channel dimension
      x = self.ups[i+1](concat_skip) # DoubleConv
    
    return self.final_conv(x)

In [5]:
s = nn.Sigmoid()
class diceloss(nn.Module):
  def __init__(self):
    super(diceloss, self).__init__()
  def forward(self, outcome, label):
    return 1-(2*(s(outcome)*label).sum()+ 1e-5) / ((s(outcome)+label).sum()+ 1e-8) 

In [6]:
def batch_loss(outputs, labels, batch_size, criterion): #gives empty_loss and nempty_loss for batch
  loss_empty_labels = []
  loss_not_empty_labels = []
  for i in range(batch_size):
    if torch.sum(labels[i]) == 0:
      loss_empty_labels.append(criterion(outputs[i].detach(), labels[i]).item())
    else:
      loss_not_empty_labels.append(criterion(outputs[i].detach(), labels[i]).item())
  return loss_empty_labels, loss_not_empty_labels

In [7]:
def train(p_trdata, transforms, model, criterion, epochs, seed, crop=False, balanced=False): #epochs = tuple
  # make dataloaders
  print('train')
  if not balanced:
    tr_dataset = img_dataset(kind='train', transforms=transforms, center_crop=crop)
  else:
    tr_dataset = img_dataset(kind='balanced_train', transforms=transforms, center_crop=crop)
  if p_trdata != 1:
      keep = math.floor(len(tr_dataset)*p_trdata)
      tr_dataset, _ = random_split(tr_dataset, [keep, len(tr_dataset)-keep], generator=torch.Generator().manual_seed(seed))
  tr_dataloader = DataLoader(tr_dataset, batch_size=8, shuffle=True)
  print('val')
  val_dataset = img_dataset(kind='validation', transforms=False, center_crop=crop)
  val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False)

  print('initialization')
  if model.initialization == True:
    init_xavier(model, seed)
  s = nn.Sigmoid() # for image saving

  # make folder
  print('make folders')
  full_model_name = f'{model.name}_transforms_{transforms}_criterion_{criterion}_ptrdata_{p_trdata}_seed_{seed}_crop_{crop}_balanced_{balanced}'
  path = 'drive/MyDrive/models/'+full_model_name
  if not os.path.isdir(path):
    os.mkdir(path)

  # train
  epochs_tr_empty_loss = []
  epochs_tr_nempty_loss = []
  epochs_tr_total_loss = []
  epochs_val_empty_loss = []
  epochs_val_nempty_loss = []
  epochs_val_total_loss = []
  min_val_loss = float("inf")
  start_epoch, end_epoch = epochs
  for epoch in range(start_epoch, end_epoch): 
    print(epoch)
    empty_loss = []
    nempty_loss = []
    total_loss = []
    model.train()
    for i, data in enumerate(tr_dataloader):  
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)         
        loss = criterion(outputs, labels)
        
        el, nel = batch_loss(outputs, labels, outputs.shape[0], criterion)
        empty_loss.extend(el)
        nempty_loss.extend(nel)
        total_loss.append(loss.item())

        loss.backward()                    
        optimizer.step()                   
    
    # track training error once per epoch
    epochs_tr_empty_loss.append(np.mean(np.array(empty_loss)))
    epochs_tr_nempty_loss.append(np.mean(np.array(nempty_loss)))
    epochs_tr_total_loss.append(np.mean(np.array(total_loss)))

    # track validation error once per epoch
    model.eval()
    val_empty_loss = []
    val_nempty_loss = []
    val_total_loss = []
    with torch.no_grad():
        for i, data in enumerate(val_dataloader):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            vel, vnel = batch_loss(outputs, labels, outputs.shape[0], criterion)
            vloss = criterion(outputs, labels).item()
            val_empty_loss.extend(vel)
            val_nempty_loss.extend(vnel)
            val_total_loss.append(vloss)
    epochs_val_empty_loss.append(round(np.mean(np.array(val_empty_loss)), 5))
    epochs_val_nempty_loss.append(round(np.mean(np.array(val_nempty_loss)), 5))
    epochs_val_total_loss.append(round(np.mean(np.array(val_total_loss)), 5))

    # save model state if its better than the others
    if np.mean(np.array(val_total_loss)) < min_val_loss:
      torch.save(model.state_dict(), path+'/model.pth')
      min_val_loss = np.mean(np.array(val_total_loss))

    # save a predicted image once per epoch
    if epoch == start_epoch:
      u = 0
      for validation_features, validation_targets in val_dataloader:
          u+=1
          if u==14:
                break
      validation_features = validation_features.to(device)
      save_image(validation_targets[4], path+'/validation_target_nonempty.png')
      save_image(validation_features[4], path+'/validation_feature_nonempty.png')
      save_image(validation_targets[0], path+'/validation_target_empty.png')
      save_image(validation_features[0], path+'/validation_feature_empty.png')
      del validation_targets
    image_name_nempty = f'/validation_notempty_prediction_{epoch}.png'
    image_name_empty = f'/validation_empty_prediction_{epoch}.png'
    predictions = s(model(validation_features))
    save_image(predictions[4], path+image_name_nempty)
    save_image(predictions[0], path+image_name_empty)
    

  # save training and validation errors as csv
  d = {}
  d['tr_empty'] = epochs_tr_empty_loss
  d['tr_not_empty'] = epochs_tr_nempty_loss
  d['tr_total'] = epochs_tr_total_loss
  d['val_empty'] = epochs_val_empty_loss
  d['val_not_empty'] = epochs_val_nempty_loss
  d['val_total'] = epochs_val_total_loss
  df = pd.DataFrame(d)
  df.to_csv(path+'/losses.txt')

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [13]:
# criterion, optimizer, model => call train function
dice = diceloss()
bce = nn.BCEWithLogitsLoss()
model = UNet(batchnorm = True, initialization = True)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.to(device)
train(p_trdata = 1.0, transforms = True, model = model, criterion = bce, epochs = (0,20), crop=False, balanced=True, seed = 1234)


train
val
initialization
make folders
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19


In [40]:


class Autoencoder(nn.Module):
  def __init__(self, batchnorm=False, initialization=False):
    super(Autoencoder, self).__init__()
    self.batchnorm = batchnorm
    self.initialization = initialization
    self.name = f'Autoencoder_batchnorm_{self.batchnorm}_initialization_{self.initialization}'
    self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
    self.down_conv1 = nn.Conv2d(1,64,kernel_size=3)
    self.norm_1 = nn.BatchNorm2d(64)
    self.down_conv2 = nn.Conv2d(64,128,kernel_size=3)
    self.norm_2 = nn.BatchNorm2d(128)
    self.down_conv3 = nn.Conv2d(128,256,kernel_size=3)
    self.norm_3 = nn.BatchNorm2d(256)
    self.up_conv1 = nn.ConvTranspose2d(256,128,kernel_size=3)
    self.up_conv2 = nn.ConvTranspose2d(128,64,kernel_size=3)
    self.up_conv3 = nn.ConvTranspose2d(64,1,kernel_size=3)
    self.out = nn.Sigmoid()
    self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)
    self.conv = nn.Conv2d(256,256,kernel_size=3)


  def forward(self,image):
    if self.batchnorm:
      #encoder
      x = self.down_conv1(image)
      x = F.relu(x)
      x = self.norm_1(x)
      x = self.max_pool(x)
      #print(x1.size())
      #print(x2.size())
      x = self.down_conv2(x)
      x = F.relu(x)
      x = self.norm_2(x)
      x = self.max_pool(x)
      #print(x3.size())
      #print(x4.size())
      x = self.down_conv3(x)
      x = F.relu(x)
      x = self.norm_3(x)
      x = self.max_pool(x)

      #decoder
      x = self.up_conv1(x)
      x = F.relu(x)
      x = self.norm_2(x)
      x = self.up_conv2(x)
      x = F.relu(x)
      x = self.norm_1(x) 
      x = self.up_conv3(x)
      print(x.size())
      
      x = self.out(x)
      return x

    else:
       #encoder
      x = self.down_conv1(image)
      x = F.relu(x)
      x = self.max_pool(x)
      print(x.size())
      #print(x2.size())
      x = self.down_conv2(x)
      x = F.relu(x)
      x = self.max_pool(x)
      print(x.size())
      #print(x4.size())
      x = self.down_conv3(x)
      x = F.relu(x)
      x = self.max_pool(x)
      print(x.size())

      x = self.conv(x)
      x = F.relu(x)
      x = self.upsample(x)
      #decoder
      x = self.up_conv1(x)
      x = F.relu(x)
      x = self.upsample(x)
  
      print(x.size())
      x = self.up_conv2(x)
      x = F.relu(x)
      x = self.upsample(x)
  
      print(x.size()) 
      x = self.up_conv3(x)
      print(x.size())
      x = self.out(x)

      return x


In [72]:
from torch.nn.modules import padding
class SegNet(nn.Module):
  def __init__(self, batchnorm = False, initialization = False):
    super(SegNet, self).__init__()
    self.batchnorm = batchnorm
    self.initialization = initialization
    self.name = f'SegNet_batchnorm_{self.batchnorm}_initialization_{self.initialization}'
    self.MaxEn = nn.MaxPool2d(2,2,return_indices=True)
    self.ConvEn11 = nn.Conv2d(1,64,kernel_size=3,padding=1)
    self.BNEn11 = nn.BatchNorm2d(64,momentum=0.5)
    self.ConvEn12 = nn.Conv2d(64,64,kernel_size=3, padding=1)
    self.BNEn12 = nn.BatchNorm2d(64,momentum=0.5)

    self.ConvEn21 = nn.Conv2d(64,128,kernel_size=3,padding=1)
    self.BNEn21 = nn.BatchNorm2d(128,momentum=0.5)
    self.ConvEn22 = nn.Conv2d(128,128,kernel_size=3, padding=1)
    self.BNEn22 = nn.BatchNorm2d(128,momentum=0.5)

    self.ConvEn31 = nn.Conv2d(128,256,kernel_size=3,padding=1)
    self.BNEn31 = nn.BatchNorm2d(256,momentum=0.5)
    self.ConvEn32 = nn.Conv2d(256,256,kernel_size=3, padding=1)
    self.BNEn32 = nn.BatchNorm2d(256,momentum=0.5)
    self.ConvEn33 = nn.Conv2d(256,256,kernel_size=3,padding=1)
    self.BNEn33 = nn.BatchNorm2d(256,momentum=0.5)

    self.ConvEn41 = nn.Conv2d(256,512,kernel_size=3,padding=1)
    self.BNEn41 = nn.BatchNorm2d(512,momentum=0.5)
    self.ConvEn42 = nn.Conv2d(512,512,kernel_size=3, padding=1)
    self.BNEn42 = nn.BatchNorm2d(512,momentum=0.5)
    self.ConvEn43 = nn.Conv2d(512,512,kernel_size=3,padding=1)
    self.BNEn43 = nn.BatchNorm2d(512,momentum=0.5)

    self.ConvEn51 = nn.Conv2d(512,512,kernel_size=3,padding=1)
    self.BNEn51 = nn.BatchNorm2d(512,momentum=0.5)
    self.ConvEn52 = nn.Conv2d(512,512,kernel_size=3, padding=1)
    self.BNEn52 = nn.BatchNorm2d(512,momentum=0.5)
    self.ConvEn53 = nn.Conv2d(512,512,kernel_size=3,padding=1)
    self.BNEn53 = nn.BatchNorm2d(512,momentum=0.5)

    #Decoding
    self.MaxDe = nn.MaxUnpool2d(2,2)
    
    self.ConvDe53 = nn.Conv2d(512,512,kernel_size=3,padding=1)
    self.BNDe53 = nn.BatchNorm2d(512,momentum=0.5)
    self.ConvDe52 = nn.Conv2d(512,512,kernel_size=3, padding=1)
    self.BNDe52 = nn.BatchNorm2d(512,momentum=0.5)
    self.ConvDe51 = nn.Conv2d(512,512,kernel_size=3,padding=1)
    self.BNDe51 = nn.BatchNorm2d(512,momentum=0.5)

    self.ConvDe43 = nn.Conv2d(512,512,kernel_size=3,padding=1)
    self.BNDe43 = nn.BatchNorm2d(512,momentum=0.5)
    self.ConvDe42 = nn.Conv2d(512,512,kernel_size=3, padding=1)
    self.BNDe42 = nn.BatchNorm2d(512,momentum=0.5)
    self.ConvDe41 = nn.Conv2d(512,256,kernel_size=3,padding=1)
    self.BNDe41 = nn.BatchNorm2d(256,momentum=0.5)

    self.ConvDe33 = nn.Conv2d(256,256,kernel_size=3,padding=1)
    self.BNDe33 = nn.BatchNorm2d(256,momentum=0.5)
    self.ConvDe32 = nn.Conv2d(256,256,kernel_size=3, padding=1)
    self.BNDe32 = nn.BatchNorm2d(256,momentum=0.5)
    self.ConvDe31 = nn.Conv2d(256,128,kernel_size=3,padding=1)
    self.BNDe31 = nn.BatchNorm2d(128,momentum=0.5)

    self.ConvDe22 = nn.Conv2d(128,128,kernel_size=3, padding=1)
    self.BNDe22 = nn.BatchNorm2d(128,momentum=0.5)
    self.ConvDe21 = nn.Conv2d(128,64,kernel_size=3,padding=1)
    self.BNDe21 = nn.BatchNorm2d(64,momentum=0.5)

    self.ConvDe12 = nn.Conv2d(64,64,kernel_size=3, padding=1)
    self.BNDe12 = nn.BatchNorm2d(64,momentum=0.5)
    self.ConvDe11 = nn.Conv2d(64,1,kernel_size=3,padding=1)
    self.BNDe11 = nn.BatchNorm2d(1,momentum=0.5)
  
  def forward(self,x):
    if self.batchnorm:
      x = F.relu(self.BNEn11(self.ConvEn11(x)))
      x = F.relu(self.BNEn12(self.ConvEn12(x)))
      x, ind1 = self.MaxEn(x)
      size1 = x.size()

      x = F.relu(self.BNEn21(self.ConvEn21(x)))
      x = F.relu(self.BNEn22(self.ConvEn22(x)))
      x, ind2 = self.MaxEn(x)
      size2 = x.size()

      x = F.relu(self.BNEn31(self.ConvEn31(x)))
      x = F.relu(self.BNEn32(self.ConvEn32(x)))
      x = F.relu(self.BNEn33(self.ConvEn33(x)))
      x, ind3 = self.MaxEn(x)
      size3 = x.size()

      x = F.relu(self.BNEn41(self.ConvEn41(x)))
      x = F.relu(self.BNEn42(self.ConvEn42(x)))
      x = F.relu(self.BNEn43(self.ConvEn43(x)))
      x, ind4 = self.MaxEn(x)
      size4 = x.size()

      x = F.relu(self.BNEn51(self.ConvEn51(x)))
      x = F.relu(self.BNEn52(self.ConvEn52(x)))
      x = F.relu(self.BNEn53(self.ConvEn53(x)))
      x, ind5 = self.MaxEn(x)
      size5 = x.size()

      #Decoder
      x = self.MaxDe(x, ind5, output_size=size4)
      x = F.relu(self.BNDe53(self.ConvDe53(x)))
      x = F.relu(self.BNDe52(self.ConvDe52(x)))
      x = F.relu(self.BNDe51(self.ConvDe51(x)))

      x = self.MaxDe(x, ind4, output_size=size3)
      x = F.relu(self.BNDe43(self.ConvDe43(x)))
      x = F.relu(self.BNDe42(self.ConvDe42(x)))
      x = F.relu(self.BNDe41(self.ConvDe41(x)))

      x = self.MaxDe(x, ind3, output_size=size2)
      x = F.relu(self.BNDe33(self.ConvDe33(x)))
      x = F.relu(self.BNDe32(self.ConvDe32(x)))
      x = F.relu(self.BNDe31(self.ConvDe31(x)))

      x = self.MaxDe(x, ind2, output_size=size1)
      x = F.relu(self.BNDe22(self.ConvDe22(x)))
      x = F.relu(self.BNDe21(self.ConvDe21(x)))

      x = self.MaxDe(x, ind1)
      x = F.relu(self.BNDe12(self.ConvDe12(x)))
      x = self.ConvDe11(x)
      return x
    
    else:
      x = F.relu(self.ConvEn11(x))
      x = F.relu(self.ConvEn12(x))
      x, ind1 = self.MaxEn(x)
      size1 = x.size()

      x = F.relu(self.ConvEn21(x))
      x = F.relu(self.ConvEn22(x))
      x, ind2 = self.MaxEn(x)
      size2 = x.size()

      x = F.relu(self.ConvEn31(x))
      x = F.relu(self.ConvEn32(x))
      x = F.relu(self.ConvEn33(x))
      x, ind3 = self.MaxEn(x)
      size3 = x.size()

      x = F.relu(self.ConvEn41(x))
      x = F.relu(self.ConvEn42(x))
      x = F.relu(self.ConvEn43(x))
      x, ind4 = self.MaxEn(x)
      size4 = x.size()

      x = F.relu(self.ConvEn51(x))
      x = F.relu(self.ConvEn52(x))
      x = F.relu(self.ConvEn53(x))
      x, ind5 = self.MaxEn(x)
      size5 = x.size()

      #Decoder
      x = self.MaxDe(x, ind5, output_size=size4)
      x = F.relu(self.ConvDe53(x))
      x = F.relu(self.ConvDe52(x))
      x = F.relu(self.ConvDe51(x))

      x = self.MaxDe(x, ind4, output_size=size3)
      x = F.relu(self.ConvDe43(x))
      x = F.relu(self.ConvDe42(x))
      x = F.relu(self.ConvDe41(x))

      x = self.MaxDe(x, ind3, output_size=size2)
      x = F.relu(self.ConvDe33(x))
      x = F.relu(self.ConvDe32(x))
      x = F.relu(self.ConvDe31(x))

      x = self.MaxDe(x, ind2, output_size=size1)
      x = F.relu(self.ConvDe22(x))
      x = F.relu(self.ConvDe21(x))

      x = self.MaxDe(x, ind1)
      x = F.relu(self.ConvDe12(x))
      x = self.ConvDe11(x)
      return x


In [74]:
image = torch.rand((1,1,232,196))
print(image)
model = SegNet()
model(image)

tensor([[[[0.0583, 0.7006, 0.0518,  ..., 0.4546, 0.3720, 0.8920],
          [0.3819, 0.8610, 0.2775,  ..., 0.0515, 0.3947, 0.6595],
          [0.4078, 0.3445, 0.0616,  ..., 0.6648, 0.4819, 0.5768],
          ...,
          [0.2027, 0.6378, 0.6171,  ..., 0.5561, 0.0702, 0.9071],
          [0.2896, 0.4228, 0.0903,  ..., 0.7746, 0.5915, 0.4330],
          [0.1411, 0.4396, 0.0681,  ..., 0.7043, 0.6479, 0.4425]]]])




tensor([[[[-0.0511, -0.0540, -0.0558,  ..., -0.0561, -0.0553, -0.0514],
          [-0.0523, -0.0545, -0.0560,  ..., -0.0562, -0.0552, -0.0483],
          [-0.0505, -0.0525, -0.0558,  ..., -0.0565, -0.0554, -0.0490],
          ...,
          [-0.0507, -0.0555, -0.0577,  ..., -0.0547, -0.0563, -0.0485],
          [-0.0508, -0.0549, -0.0536,  ..., -0.0557, -0.0564, -0.0498],
          [-0.0433, -0.0470, -0.0468,  ..., -0.0457, -0.0457, -0.0436]]]],
       grad_fn=<ConvolutionBackward0>)

In [None]:
# criterion, optimizer, model => call train function
dice = diceloss()
bce = nn.BCEWithLogitsLoss()
model = SegNet(batchnorm = True, initialization = True)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.to(device)
train(p_trdata = 0.5, transforms = True, model = model, criterion = bce, epochs = (0,40), crop=True, balanced=True, seed = 1234)


train
val
initialization
make folders
0




1
2
3
4
5
