<a href="https://colab.research.google.com/github/DanielBugelnig/U-Net/blob/arthur/u_net_updated_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [3]:
# Creating U-Net architecture

# page 4 of the paper: (https://arxiv.org/pdf/1505.04597.pdf)

# Network Architecture
# The network architecture is illustrated in Figure 1. It consists of a contracting
# path (left side) and an expansive path (right side). The contracting path follows
# the typical architecture of a convolutional network. It consists of the repeated
# application of two 3x3 convolutions (unpadded convolutions), each followed by
# a rectified linear unit (ReLU) and a 2x2 max pooling operation with stride 2
# for downsampling. At each downsampling step we double the number of feature
# channels. Every step in the expansive path consists of an upsampling of the
# feature map followed by a 2x2 convolution (“up-convolution”) that halves the
# number of feature channels, a concatenation with the correspondingly cropped
# feature map from the contracting path, and two 3x3 convolutions, each followed by a ReLU. The cropping is necessary due to the loss of border pixels in
# every convolution. At the final layer a 1x1 convolution is used to map each 64-
# component feature vector to the desired number of classes. In total the network
# has 23 convolutional layers.
# To allow a seamless tiling of the output segmentation map (see Figure 2), it
# is important to select the input tile size such that all 2x2 max-pooling operations
# are applied to a layer with an even x- and y-size.

# The reduced output size within a single tile (e.g., 388x388 for a 572x572 input) ensures that the predictions are based on full context,
# avoiding incomplete or invalid segmentations near the borders.




#pytorch libraries
import torch
import torch.nn as nn
import torch.nn.functional as F #for ReLu
import torchvision.transforms.functional as Trans
from torchinfo import summary

class UNet(nn.Module):
    def __init__(self, input_number, output_number):
        super(UNet, self).__init__()
        self.input_number = input_number
        self.output_number = output_number


        # Encoder
        # input: 4d tensor: batch_size x input_number x 572x572 --> input number=1 for grayscale image, 3 for RGB image
        # assuming 572x572 image
        self.conv1 = nn.Conv2d(self.input_number, 64, kernel_size=3, padding=0) # 64x570x570
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=0) # 64x568x568
        self.maxPool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 64x284x284

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=0) # 128x282x282
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=0) # 128x280x280
        self.maxPool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 128x140x140

        self.conv5 = nn.Conv2d(128,256, kernel_size=3, padding=0) # 256x138x138
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=0) # 256x136x136
        self.maxPool3 = nn.MaxPool2d(kernel_size=2, stride=2) # 256x68x68

        self.conv7 = nn.Conv2d(256, 512, kernel_size=3, padding=0) # 512x66x66
        self.conv8 = nn.Conv2d(512, 512, kernel_size=3, padding=0) # 512x64x64
        self.maxPool4 = nn.MaxPool2d(kernel_size=2, stride=2) # 512x32x32

        self.conv9 = nn.Conv2d(512, 1024, kernel_size=3, padding=0) # 1024x30x30
        self.conv10 = nn.Conv2d(1024, 1024, kernel_size=3, padding=0) # 1024x28x28

        # Decoder
        # Upsampling by a factor of 2, --> stride=2, kernel_size=2
        # 2x2 up convolution halves the feature channels

        self.upconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) # 512x56x56 --> output size formular(input) # stride * (input-1) + kernel_size -2 *padding + output_padding = 2*(28-1) + 2 - 2*0 + 0 = 56
        self.conv1b = nn.Conv2d(1024, 512, kernel_size=3, padding=0) # 512x54x54 other 512 features come from encoder site
        self.conv2b = nn.Conv2d(512, 512, kernel_size=3, padding=0) # 512x52x52

        self.upconv2 = nn.ConvTranspose2d(512, 256, 2, 2) # 512x104x104
        self.conv3b = nn.Conv2d(512, 256, 3, padding=0) #256x102x102
        self.conv4b = nn.Conv2d(256, 256, 3, padding=0) # 256x100x100

        self.upconv3 = nn.ConvTranspose2d(256, 128, 2,2,) #256x200x200
        self.conv5b = nn.Conv2d(256, 128, 3, padding=0) #128x198x198
        self.conv6b = nn.Conv2d(128, 128, 3, padding=0) #128x196x196

        self.upconv4 = nn.ConvTranspose2d(128, 64, 2, 2) #128x392x392
        self.conv7b = nn.Conv2d(128, 64, 3, padding=0) #64x390x390
        self.conv8b = nn.Conv2d(64,64,3, padding=0) # 64x388x388
        self.final_conv = nn.Conv2d(64, self.output_number, kernel_size=1, padding=0) #2x388x388

    def cropConcat(self, encoder, decoder):
        # crops the encoder tensor and concatenate its with the decoder tensor
        _,_,H,W = decoder.shape
        cropped_enc = Trans.center_crop(encoder, [H,W]) # crops the encoder tensor in the centre
        return torch.cat((cropped_enc, decoder), dim=1) # concatenates at the feature dimension

    def forward(self, x):
        # Encoder
        x = F.relu(self.conv1(x))
        x1 = F.relu(self.conv2(x))
        x = self.maxPool1(x1)

        x = F.relu(self.conv3(x))
        x2 = F.relu(self.conv4(x))
        x = self.maxPool2(x2)

        x = F.relu(self.conv5(x))
        x3 = F.relu(self.conv6(x))
        x = self.maxPool3(x3)

        x = F.relu(self.conv7(x))
        x4 = F.relu(self.conv8(x))
        x = self.maxPool4(x4)

        x = F.relu(self.conv9(x))
        x = F.relu(self.conv10(x))

        # Decoder
        x = self.upconv1(x)  # size 512x56x56

        x = self.cropConcat(x4, x) # concatination1 size 1024x56x56
        x = F.relu(self.conv1b(x))
        x = F.relu(self.conv2b(x))
        x = self.upconv2(x)

        x = self.cropConcat(x3,x)
        x = F.relu(self.conv3b(x))
        x = F.relu(self.conv4b(x))
        x = self.upconv3(x)

        x = self.cropConcat(x2,x)
        x = F.relu(self.conv5b(x))
        x = F.relu(self.conv6b(x))
        x = self.upconv4(x)

        x = self.cropConcat(x1,x)
        x = F.relu(self.conv7b(x))
        x = F.relu(self.conv8b(x))
        x = self.final_conv(x)
        #x = F.softmax(x,1)
        return x


model = UNet(1,1)
summary(model, input_size=(1, 1, 572, 572))  # Example input size


Layer (type:depth-idx)                   Output Shape              Param #
UNet                                     [1, 1, 388, 388]          --
├─Conv2d: 1-1                            [1, 64, 570, 570]         640
├─Conv2d: 1-2                            [1, 64, 568, 568]         36,928
├─MaxPool2d: 1-3                         [1, 64, 284, 284]         --
├─Conv2d: 1-4                            [1, 128, 282, 282]        73,856
├─Conv2d: 1-5                            [1, 128, 280, 280]        147,584
├─MaxPool2d: 1-6                         [1, 128, 140, 140]        --
├─Conv2d: 1-7                            [1, 256, 138, 138]        295,168
├─Conv2d: 1-8                            [1, 256, 136, 136]        590,080
├─MaxPool2d: 1-9                         [1, 256, 68, 68]          --
├─Conv2d: 1-10                           [1, 512, 66, 66]          1,180,160
├─Conv2d: 1-11                           [1, 512, 64, 64]          2,359,808
├─MaxPool2d: 1-12                        [1, 51

In [None]:
import torch
from torch import optim, nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torchvision.transforms.functional as Trans
from PIL import Image
import matplotlib.pyplot as plt
import random

torch.cuda.empty_cache()
'''
#Loading data
#https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
class Dataset(Dataset):
  def __init__(self, image_path, label_path, transform=None):
    self.images = Image.open(image_path)
    self.labels = Image.open(label_path)
    self.transform = transform

  def __len__(self):
    return self.images.n_frames

  def __getitem__(self, idx):
    #find specific frame
    self.images.seek(idx)
    self.labels.seek(idx)
    #grayscale conversion (if necessary)
    image = self.images.convert("L")
    label = self.labels.convert("L")
    if self.transform:
      image = self.transform(image)
      label = self.transform(label)
    return image, label
'''
#Transforming dataset from Daniel
class Dataset(Dataset):
    def __init__(self, image_path, label_path, transform=True):
        self.images = Image.open(image_path)
        self.labels = Image.open(label_path)
        self.transform = transform

    def __len__(self):
        return self.images.n_frames


    def __getitem__(self, idx):
        # Access specific frame
        self.images.seek(idx)
        self.labels.seek(idx)

        # Convert to grayscale
        image = self.images.convert("L")
        label = self.labels.convert("L")

        if transforms:

          # Random horizontal flipping
          if random.random() > 0.5:
              image = Trans.hflip(image)
              label = Trans.hflip(label)

        # Random vertical flipping
          if random.random() > 0.5:
              image = Trans.vflip(image)
              label = Trans.vflip(label)

        # Transform to tensor
        image = Trans.to_tensor(image)
        label = Trans.to_tensor(label)

        return image, label

transform = transforms.Compose([transforms.ToTensor()])

#original
#train_dataset = Dataset("ISBI-2012-challenge/train-volume.tif", "ISBI-2012-challenge/train-labels.tif", transform)
#test_dataset = Dataset("ISBI-2012-challenge/test-volume.tif", "ISBI-2012-challenge/test-labels.tif", transform)

#mirrored
train_dataset = Dataset("ISBI-2012-challenge/train-mirror.tif", "ISBI-2012-challenge/train-labels.tif", transform)
test_dataset = Dataset("ISBI-2012-challenge/test-mirror.tif", "ISBI-2012-challenge/test-labels.tif", transform)


train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

#print(train_dataset.images.n_frames)

#print(train_dataset.shape)

#Training - pick optimizer and learning rate
#https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(1,1).to(device)
#optimizer= optim.Adam(model.parameters(),0.001)
optimizer = optim.SGD(model.parameters(),0.0001,momentum=0.99)
#criterion = nn.CrossEntropyLoss() #for some reason doesn't work
criterion = nn.BCEWithLogitsLoss()


def train(model, dataloader, criterion, optimizer, nrOfEpochs):
  model.train()
  for i in range(nrOfEpochs):
    running_loss = 0.0
    avgAcc=0.0
    avgPrec=0.0
    avgRec=0.0
    avgF1=0.0
    for images, labels in dataloader:
      images, labels = images.to(device), labels.to(device);
      #print(images.shape)
      optimizer.zero_grad()
      outputs=model(images)
      #outputs = (outputs > 0.5).float()
      labels = Trans.center_crop(labels, [388,388])
      #assert outputs.shape == labels.shape, f"Shape mismatch: {outputs.shape} vs {labels.shape}"
      loss = criterion(outputs,labels)
      loss.backward()
      optimizer.step()
      running_loss += loss.item()
      if(i==nrOfEpochs-1):
        for j in range(dataloader.batch_size):
          outputs = (outputs > 0.5).int()
          labels = labels.int()
          TP = ((labels[j])*(outputs[j])).sum()
          TN = ((1-labels[j])*(1-outputs[j])).sum()
          FP = ((1-labels[j])*(outputs[j])).sum()
          FN = ((labels[j])*(1-outputs[j])).sum()

          accuracy = (TP+TN)/(TP+TN+FP+FN)
          precision = TP/(TP+FP)
          recall = TP/(TP+FN)
          f1 = 2*(precision*recall)/(precision+recall)
          '''
          print(f"Accuracy:{accuracy}")
          print(f"Precision:{precision}")
          print(f"Recall:{recall}")
          print(f"F1 Score:{f1}\n")
          '''
          avgAcc+=accuracy.cpu().item()
          avgPrec+=precision.cpu().item()
          avgRec+=recall.cpu().item()
          avgF1+=f1.cpu().item()

    '''
    #for visualization of the training data
    for j in range(dataloader.batch_size):
        plt.figure()
        plt.subplot(2,2,1)
        plt.imshow(images[j].cpu().numpy().squeeze(), cmap='viridis')
        plt.subplot(2,2,2)
        plt.imshow(labels[j].cpu().numpy().squeeze(), cmap='viridis')
    '''
    train_loss = running_loss/len(dataloader)
    print(f"Epoch {i+1}/{nrOfEpochs}\nLoss:{train_loss}")
  return train_loss, avgAcc/len(dataloader), avgPrec/len(dataloader), avgRec/len(dataloader), avgF1/len(dataloader)


#Evaluation
def test(model, dataloader, criterion):
  model.eval()
  running_loss = 0.0
  avgAcc=0.0
  avgPrec=0.0
  avgRec=0.0
  avgF1=0.0
  plotImage=0
  with torch.no_grad():
    for images, labels in dataloader:
      images, labels = images.to(device), labels.to(device);
      outputs = model(images)
      labels = Trans.center_crop(labels, [388,388])
      loss = criterion(outputs,labels)
      running_loss+=loss.item()
      outputs = (outputs > 0.5).int()
      #visualisation and evaluation
      for j in range(dataloader.batch_size):
        if(plotImage==6): #for picking single image to plot
          plt.figure()
          plt.subplot(2,2,1)
          plt.title(f"Image")
          plt.imshow(images[j].cpu().numpy().squeeze(), cmap='viridis')
          plt.subplot(2,2,2)
          plt.title(f"Ground Truth")
          plt.imshow(labels[j].cpu().numpy().squeeze(), cmap='viridis')
          plt.subplot(2,2,3)
          plt.title(f"Prediction")
          #outputs[j]=(outputs[j]>0.5).int()
          plt.imshow(outputs[j].cpu().detach().numpy().squeeze(), cmap='viridis')
          plt.subplot(2,2,4)
          plt.title(f"Difference")
          diff=(outputs[j]!=labels[j]).int()
          plt.imshow(diff.cpu().detach().numpy().squeeze(),cmap='viridis')
          plt.tight_layout()
          plt.show()

        #Output image to binary values
        #outputs[j] = (outputs > 0.5).float()
        labels = labels.int()
        TP = ((labels[j])*(outputs[j])).sum()
        TN = ((1-labels[j])*(1-outputs[j])).sum()
        FP = ((1-labels[j])*(outputs[j])).sum()
        FN = ((labels[j])*(1-outputs[j])).sum()

        accuracy = (TP+TN)/(TP+TN+FP+FN)
        precision = TP/(TP+FP)
        recall = TP/(TP+FN)
        f1 = 2*(precision*recall)/(precision+recall)
        '''
        print(f"Accuracy:{accuracy}")
        print(f"Precision:{precision}")
        print(f"Recall:{recall}")
        print(f"F1 Score:{f1}\n")
        '''
        avgAcc+=accuracy.cpu().item()
        avgPrec+=precision.cpu().item()
        avgRec+=recall.cpu().item()
        avgF1+=f1.cpu().item()

        plotImage+=1

  print(f"Avg Accuracy:{avgAcc/len(dataloader)}")
  print(f"Avg Precision:{avgPrec/len(dataloader)}")
  print(f"Avg Recall:{avgRec/len(dataloader)}")
  print(f"Avg F1 Score:{avgF1/len(dataloader)}\n")

  test_loss = running_loss/len(dataloader)
  print(f"Test loss:{test_loss}")
  return test_loss, avgAcc/len(dataloader), avgPrec/len(dataloader), avgRec/len(dataloader), avgF1/len(dataloader)

#10x number of epochs - DECIDES NR OF EPOCHS
nrEpx10 = 1

x = [0]*nrEpx10
trainLoss = [0]*nrEpx10
trainAcc = [0]*nrEpx10
trainPrec = [0]*nrEpx10
trainRec = [0]*nrEpx10
trainF1 = [0]*nrEpx10
testLoss = [0]*nrEpx10
testAcc = [0]*nrEpx10
testPrec = [0]*nrEpx10
testRec = [0]*nrEpx10
testF1 = [0]*nrEpx10


#torch.load('unet_50ep_lr0001_mirr.pth', map_location=torch.device('cpu'))

#Running code:
for i in range(nrEpx10):
  print(f"\nLoop {i+1}\n")
  x[i] = 10*(i+1)
  trainLoss[i], trainAcc[i], trainPrec[i], trainRec[i], trainF1[i] = train(model, train_loader, criterion, optimizer, 10)
  testLoss[i], testAcc[i], testPrec[i], testRec[i], testF1[i] = test(model, test_loader, criterion)
  #update frame

#Plot acc, prec, rec, f1 - add "-o" for dots
fig = plt.figure(dpi=1200)
#plt.title('Metrics over Epochs')
plt.subplot(2,3,1)
#plt.title(f"Loss")
plt.plot(x, trainLoss, label='Train Loss')
plt.plot(x, testLoss, label='Test Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.subplot(2,3,2)
#plt.title(f"Accuracy")
plt.plot(x, trainAcc, label='Train Accuracy')
plt.plot(x, testAcc, label='Test Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.subplot(2,3,3)
#plt.title(f"Precision")
plt.plot(x, trainPrec, label='Train Precision')
plt.plot(x, testPrec, label='Test Precision')
plt.xlabel('Epochs')
plt.ylabel('Precision')
plt.subplot(2,3,4)
#plt.title(f"Recall")
plt.plot(x, trainRec, label='Train Recall')
plt.plot(x, testRec, label='Test Recall')
plt.xlabel('Epochs')
plt.ylabel('Recall')
plt.subplot(2,3,5)
#plt.title(f"F1")
#plt.plot(x, trainF1, "-o", label='Train F1 Score')
#plt.plot(x, testF1, "-o", label='Test F1 Score')
plt.plot(x, trainF1 , label='Train')
plt.plot(x, testF1, label='Test')
plt.xlabel('Epochs')
plt.ylabel('F1')
#plt.legend()
handles, labels = plt.gca().get_legend_handles_labels()
fig.legend(handles, labels, loc='lower right')
plt.tight_layout()
plt.show()


#Save model - change name to: unet _ nr of epochs _ learning rate 0.xxx as lrxxx _ which dataset
#torch.save(model.state_dict(), 'unet_5000ep_lr0001_mirr_transf_SGD.pth')



Loop 1

Epoch 1/10
Loss:0.7191595176855723
Epoch 2/10
Loss:0.7093304673830668
