# DenseNet Model Implementation

In [9]:
import os
import glob
import cv2
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
from torch.utils.data import random_split
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
import sklearn
from PIL import Image
import zipfile
import urllib.request
import os.path
from IPython.display import display

# Set GPU access

In [12]:
if torch.backends.mps.is_available(): # Check if PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
    print(f"MPS is available!")
    if torch.backends.mps.is_built():
        print(f"MPS (Metal Performance Shader) is built in!")    
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(f"Using device: {device}")

MPS is available!
MPS (Metal Performance Shader) is built in!
Using device: mps


# Helper functions

In [13]:
def create_image_stack(images, labels):
    
    image_count = len(images) // 9  

    stacked_data = []
    stacked_labels = []

    normalize = transforms.Normalize([0.485], [0.229]) 

    for i in range(image_count):
        start_index = i * 9
        end_index = start_index + 9

        single_stack = images[start_index+1:end_index] # only 8 images without background image

        image_stack = np.stack([normalize(torch.from_numpy(img.astype(np.float32))) for img in single_stack])

        image_stack = torch.from_numpy(image_stack)
        
        stacked_data.append(image_stack)
        
        stacked_labels.append(labels[start_index])
    
    return stacked_data, stacked_labels

In [14]:
def compute_score_with_logits(logits, labels):
    logits = torch.max(logits, 1)[1].data # argmax
    if device == "mps":
        one_hots = torch.zeros(*labels.size()).to(device)
    else:
        one_hots = torch.zeros(*labels.size()).device()
    one_hots.scatter_(1, logits.view(-1, 1), 1)
    scores = (one_hots * labels)

    return scores

In [15]:
def tile(a, dim, n_tile):
    init_dim = a.size(dim)
    repeat_idx = [1] * a.dim()
    repeat_idx[dim] = n_tile
    a = a.repeat(*(repeat_idx))
    order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
    return torch.index_select(a, dim, order_index)

# Preprocessing

In [30]:
class DataPreprocessing(Dataset):
    
    # classPaths --> className: [path, #training_samples]
    classPaths = {
    
    'AGE_RMD':  ['../../Images/CT_RETINA/AGE_RMD_55/AR1-45_9Levels/', 40],
    'CSR':  ['../../Images/CT_RETINA/CSR_102/CR1-80_9Levels/', 5],
    'DIABETR': [ '../../Images/CT_RETINA/DIABETR_107/DR1-83_9Levels/', 40],
    'MACHOLE': ['../../Images/CT_RETINA/MACHOLE_102/MH1-80_9Levels/', 40],
    'NORMAL':  ['../../Images/CT_RETINA/NORMAL_206/NR1-160_9Levels/', 40] 
    
        }

    # classEncoding --> className: label_tensor
    classEncoding = {
        
                    #change 
                    'AGE_RMD': torch.FloatTensor([1, 0, 0]),
#                     'CSR': torch.FloatTensor([0, 1, 0, 0, 0]),
                    'DIABETR': torch.FloatTensor([0, 1, 0]),
#                     'MACHOLE': torch.FloatTensor([1, 0]),
                    'NORMAL': torch.FloatTensor([ 0, 0, 1])

                }
    
    def __init__(self,classEncoding=classEncoding, classPaths=classPaths):
        
        self.image_paths = []
        self.labels = []
        self.images = []
    
        # image paths
        for imCoreName in (classEncoding.keys()):
            temp_paths = []
            for directoryPath in glob.glob(classPaths[imCoreName][0]):
                for imgPath in glob.glob(os.path.join(directoryPath, "*.jpg")):
                    temp_paths.append(imgPath)


            # labels 
            labels_list = [imCoreName] * classPaths[imCoreName][1] * 9

            for label in labels_list:
                labelTensor = torch.FloatTensor([0, 0, 0])  #change
                labelTensor = labelTensor.add(classEncoding[label])
                self.labels.append(labelTensor)

            
            img_paths = temp_paths[:classPaths[imCoreName][1] * 9] 
            self.image_paths.append(img_paths)

            for image_path in img_paths:
                img = cv2.imread(image_path,0) 
                img = cv2.resize(img, (224, 224), interpolation = cv2.INTER_AREA)
                img = np.reshape(img, (*img.shape, 1))
                img = np.transpose(img, (2, 0, 1))
                self.images.append(img)
              
              
        self.image_paths = [y for x in self.image_paths for y in x]
        self.stacked_images, self.stacked_labels = create_image_stack(self.images, self.labels)  

                
    def __getitem__(self, index):
        # preprocess and return single image stack of dim 8*1*224*224
        
        return self.stacked_images[index], self.stacked_labels[index]
    
    def __len__(self):

        return len(self.stacked_images)

# Model

In [31]:
class DenseNet121(nn.Module):
    def __init__(self):
        super(DenseNet121, self).__init__()
        self.model = torchvision.models.densenet121(pretrained = True)
        num_ftrs = self.model.classifier.in_features
        self.model.features[0] = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.model.classifier = nn.Sequential(
            nn.Linear(num_ftrs, 3), #change
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.model(x)
        return x

In [33]:
data = DataPreprocessing()
train_set, test_set = random_split(data, [math.ceil(len(data) * 0.8), math.floor(len(data) * 0.2)])

trainloader = torch.utils.data.DataLoader(train_set, batch_size=2, shuffle=True, num_workers=0)
testloader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False)

if device == "mps":
    model = DenseNet121().to(device)
    model = nn.DataParallel(model).to(device)
else:
    model = DenseNet121().device()
    model = nn.DataParallel(model).device()



In [34]:
%%time

criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)  # RMSprop, Adam

for epoch in range(100): #loop over the dataset multiple times

    running_loss = 0.0
    correct = 0
    total = 0 
    for i, (images, labels) in enumerate(trainloader, 0): # get the inputs; data is a list of [images, labels]

        # zero the parameter gradients
        optimizer.zero_grad()
        
        if device == "mps":
            images = images.to(device)
        else:
            images = images.device()
        #format input
        n_batches, n_crops, channels, height, width = images.size()
        image_batch = torch.autograd.Variable(images.view(-1, channels, height, width)) 
        
        if device == "mps":
            labels = tile(labels, 0, 8).to(device) #duplicate for each crop the label 
        else:
            labels = tile(labels, 0, 8).device()
        
        # forward + backward + optimize
        outputs = model(image_batch)
        loss = criterion(outputs, labels.float())
        loss.backward()
        optimizer.step()


        running_loss += loss.item()

        correct += compute_score_with_logits(outputs, labels).sum()
        total += labels.size(0)

    print('Epoch: %d, loss: %.3f, Accuracy: %.3f' %
          (epoch + 1, running_loss, 100 * correct / total))

print('Finished Training')


SyntaxError: expected ':' (<unknown>, line 16)

In [None]:
model.eval()

In [None]:
correct = 0
total = 0
with torch.no_grad():
    for i, (images, labels) in enumerate(testloader, 0):
        if device == "mps":
            images = images.to(device)
        else:
            images = images.device()
        n_batches, n_crops, channels, height, width = images.size()
        image_batch = torch.autograd.Variable(images.view(-1, channels, height, width))
        if device == "mps":
            labels = tile(labels, 0, 8).to(device)
        else:
            labels = tile(labels, 0, 8).device()
        outputs = model(image_batch)
        correct += compute_score_with_logits(outputs, labels).sum()
        total += labels.size(0)
    
print('Accuracy on test set: %.3f' % (100 * correct / total))