<a href="https://colab.research.google.com/github/SeanSDarcy2001/CISProgrammingAssignments/blob/main/downstreamTasks/lung/lungNoduleUNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

directory = "gdrive/My Drive/deepDRR_3dseg"


Mounted at /content/gdrive


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms 
from torch.utils.data import Dataset, DataLoader
from skimage import io
import os
import numpy as np


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    #might need to add padding here?
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.down_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.down_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=3, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        x = self.conv(x)
        return self.sig(x)

In [3]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 16)
        self.down1 = Down(16, 32)
        self.down2 = Down(32, 64)
        self.down3 = Down(64, 128)
        factor = 2 if bilinear else 1
        self.down4 = Down(128, 256 // factor)
        self.up1 = Up(256, 128 // factor, bilinear)
        self.up2 = Up(128, 64 // factor, bilinear)
        self.up3 = Up(64, 32 // factor, bilinear)
        self.up4 = Up(32, 16, bilinear)
        self.outc = OutConv(16, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [8]:
## Image Transforms
img_transform = transforms.Compose([
        transforms.ToTensor(),
])


## Image Dataloader
class ImageDataset(Dataset):
    
    """
    ImageDataset
    """
    
    def __init__(self,
                 input_dir,
                 transforms=None):
        """
        Args:
            input_dir (str): Path to either colorization or segmentation directory
            op (str): One of "train", "val", or "test" signifying the desired split
            mask_json_path (str): Path to mapping.json file
            transforms (list or None): Image transformations to apply upon loading.
        """
        self.transform = transforms
        self.data_dir = input_dir
        self.cases = os.listdir(os.path.join(self.data_dir, "images"))
    

    def __len__(self):
        """
        
        """
        return len(next(os.walk(self.data_dir))[1])

    def __getitem__(self,
                    idx):
        """
        
        """
        case = self.cases[idx - 1]

        img = io.imread(os.path.join(self.data_dir, "images", case))
        mask = io.imread(os.path.join(self.data_dir, "masks", case))
      
        ## Transform image and mask
        if self.transform:
            img, mask = self.img_transform(img, mask)

        #mask = F.one_hot(mask, num_classes = 2)
      
        return img, mask

    def img_transform(self,
                      img,
                      mask):
        """
        
        """
        ## Apply Transformations to Image and Mask
        img = self.transform(img)
        mask = self.transform(mask)
        return img, mask

In [5]:
def dice_score_image(prediction, target, n_classes):
    '''
      computer the mean dice score for a single image

      Reminders: A false positive is a result that indicates a given condition exists, when it does not
               A false negative is a test result that indicates that a condition does not hold, while in fact it does
      Args:
          prediction (tensor): predictied labels of the image
          target (tensor): ground truth of the image
          n_classes (int): number of classes
    
      Returns:
          m_dice (float): Mean dice score over classes
    '''

    smooth = 1
    prediction = prediction[0]
    target=target[0]
    prediction = prediction.view(-1, prediction.size(0) * prediction.size(1)).cpu().numpy() 
    dice_classes = np.zeros(n_classes)

    for cl in range(n_classes): 
      label = target[cl]
      label = label.view(-1, label.size(0) * label.size(1)).cpu().numpy() 
      TP = 0
      FP = 0
      FN = 0
      for i in range(len(label[0][:])):
        if label[0][i] == 1 and prediction[0][i] == cl: 
          TP += 1 
        elif label[0][i] == 0 and prediction[0][i] == cl:
          FP += 1 
        elif label[0][i] == 1 and prediction[0][i] != cl:
          FN += 1 
      #When there is no grount truth of the class in this image
      #Give 1 dice score if False Positive pixel number is 0,
      #give 0 dice score if False Positive pixel number is not 0 (> 0). 
      if (TP + FN == 0):
        if FP == 0: 
          dice_classes[cl] = 1
        else: 
          dice_classes[cl] = 0 
      else:
        dice_classes[cl] = (2*TP+smooth)/(2*TP+FN+FP+smooth) 
    return dice_classes.mean()


def dice_score_dataset(model, dataloader, num_classes, use_gpu=False):
    """
    Compute the mean dice score on a set of data.
    
    Note that multiclass dice score can be defined as the mean over classes of binary
    dice score. Dice score is computed per image. Mean dice score over the dataset is the dice
    score averaged across all images.
    
    Reminders: A false positive is a result that indicates a given condition exists, when it does not
               A false negative is a test result that indicates that a condition does not hold, while in fact it does
     
    Args:
        model (UNET class): Your trained model
        dataloader (DataLoader): Dataset for evaluation
        num_classes (int): Number of classes
    
    Returns:
        m_dice (float): Mean dice score over the input dataset
    """
    ## Number of Batches and Cache over Dataset 
    n_batches = len(dataloader)
    scores = np.zeros(n_batches)
    ## Evaluate
    model.eval()
    idx = 0
    for data in dataloader:
        ## Format Data
        img, target = data
        if use_gpu:
            img = img.cuda()
            target = target.cuda()
        ## Make Predictions
        out = model(img)
        #print(n_classes)
        prediction = torch.argmax(out, dim = 1)
        #print(prediction[0])
        scores[idx] = dice_score_image(prediction, target, n_classes)
        idx += 1
    ## Average Dice Score Over Images
    m_dice = scores.mean()
    return m_dice

class DICELoss(nn.Module):
    def __init__(self):
      super(DICELoss, self).__init__()
    
    def forward(self, out, target):
        smoothing = 1
        intersect = (out * target).sum(dim = 2).sum(dim = 2)
        diceScore = (2.* intersect + smoothing) / (smoothing + out.sum(dim=2).sum(dim=2) + target.sum(dim =2).sum(dim = 2))
        loss = torch.mean(1 - diceScore) 
        return loss

In [6]:
## Batch Size
train_batch_size = 1

## Learning Rate
learning_rate = 0.001

# Epochs (Consider setting high and implementing early stopping)
num_epochs = 200

In [9]:
n_classes = 2
model = UNet(1, n_classes)

if torch.cuda.is_available():
    model.cuda()

## Initialize Dataloaders
train_dataset=ImageDataset(input_dir=directory, transforms=img_transform)
train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=False)

## Initialize Optimizer and Learning Rate Scheduler
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate, betas = [.9, .999])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

#Losses
lossfxn = DICELoss()


print("Start Training...")

tLoss = []
vLoss = []


for epoch in range(num_epochs):
    ########################### Training #####################################
    print("\nEPOCH " +str(epoch+1)+" of "+str(num_epochs)+"\n")

  

    model.train()
    batchLoss = []
    for data in train_dataloader:
        optimizer.zero_grad()
        ## Format Data
        img, target = data
        img = img.cuda()
        target = target.cuda()
        ## Make Predictions
        #out = model(img)
        #prediction = torch.argmax(model(img), dim = 1)
        l = lossfxn(model(img), target)
        batchLoss.append(l.item())
        l.backward()
        optimizer.step()
    epochLoss = np.mean(batchLoss)
    tLoss.append(epochLoss)
    scheduler.step()
    print("Training Loss:", epochLoss)
  

  

Start Training...

EPOCH 1 of 200

Training Loss: 0.996692419052124

EPOCH 2 of 200

Training Loss: 0.9961113731066386

EPOCH 3 of 200

Training Loss: 0.9955704808235168

EPOCH 4 of 200

Training Loss: 0.9955242077509562

EPOCH 5 of 200

Training Loss: 0.9950094421704611

EPOCH 6 of 200

Training Loss: 0.9948886632919312

EPOCH 7 of 200

Training Loss: 0.9946417411168417

EPOCH 8 of 200

Training Loss: 0.9946794509887695

EPOCH 9 of 200

Training Loss: 0.9944619337717692

EPOCH 10 of 200

Training Loss: 0.9943185846010844

EPOCH 11 of 200

Training Loss: 0.9941792885462443

EPOCH 12 of 200

Training Loss: 0.9941595991452535

EPOCH 13 of 200

Training Loss: 0.9941262205441793

EPOCH 14 of 200

Training Loss: 0.9940953056017557

EPOCH 15 of 200

Training Loss: 0.9940719803174337

EPOCH 16 of 200

Training Loss: 0.9940452377001444

EPOCH 17 of 200

Training Loss: 0.9940088589986166

EPOCH 18 of 200

Training Loss: 0.9939751227696737

EPOCH 19 of 200

Training Loss: 0.9939441482226054

EPO

KeyboardInterrupt: ignored