In [2]:
# Packages
import numpy as np
import os
import torch
from torchvision import models
from torchvision import datasets
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt


In [None]:
# Data loading
dataPath = '/blue/eel6935/guzikjar/HW2/LungImages'
maskPath = 'LungMasks'

In [3]:
# Classes

class downNet1(nn.Module):
    def __init__(self):
        super(downNet1, self).__init__()

        # First 2D convolutional layer, taking in 1 input channel (image),
        # outputting 64 convolutional features, with a square kernel size of 3
        self.conv1 = nn.Conv2d(1, 64, 3, 1)

        # Activation Function
        self.act1 = nn.ReLU()

        # Second 2D convolutional layer, taking in the 64 input layers,
        # outputting 64 convolutional features, with a square kernel size of 3
        self.conv2 = nn.Conv2d(64, 64, 3, 1)

        # Activation Function
        self.act2 = nn.ReLU()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.act2(x)
        output = x
        return output

    
class downNet2(nn.Module):
    def __init__(self):
        super(downNet2, self).__init__()
        
        self.conv1 = nn.Conv2d(64, 128, 3, 1)
        
        self.act1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(128, 128, 3, 1)
        
        self.act2 = nn.ReLU()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.act2(x)
        output = x
        return output
        
        
class downNet3(nn.Module):
    def __init__(self):
        super(downNet3, self).__init__()
        
        self.conv1 = nn.Conv2d(128, 256, 3, 1)
        
        self.act1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(256, 256, 3, 1)
        
        self.act2 = nn.ReLU()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.act2(x)
        output = x
        return output
        
        
class downNet4(nn.Module):
    def __init__(self):
        super(downNet4, self).__init__()
        
        self.conv1 = nn.Conv2d(256, 512, 3, 1)
        
        self.act1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(512, 512, 3, 1)
        
        self.act2 = nn.ReLU()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.act2(x)
        output = x
        return output
        
class upNet1(nn.Module):
    def __init__(self):
        super(upNet1, self).__init__()

        self.conv1 = nn.Conv2d(1024, 512, 3, 1)

        self.act1 = nn.ReLU()

        self.conv2 = nn.Conv2d(512, 512, 3, 1)

        self.act2 = nn.ReLU()
        
        self.upSamp = nn.ConvTranspose2d(512, 256, 2, 2)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.act2(x)
        x = self.upSamp(x)
        output = x
        return output

    
class upNet2(nn.Module):
    def __init__(self):
        super(upNet2, self).__init__()
        
        self.conv1 = nn.Conv2d(512, 256, 3, 1)
        
        self.act1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(256, 256, 3, 1)
        
        self.act2 = nn.ReLU()
        
        self.upSamp = nn.ConvTranspose2d(256, 128, 2, 2)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.act2(x)
        x = self.upSamp(x)
        output = x
        return output
        
        
class upNet3(nn.Module):
    def __init__(self):
        super(upNet3, self).__init__()
        
        self.conv1 = nn.Conv2d(256, 128, 3, 1)
        
        self.act1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(128, 128, 3, 1)
        
        self.act2 = nn.ReLU()
        
        self.upSamp = nn.ConvTranspose2d(128, 64, 2, 2)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.act2(x)
        x = self.upSamp(x)
        output = x
        return output
        
        
class upNet4(nn.Module):
    def __init__(self):
        super(upNet4, self).__init__()
        
        self.conv1 = nn.Conv2d(128, 64, 3, 1)
        
        self.act1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(64, 64, 3, 1)
        
        self.act2 = nn.ReLU()
        
        # 2 classes, so final output should be 2
        self.conv3 = nn.Conv2d(64, 2, 1, 1) 
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.act2(x)
        x = self.conv3(x)
        output = x
        return output
        
        
class bottleNeck(nn.Module):
    def __init__(self):
        super(bottleNeck, self).__init__()
        
        self.conv1 = nn.Conv2d(512, 1024, 3, 1)
        
        self.act1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(1024, 1024, 3, 1)
        
        self.act2 = nn.ReLU()
        
        self.upSamp = nn.ConvTranspose2d(1024, 512, 2, 2)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.act2(x)
        x = self.upSamp(x)
        output = x
        return output
        
     
class pool(nn.Module):
    def __init__(self):
        super(pool, self).__init__()

        self.pooling = nn.MaxPool2d(2, 2)
        
    def forward(self, x):
        x = self.pooling(x)
        output = x
        return output
        
       
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        
        # Encoder (Downsampling)
        self.down1 = downNet1()
        self.pool1 = pool()
        
        self.down2 = downNet2()
        self.pool2 = pool()
        
        self.down3 = downNet3()
        self.pool3 = pool()
        
        self.down4 = downNet4()
        self.pool4 = pool()
        
        # Bottleneck
        self.bottleneck = bottleNeck()
        
        # Decoder (Upsampling)
        self.up1 = upNet1()
        self.up2 = upNet2()
        self.up3 = upNet3()
        self.up4 = upNet4()
        
    def copy_and_crop(self, upsampled, skip):
        # Crop the skip connection to match the upsampled dimensions
        _, _, H, W = upsampled.size()
        skip = skip[:, :, :H, :W]  # Crop the skip connection
        return torch.cat((upsampled, skip), dim=1) 
        
    def forward(self, x):
        # Downsampling path
        enc1 = self.down1(x)
        enc1_pool = self.pool1(enc1)
        
        enc2 = self.down2(enc1_pool)
        enc2_pool = self.pool2(enc2)
        
        enc3 = self.down3(enc2_pool)
        enc3_pool = self.pool3(enc3)
        
        enc4 = self.down4(enc3_pool)
        enc4_pool = self.pool4(enc4)
        
        # Bottleneck
        bottleneck_out = self.bottleneck(enc4_pool)
        
        # Upsampling path
        dec1 = self.up1(bottleneck_out)
        dec1 = self.copy_and_crop(dec1, enc4)  # Skip connection
        
        dec2 = self.up2(dec1)
        dec2 = self.copy_and_crop(dec2, enc3)  # Skip connection
        
        dec3 = self.up3(dec2)
        dec3 = self.copy_and_crop(dec3, enc2)  # Skip connection
        
        dec4 = self.up4(dec3)
        dec4 = self.copy_and_crop(dec4, enc1)  # Skip connection
        
        return dec4
    
           

my_nn = UNet()
print(my_nn)

UNet(
  (down1): downNet1(
    (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
    (act1): ReLU()
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (act2): ReLU()
  )
  (pool1): pool(
    (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (down2): downNet2(
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (act1): ReLU()
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (act2): ReLU()
  )
  (pool2): pool(
    (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (down3): downNet3(
    (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
    (act1): ReLU()
    (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
    (act2): ReLU()
  )
  (pool3): pool(
    (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (down4): downNet4(
    (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
  

In [6]:
# Run a test
random_data = torch.rand((1, 1, 572, 572))

result = my_nn(random_data)
print (result)

RuntimeError: Given groups=1, weight of size [512, 1024, 3, 3], expected input[1, 512, 56, 56] to have 1024 channels, but got 512 channels instead