# DenseNet Model Implementation

In [1]:
import os
import glob
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import random_split
import torch.nn as nn
import torch.optim as optim
import torchvision
import os.path
from helperFunctions import tile
from analysisFunctions import compute_score_with_logits
from dataClasses import DataPreprocessing
from dataClasses import DenseNet121
import dotenv
dotenv.load_dotenv()

True

In [3]:
 # Check if PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(f"Using device: {device}")

# Data Preprocessing 
data = DataPreprocessing()

trainMode = eval(os.getenv("trainMode"))
if trainMode:
    trainSize = len(data) 
    testSize = 0
    print('Train Images: ', trainSize)
else: 
    trainSize = 0
    testSize = len(data)
    print('Test Images: ', testSize) 



train_set, test_set = random_split(data, [trainSize, testSize])

if trainMode:
    trainloader = torch.utils.data.DataLoader(train_set, batch_size=2, shuffle=False, num_workers=0)
else: 
    testloader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False)

if device == "mps":
    model = DenseNet121().to(device)
    model = nn.DataParallel(model).to(device)
elif device == "cuda":
    model = DenseNet121().cuda()
    model = nn.DataParallel(model).cuda()
else :
    model = DenseNet121()
    model = nn.DataParallel(model)

Using device: cuda
Test Images:  416


In [4]:

#%%time

#parameters to save/load models
trainMode = eval(os.getenv('trainMode'))
if trainMode:
    saveModel = os.getenv('saveModel')
else:
    saveModel = 'False'

modelPath = os.getenv('SAVE_LOAD_MODEL_PATH')
downImageSize = os.getenv('downImageSize')
nClasses = os.getenv('nClasses')
modelFname = "dense" + str(nClasses) + "class" + str(downImageSize) + "pix.pt"
modelPathFile = modelPath + modelFname

if trainMode: 

    #INITIAL GUESS FROM PREVIOUSLY TRAINED MODEL
    if os.path.isfile(modelPathFile):
        model = torch.load(modelPathFile)
    else:
        pass 


    logPathFileE = modelPath + "logEpochs.csv"
    if os.path.isfile(logPathFileE):
        os.remove(logPathFileE)
        print (' log deleted')
    else:
        print ('log not found')
    logFileE = open(logPathFileE, "a")    

    criterion = nn.BCELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)  # RMSprop, Adam

    nEpochs = eval(os.getenv("nEpochs"))
    useImLevels = eval(os.getenv("useImLevels"))

    for epoch in range(nEpochs): #loop over the dataset multiple times

        running_loss = 0.0
        correct = 0
        total = 0 
        for i, (images, labels, image_names) in enumerate(trainloader, 0): # get the inputs; data is a list of [images, labels]

            # zero the parameter gradients
            optimizer.zero_grad()
            
            if device == "mps":
                images = images.to(device)
                labels = tile(labels, 0, useImLevels).to(device) #duplicate for each crop the label 
            elif device == "cuda":
                images = images.cuda()
                labels = tile(labels, 0, useImLevels).cuda()
            else:
                labels = tile(labels, 0, useImLevels)
        
            #format input
            n_batches, n_crops, channels, height, width = images.size()
            image_batch = torch.autograd.Variable(images.view(-1, channels, height, width)) 
            
            # forward + backward + optimize
            outputs = model(image_batch)
            loss = criterion(outputs, labels.float())
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            correct += compute_score_with_logits(outputs, labels, device).sum()
            total += labels.size(0)

        print('Epoch: %d, loss: %.3f, Accuracy: %.3f' %
            (epoch + 1, running_loss, 100 * correct / total))

        logEL = 'Epoch: ' + str(epoch +1) + ' Loss: ' + str(running_loss)

        logFileE.write(logEL + ' Accuracy: ' + str(100 * correct / total ) + "\n")

    print('Finished Training')
    logFileE.close()
    if saveModel: 
        torch.save(model, modelPathFile)
        
#load the model when doing only testing
else: 
    model = torch.load(modelPathFile)

#model evaluation
model.eval()

DataParallel(
  (module): DenseNet121(
    (model): DenseNet(
      (features): Sequential(
        (conv0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu0): ReLU(inplace=True)
        (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (denseblock1): _DenseBlock(
          (denselayer1): _DenseLayer(
            (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu1): ReLU(inplace=True)
            (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu2): ReLU(inplace=True)
            (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          )
          (denselayer2): _DenseLayer

In [5]:
#if useCalcModel:
#    model = torch.load('model.pth')
correct = 0
total = 0
useImLevels = eval(os.getenv("useImLevels"))
perfect8 = 0
sevenLev = 0
sixL = 0
problem = 0
normWArt =0
sickAsNorm = 0
wrongMajor = 0
multiDes = 0
normAsSick = 0
vovaAcc = 0

modelPath = os.getenv('SAVE_LOAD_MODEL_PATH')
downImageSize = os.getenv('downImageSize')

logFname = "log" + str(nClasses) + "class" + str(downImageSize) + ".csv"
logPathFile = modelPath + logFname

if os.path.isfile(logPathFile):
    os.remove(logPathFile)
    print (' log deleted')
else:
    print ('log not found')

logFile = open(logPathFile, "a")


with torch.no_grad():
    for i, (images, labels, image_names) in enumerate(testloader, 0):
        if device == "mps":
            images = images.to(device)
            labels = tile(labels, 0, useImLevels).to(device)
        elif device == "cuda":
            images = images.cuda()
            labels = tile(labels, 0, useImLevels).cuda() 
        else:
            labels = tile(labels, 0, useImLevels)
        n_batches, n_crops, channels, height, width = images.size()
        image_batch = torch.autograd.Variable(images.view(-1, channels, height, width))
      
        outputs = model(image_batch)
        prob = outputs.cpu().numpy()
        prob = np.round(100*prob)/100
      
        nCl = len(prob[0])
        #initialize probability scorer
        fitN = np.zeros(nCl)
        #find maximum probability index 
        mpi =np.argmax(prob, axis=-1) #8-element vector w highest prob

        for j in range(nCl):
             fitN[j] = sum(np.equal(mpi, j))
        labelIs = np.argmax(fitN)
        labelWas = np.argmax(labels[0].cpu().numpy())
        logString0 =  image_names[0][0].split("_1_")[0] + " L is: " + str(labelIs) + " L was: " + str(labelWas)
        logString1 = logString0 +  " Fit: " + str(fitN.astype(int))
       
        # Print all image content
        # print(logString1)
        if (labelIs == labelWas):
            vovaAcc = vovaAcc +1



        # Log only problematic Images:
        if (max(fitN) == 8) and (labelIs == labelWas):
            perfect8 = perfect8 + 1
        elif (max(fitN) == 7) and (labelIs == labelWas):
            sevenLev = sevenLev + 1
        elif (max(fitN) > 4) and (labelIs == labelWas) and (labelIs== nCl- 1):
            sevenLev = sevenLev + 1
        elif (max(fitN) > 4) and (labelIs == labelWas):
            sixL = sixL + 1
            logFile.write('Minor Second: ' + logString1 + "\n")
        elif (fitN[nCl-1] > 2) and (labelWas == nCl-1):
            normWArt = normWArt + 1
            logFile.write('Likely Normal w Artifact : ' + logString1 + "\n")
        elif (fitN[nCl-1] > 3) and (labelWas < nCl-1):
            sickAsNorm = sickAsNorm + 1
            logFile.write('PROBLEM: Sick as Normal : ' + logString1 + "\n")
            print('PROBLEM: Sick as Normal : ' + logString1)
        elif (labelWas == nCl-1) and (labelWas !=labelIs) and (fitN[nCl-1] < 3):
            normAsSick= normAsSick + 1
            logFile.write('Doctor look: Normal as Sick : ' + logString1 + "\n")
            print('Doctor look: Normal as Sick : ' + logString1)
        
        elif (max(fitN) > 4) and (labelIs != labelWas):
            wrongMajor = wrongMajor + 1
            logFile.write('Problem: Wrong Major Diagnosys : ' + logString1 + "\n")
            #print('Problem: Wrong Major Diagnosys : ' + logString1)
        elif (max(fitN) < 5) and (labelWas < nCl-1):
            multiDes = multiDes + 1
            logFile.write('Doctor look: Multi Diseases : ' + logString1 + "\n")
            #print('Doctor look: Multi Diseases : ' + logString1)
        else: # Max fit < 7
            problem = problem + 1
            logFile.write('Unknown Problem: ' + logString1 + "\n")
            print('Unknown Problem: ' + logString1)

        fitScore = compute_score_with_logits(outputs, labels, device).sum()
        correct += fitScore.item()
        total += labels.size(0)

logFile.close()


print('Correct', correct)
print('Total', total)
print('vovaAcc', vovaAcc/len(testloader))

print('Total images tested: ', len(testloader))
# print('Accuracy according to Standard Metric: %.3f' % (100 * correct / total))
print('Single Desease/normal w High Convidence,', perfect8+sevenLev, "Percentage", round(10000*(perfect8+sevenLev)/len(testloader))/100, '%')
print('Single Desease/normal may have 2nd minor,', sixL, "Percentage", round(10000*(sixL)/len(testloader))/100, '%')
print('Normal with Device Artifact,', normWArt , "Percentage",  round(10000*(normWArt)/len(testloader))/100, '%')
print('Problem: Wrong Major Diagnosys', wrongMajor , "Percentage", round(10000*(wrongMajor)/len(testloader))/100, '%')
print('PROBLEM: Sick as Normal', sickAsNorm , "Percentage", round(10000*(sickAsNorm)/len(testloader))/100, '%')
print('Doc Look: Multiple Diseases', multiDes , "Percentage", round(10000*(multiDes)/len(testloader))/100, '%')
print('Doc Look: Normal as Sick', normAsSick , "Percentage", round(10000*(normAsSick)/len(testloader))/100, '%')
print('Unknown Problems', problem , "Percentage", round(10000*(problem)/len(testloader))/100, '%')


log not found
PROBLEM: Sick as Normal : AMRD16 L is: 4 L was: 0 Fit: [3 0 0 0 5]
PROBLEM: Sick as Normal : DR97 L is: 4 L was: 2 Fit: [0 0 0 0 8]
PROBLEM: Sick as Normal : MH7 L is: 4 L was: 3 Fit: [0 1 0 2 5]
PROBLEM: Sick as Normal : CSR50 L is: 4 L was: 1 Fit: [2 0 0 0 6]
PROBLEM: Sick as Normal : CSR83 L is: 4 L was: 1 Fit: [0 0 0 0 8]
PROBLEM: Sick as Normal : CSR31 L is: 4 L was: 1 Fit: [2 0 0 0 6]
PROBLEM: Sick as Normal : DR76 L is: 4 L was: 2 Fit: [0 0 0 0 8]
PROBLEM: Sick as Normal : AMRD13 L is: 4 L was: 0 Fit: [3 0 0 0 5]
PROBLEM: Sick as Normal : CSR96 L is: 0 L was: 1 Fit: [4 0 0 0 4]
PROBLEM: Sick as Normal : CSR78 L is: 4 L was: 1 Fit: [0 3 0 0 5]
PROBLEM: Sick as Normal : DR8 L is: 4 L was: 2 Fit: [2 0 0 0 6]
PROBLEM: Sick as Normal : MH59 L is: 4 L was: 3 Fit: [1 2 0 1 4]
PROBLEM: Sick as Normal : CSR72 L is: 4 L was: 1 Fit: [0 0 0 0 8]
PROBLEM: Sick as Normal : DR25 L is: 4 L was: 2 Fit: [3 0 0 0 5]
PROBLEM: Sick as Normal : AMRD2 L is: 4 L was: 0 Fit: [0 0 0 0 8]
PR