In [1]:
import torch
import os
import pandas as pd
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time
import sys


import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision
import torch.optim as optim
import torch.nn.functional as tfunc
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as func

from sklearn.metrics.ranking import roc_auc_score

from torch.utils.data import Dataset
from PIL import Image
from models.chexnet.DensenetModels import DenseNet121, DenseNet121_Sigmoid
from models.models import ResNet18
from tensorboardX import SummaryWriter
from models.chexnet.DatasetGenerator import DatasetGenerator

%load_ext autoreload
%autoreload 2

In [2]:
writer = SummaryWriter('./logs')

In [3]:
"""
Read images and corresponding labels.
"""
class ChestXrayDataSet(Dataset):
    
    def convert_to_ones(self, df, disease):
        df[disease] = df[disease].replace([-1.0], 1.0)
    
    def convert_to_zeros(self, df, disease):
        df[disease] = df[disease].replace([-1.0], 0.0)
        
    def convert_to_multi(self, df, disease):
        df[disease] = df[disease].replace([-1.0], 2.0)

    def __init__(self, data_dir, image_list_file, diseases=['Atelectasis', 'Consolidation', 'Edema','Cardiomegaly', 'Pleural Effusion'], side='Frontal', transform=None):
        """
        Args:
            data_dir: path to image directory.
            image_list_file: path to the file containing images
                with corresponding labels.
            transform: optional transform to be applied on a sample.
        """
        image_names = []
        labels = []
        chex_df = pd.read_csv(image_list_file)
        chex_df = chex_df.fillna(0.0)
        chex_df = chex_df.loc[chex_df['Frontal/Lateral'] == side]
        self.convert_to_ones(chex_df, 'Atelectasis')
        self.convert_to_ones(chex_df, 'Consolidation')
        self.convert_to_ones(chex_df, 'Edema')
        self.convert_to_ones(chex_df, 'Cardiomegaly')
        self.convert_to_ones(chex_df, 'Pleural Effusion')

#         chex_df_diseases = chex_df[diseases]
                         
#         if 'train' in image_list_file:
#             chex_df = chex_df
#         if len(diseases) == 1:
#             chex_df = chex_df.loc[chex_df['Pleural Effusion'] != -1] #U-Ignore
#         print(chex_df)
        labels = chex_df.as_matrix(columns=diseases)
        labels = list(labels)

        image_names = chex_df.as_matrix(columns=['Path']).flatten()
        image_names = [os.path.join(data_dir, im_name) for im_name in image_names]

        self.image_names = image_names
        self.labels = labels
        self.transform = transform

    def __getitem__(self, index):
        """
        Args:
            index: the index of item
        Returns:
            image and its labels
        """
        image_name = self.image_names[index]
        image = Image.open(image_name).convert('RGB')
        label = torch.FloatTensor(self.labels[index])
        if self.transform is not None:
            image = self.transform(image)
        return image, label

    def __len__(self):
        return len(self.image_names)

In [None]:
BATCH_SIZE = 8
transCrop = 224

DATA_DIR = './data'
TRAIN_IMAGE_LIST = './data/CheXpert-v1.0-small/train.csv'
VAL_IMAGE_LIST = './data/CheXpert-v1.0-small/valid.csv'

normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])


transformList = []
transformList.append(transforms.RandomResizedCrop(transCrop))
transformList.append(transforms.RandomHorizontalFlip())
transformList.append(transforms.ToTensor())
transformList.append(normalize)
transformSequence=transforms.Compose(transformList)

valid_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                image_list_file=VAL_IMAGE_LIST,
                                diseases =['Pleural Effusion'],
                                transform=transformSequence)

In [6]:
class ChexnetTrainer ():

    #---- Train the densenet network 
    #---- pathDirData - path to the directory that contains images
    #---- pathFileTrain - path to the file that contains image paths and label pairs (training set)
    #---- pathFileVal - path to the file that contains image path and label pairs (validation set)
    #---- nnArchitecture - model architecture 'DENSE-NET-121', 'DENSE-NET-169' or 'DENSE-NET-201'
    #---- nnIsTrained - if True, uses pre-trained version of the network (pre-trained on imagenet)
    #---- nnClassCount - number of output classes 
    #---- trBatchSize - batch size
    #---- trMaxEpoch - number of epochs
    #---- transResize - size of the image to scale down to (not used in current implementation)
    #---- transCrop - size of the cropped image 
    #---- launchTimestamp - date/time, used to assign unique name for the checkpoint file
    #---- checkpoint - if not None loads the model and continues training
    
    
    def train (pathDirData, pathFileTrain, pathFileVal, nnArchitecture, nnIsTrained, nnClassCount, trBatchSize, trMaxEpoch, transResize, transCrop, launchTimestamp, checkpoint):

        
        #-------------------- SETTINGS: NETWORK ARCHITECTURE
        if nnArchitecture == 'DENSE-NET-121': model = DenseNet121(nnClassCount, nnIsTrained).cuda()
        elif nnArchitecture == 'DENSE-NET-169': model = DenseNet169(nnClassCount, nnIsTrained).cuda()
        elif nnArchitecture == 'DENSE-NET-201': model = DenseNet201(nnClassCount, nnIsTrained).cuda()
        elif nnArchitecture == 'RES-NET-18': model = ResNet18(nnClassCount, nnIsTrained).cuda()
        
        model = torch.nn.DataParallel(model).cuda()
       
        #-------------------- SETTINGS: DATA TRANSFORMS |TRAIN|
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        transformList = []
        transformList.append(transforms.RandomResizedCrop(transCrop))
        transformList.append(transforms.RandomHorizontalFlip())
        transformList.append(transforms.ToTensor())
        transformList.append(normalize)      
        transformSequence=transforms.Compose(transformList)

        #-------------------- SETTINGS: DATASET BUILDER |TRAIN|
                    
        datasetTrain = ChestXrayDataSet(data_dir=pathDirData,image_list_file=pathFileTrain, transform=transformSequence)
        #datasetVal =   ChestXrayDataSet(data_dir=pathDirData, image_list_file=pathFileVal, diseases=['Pleural Effusion'], transform=transformSequence)
              
        dataLoaderTrain = DataLoader(dataset=datasetTrain, batch_size=trBatchSize, shuffle=True,  num_workers=4, pin_memory=True)
        #dataLoaderVal = DataLoader(dataset=datasetVal, batch_size=trBatchSize, shuffle=False, num_workers=24, pin_memory=True)
        
        
        
        #-------------------- SETTINGS: DATA TRANSFORMS, TEN CROPS |VAL|
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        #-------------------- SETTINGS: DATASET BUILDERS |VAL|
        transformList = []
        
        transformList.append(transforms.RandomResizedCrop(transCrop))
        transformList.append(transforms.RandomHorizontalFlip())
        transformList.append(transforms.ToTensor())
        transformList.append(normalize)      
        
        
#         transformList.append(transforms.Resize(transResize))
#         transformList.append(transforms.TenCrop(transCrop))
#        transformList.append(normalize)  
#         transformList.append(transforms.ToTensor())
#         transformList.append(normalize)
#         transformList.append(transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])))
#         transformList.append(transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops])))
      
        transformSequence=transforms.Compose(transformList)
        
        datasetVal =   ChestXrayDataSet(data_dir=pathDirData, image_list_file=pathFileVal, transform=transformSequence)
        dataLoaderVal = DataLoader(dataset=datasetVal, batch_size=trBatchSize, shuffle=False, num_workers=4, pin_memory=True)
        
        
        
        
        
        #-------------------- SETTINGS: OPTIMIZER & SCHEDULER
        optimizer = optim.Adam (model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
        scheduler = ReduceLROnPlateau(optimizer, factor = 0.1, patience = 5, mode = 'min')
                
        #-------------------- SETTINGS: LOSS
        loss = torch.nn.BCELoss(size_average = True)
        
        #---- Load checkpoint 
        if checkpoint != None:
            modelCheckpoint = torch.load(checkpoint)
            model.load_state_dict(modelCheckpoint['state_dict'])
            optimizer.load_state_dict(modelCheckpoint['optimizer'])

        
        #---- TRAIN THE NETWORK
        counter = 0
        lossMIN = 100000
        
        for epochID in range (0, trMaxEpoch):
            
            timestampTime = time.strftime("%H%M%S")
            timestampDate = time.strftime("%d%m%Y")
            timestampSTART = timestampDate + '-' + timestampTime
                         
            lossTrain, counter = ChexnetTrainer.epochTrain (model, dataLoaderTrain, dataLoaderVal, optimizer, scheduler, trMaxEpoch, nnClassCount, loss, counter)
            lossVal, losstensor = ChexnetTrainer.epochVal (model, dataLoaderVal, optimizer, scheduler, trMaxEpoch, nnClassCount, loss, counter)
            
            timestampTime = time.strftime("%H%M%S")
            timestampDate = time.strftime("%d%m%Y")
            timestampEND = timestampDate + '-' + timestampTime

            scheduler.step(losstensor.item())
            writer.add_scalar('logs/train_loss_epoch', lossTrain, epochID)
            writer.add_scalar('logs/val_loss_epoch', lossVal, epochID)
            if lossVal < lossMIN:

                lossMIN = lossVal    
                torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN, 'optimizer' : optimizer.state_dict()}, 'm-' + launchTimestamp + '.pth.tar')
                print ('Epoch [' + str(epochID + 1) + '] [save] [' + timestampEND + '] loss= ' + str(lossVal))
            else:
                print ('Epoch [' + str(epochID + 1) + '] [----] [' + timestampEND + '] loss= ' + str(lossVal))
                     
    #-------------------------------------------------------------------------------- 
       
    def epochTrain (model, dataLoader, dataLoaderVal, optimizer, scheduler, epochMax, classCount, loss, counter):
        
        model.train()
        lossTrain = 0
        lossTrainNorm = 0
        
        avg_loss = 0.0
        for batchID, (input, target) in enumerate (dataLoader):

            target = target.cuda()
            varInput = torch.autograd.Variable(input)
            varTarget = torch.autograd.Variable(target)         
            varOutput = model(varInput)
            
#             lossvalue = loss(varOutput, varTarget)

            CEloss =  torch.nn.CrossEntropyLoss()
            BCEloss = torch.nn.BCELoss()

#             varTarget = varTarget.type(torch.long)
            L1 = BCEloss(varOutput[:,:1],varTarget[:,0]) 
            L2 = BCEloss(varOutput[:,1:2],varTarget[:,1])
            L3 = BCEloss(varOutput[:,2:3],varTarget[:,2])
            varTarget = varTarget.long()
            L4 = CEloss(varOutput[:,3:6],varTarget[:,3])
            L5 = CEloss(varOutput[:,6:9],varTarget[:,4])

            
            lossvalue = L1 + L2 + L3 + L4 + L5
            lossvalue /= 5

            print(lossvalue)
            avg_loss = avg_loss * (batchID)/(batchID+1) + lossvalue * 1.0/(batchID+ 1)
            lossTrain += lossvalue
            lossTrainNorm += 1
            
            optimizer.zero_grad()
            lossvalue.backward()
            optimizer.step()
            writer.add_scalar('logs/train_loss', avg_loss, counter)
            if batchID % 1000 == 0:
                ChexnetTrainer.epochVal(model, dataLoaderVal, optimizer, scheduler, epochMax, classCount, loss, counter)
            
            counter += 1
            
        outLoss = lossTrain/lossTrainNorm
        return outLoss, counter

                        
    #-------------------------------------------------------------------------------- 
        
    def epochVal (model, dataLoader, optimizer, scheduler, epochMax, classCount, loss, counter):
        
        model.eval ()
        
        lossVal = 0
        lossValNorm = 0
        
        losstensorMean = 0
        
        
        ###Old code-- didn't handle 5d shapes with crops. We should think about whether we want to crop on val
#         outGT = torch.FloatTensor().cuda()
#         outPRED = torch.FloatTensor().cuda()
       
        
#         for i, (input, target) in enumerate (dataLoader):
#             target = target.cuda()
#             outGT = torch.cat((outGT, target), 0)
        
#             bs, n_crops, c, h, w = input.size()
#             print("Val", input.size())
#             varInput = torch.autograd.Variable(input, volatile=True)
#             varTarget = torch.autograd.Variable(target, volatile=True)    
#             varOutput = model(varInput)
            
#             outMean = varOutput.view(bs, n_crops, -1).mean(1)
            
#             outPRED = torch.cat((outPRED, outMean.data), 0)

        ### Validation computes mean prediction over 10 crops, NOT like in training. 
        outGT = torch.FloatTensor().cuda()
        outPRED = torch.FloatTensor().cuda()
       
        model.eval()
        
        for i, (input, target) in enumerate(dataLoader):
            
            target = target.cuda()
            outGT = torch.cat((outGT, target), 0)
            
#             bs, n_crops, c, h, w = input.size()
            bs, c, h, w = input.size()


            varInput = torch.autograd.Variable(input.view(-1, c, h, w).cuda(), volatile=True)
            
            out = model(varInput)
#             outMean = out.view(bs, n_crops, -1).mean(1)
            outMean = out.view(bs, -1)
    
            outPRED = torch.zeros(out.shape[0], 5).cuda()
            outPRED[:,0] = outMean[:,0]
            outPRED[:,1] = outMean[:,1]
            outPRED[:,2] = outMean[:,2]
            outPRED[:,3] = torch.max(outMean[:,3:6],1)[0]
            outPRED[:,4] = torch.max(outMean[:,6:9],1)[0]
            
            
#             outPRED = torch.cat((outPRED, outMean.data), 0)
            
            varOutput = outPRED
            varTarget = outGT
            
#             losstensor = loss(varOutput, varTarget)

            CEloss =  torch.nn.CrossEntropyLoss()
            BCEloss = torch.nn.BCELoss()

#             varTarget = varTarget.type(torch.long)
            L1 = BCEloss(varOutput[:,:1],varTarget[:,0]) 
            L2 = BCEloss(varOutput[:,1:2],varTarget[:,1])
            L3 = BCEloss(varOutput[:,2:3],varTarget[:,2])
            varTarget = varTarget.long()
            L4 = CEloss(varOutput[:,3:6],varTarget[:,3])
            L5 = CEloss(varOutput[:,6:9],varTarget[:,4])

            
            losstensor = L1 + L2 + L3 + L4 + L5
            losstensor /= 5


            losstensorMean += losstensor
            lossVal += losstensor.item()
            lossValNorm += 1
            
        outLoss = lossVal / lossValNorm
        losstensorMean = losstensorMean / lossValNorm
        
        aurocIndividual = ChexnetTrainer.computeAUROC(outGT, outPRED, classCount)
        aurocMean = np.array(aurocIndividual).mean()
        
        #print("AUROC val", aurocMean)
        writer.add_scalar('logs/val_auroc', aurocMean, counter)
        
        return outLoss, losstensorMean
               
    #--------------------------------------------------------------------------------     
     
    #---- Computes area under ROC curve 
    #---- dataGT - ground truth data
    #---- dataPRED - predicted data
    #---- classCount - number of classes
    
    def computeAUROC (dataGT, dataPRED, classCount):
        
        outAUROC = []
        
        datanpGT = dataGT.cpu().numpy()
        datanpPRED = dataPRED.cpu().numpy()
        
        for i in range(classCount):
            outAUROC.append(roc_auc_score(datanpGT[:, i], datanpPRED[:, i]))
            
        return outAUROC
        
        
    #--------------------------------------------------------------------------------  
    
    #---- Test the trained network 
    #---- pathDirData - path to the directory that contains images
    #---- pathFileTrain - path to the file that contains image paths and label pairs (training set)
    #---- pathFileVal - path to the file that contains image path and label pairs (validation set)
    #---- nnArchitecture - model architecture 'DENSE-NET-121', 'DENSE-NET-169' or 'DENSE-NET-201'
    #---- nnIsTrained - if True, uses pre-trained version of the network (pre-trained on imagenet)
    #---- nnClassCount - number of output classes 
    #---- trBatchSize - batch size
    #---- trMaxEpoch - number of epochs
    #---- transResize - size of the image to scale down to (not used in current implementation)
    #---- transCrop - size of the cropped image 
    #---- launchTimestamp - date/time, used to assign unique name for the checkpoint file
    #---- checkpoint - if not None loads the model and continues training
    
    def test (pathDirData, pathFileTest, pathModel, nnArchitecture, nnClassCount, nnIsTrained, trBatchSize, transResize, transCrop, launchTimeStamp):   
        
        
        CLASS_NAMES = [ 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
                'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia']
        
        cudnn.benchmark = True
        
        #-------------------- SETTINGS: NETWORK ARCHITECTURE, MODEL LOAD
        if nnArchitecture == 'DENSE-NET-121': model = DenseNet121(nnClassCount, nnIsTrained).cuda()
        elif nnArchitecture == 'DENSE-NET-169': model = DenseNet169(nnClassCount, nnIsTrained).cuda()
        elif nnArchitecture == 'DENSE-NET-201': model = DenseNet201(nnClassCount, nnIsTrained).cuda()
        
        model = torch.nn.DataParallel(model).cuda() 
        
        modelCheckpoint = torch.load(pathModel)
        model.load_state_dict(modelCheckpoint['state_dict'])

        #-------------------- SETTINGS: DATA TRANSFORMS, TEN CROPS
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        #-------------------- SETTINGS: DATASET BUILDERS
        transformList = []
        transformList.append(transforms.Resize(transResize))
        transformList.append(transforms.TenCrop(transCrop))
        transformList.append(transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])))
        transformList.append(transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops])))
        transformSequence=transforms.Compose(transformList)
        
        datasetTest = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTest, transform=transformSequence)
        dataLoaderTest = DataLoader(dataset=datasetTest, batch_size=trBatchSize, num_workers=4, shuffle=False, pin_memory=True)
        
        outGT = torch.FloatTensor().cuda()
        outPRED = torch.FloatTensor().cuda()
       
        model.eval()
        
        for i, (input, target) in enumerate(dataLoaderTest):
            
            target = target.cuda()
            outGT = torch.cat((outGT, target), 0)
            
            bs, n_crops, c, h, w = input.size()
            
            varInput = torch.autograd.Variable(input.view(-1, c, h, w).cuda(), volatile=True)
            
            out = model(varInput)
            outMean = out.view(bs, n_crops, -1).mean(1)
            
            outPRED = torch.cat((outPRED, outMean.data), 0)

        aurocIndividual = ChexnetTrainer.computeAUROC(outGT, outPRED, nnClassCount)
        aurocMean = np.array(aurocIndividual).mean()
        
        print ('AUROC mean ', aurocMean)
        
        for i in range (0, len(aurocIndividual)):
            print (CLASS_NAMES[i], ' ', aurocIndividual[i])
        
     
        return
#-------------------------------------------------------------------------------- 


In [None]:
DATA_DIR = './data'
TRAIN_IMAGE_LIST = './data/CheXpert-v1.0-small/train.csv'
VAL_IMAGE_LIST = './data/CheXpert-v1.0-small/valid.csv'
valid_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                image_list_file=VAL_IMAGE_LIST,
                                diseases =['Pleural Effusion'],
                                transform=transformSequence)

nnIsTrained = True
nnArchitecture = 'RES-NET-18'
nnClassCount = 1
trBatchSize = 10
trMaxEpoch = 50
transResize = 256
transCrop = 224
launchTimestamp = ''
checkpoint = None

ChexnetTrainer.train(DATA_DIR,TRAIN_IMAGE_LIST,VAL_IMAGE_LIST,nnArchitecture, nnIsTrained, nnClassCount, trBatchSize, trMaxEpoch, transResize, transCrop, launchTimestamp, checkpoint)


In [None]:
DATA_DIR = './data'
TRAIN_IMAGE_LIST = './data/CheXpert-v1.0-small/train.csv'
VAL_IMAGE_LIST = './data/CheXpert-v1.0-small/valid.csv'
# valid_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
#                                 image_list_file=VAL_IMAGE_LIST,
#                                 diseases =['Pleural Effusion'],
#                                 transform=transformSequence)

nnIsTrained = True
nnArchitecture = 'DENSE-NET-121'
nnClassCount = 1
trBatchSize = 4
trMaxEpoch = 50
transResize = 256
transCrop = 224
launchTimestamp = 'dense'
checkpoint = None

ChexnetTrainer.train(DATA_DIR,TRAIN_IMAGE_LIST,VAL_IMAGE_LIST,nnArchitecture, nnIsTrained, nnClassCount, trBatchSize, trMaxEpoch, transResize, transCrop, launchTimestamp, checkpoint)

In [7]:
DATA_DIR = './data'
TRAIN_IMAGE_LIST = './data/CheXpert-v1.0-small/train.csv'
VAL_IMAGE_LIST = './data/CheXpert-v1.0-small/valid.csv'
valid_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                image_list_file=VAL_IMAGE_LIST)

nnIsTrained = True
nnArchitecture = 'DENSE-NET-121'
nnClassCount = 9
trBatchSize = 6
trMaxEpoch = 50
transResize = 256
transCrop = 224
launchTimestamp = ''
checkpoint = None

ChexnetTrainer.train(DATA_DIR,TRAIN_IMAGE_LIST,VAL_IMAGE_LIST,nnArchitecture, nnIsTrained, nnClassCount, trBatchSize, trMaxEpoch, transResize, transCrop, launchTimestamp, checkpoint)



tensor(0.8192, device='cuda:0', grad_fn=<DivBackward0>)




RuntimeError: invalid argument 2: non-empty vector or matrix expected at /opt/conda/conda-bld/pytorch_1556653114079/work/aten/src/THCUNN/generic/ClassNLLCriterion.cu:31

### From here is the working code for our base line

In [6]:
class ChexnetTrainer ():

    #---- Train the densenet network 
    #---- pathDirData - path to the directory that contains images
    #---- pathFileTrain - path to the file that contains image paths and label pairs (training set)
    #---- pathFileVal - path to the file that contains image path and label pairs (validation set)
    #---- nnArchitecture - model architecture 'DENSE-NET-121', 'DENSE-NET-169' or 'DENSE-NET-201'
    #---- nnIsTrained - if True, uses pre-trained version of the network (pre-trained on imagenet)
    #---- nnClassCount - number of output classes 
    #---- trBatchSize - batch size
    #---- trMaxEpoch - number of epochs
    #---- transResize - size of the image to scale down to (not used in current implementation)
    #---- transCrop - size of the cropped image 
    #---- launchTimestamp - date/time, used to assign unique name for the checkpoint file
    #---- checkpoint - if not None loads the model and continues training
    
    
    def train (pathDirData, pathFileTrain, pathFileVal, nnArchitecture, nnIsTrained, nnClassCount, trBatchSize, trMaxEpoch, transResize, transCrop, launchTimestamp, checkpoint):

        
        #-------------------- SETTINGS: NETWORK ARCHITECTURE
        if nnArchitecture == 'DENSE-NET-121-Sigmoid': model = DenseNet121_Sigmoid(nnClassCount, nnIsTrained).cuda()
        elif nnArchitecture == 'DENSE-NET-121': model = DenseNet121(nnClassCount, nnIsTrained)#.cuda()
        elif nnArchitecture == 'DENSE-NET-169': model = DenseNet169(nnClassCount, nnIsTrained)#.cuda()
        elif nnArchitecture == 'DENSE-NET-201': model = DenseNet201(nnClassCount, nnIsTrained)#.cuda()
        elif nnArchitecture == 'RES-NET-18': model = ResNet18(nnClassCount, nnIsTrained)#.cuda()
        
        model = torch.nn.DataParallel(model)#.cuda()
       
        #-------------------- SETTINGS: DATA TRANSFORMS |TRAIN|
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        transformList = []
        transformList.append(transforms.RandomResizedCrop(transCrop))
        transformList.append(transforms.RandomHorizontalFlip())
        transformList.append(transforms.ToTensor())
        transformList.append(normalize)      
        transformSequence=transforms.Compose(transformList)

        #-------------------- SETTINGS: DATASET BUILDER |TRAIN|
                    
        datasetTrain = ChestXrayDataSet(data_dir=pathDirData,image_list_file=pathFileTrain, transform=transformSequence)
        #datasetVal =   ChestXrayDataSet(data_dir=pathDirData, image_list_file=pathFileVal, diseases=['Pleural Effusion'], transform=transformSequence)
              
        dataLoaderTrain = DataLoader(dataset=datasetTrain, batch_size=trBatchSize, shuffle=True,  num_workers=4, pin_memory=True)
        #dataLoaderVal = DataLoader(dataset=datasetVal, batch_size=trBatchSize, shuffle=False, num_workers=24, pin_memory=True)
        
        
        
        #-------------------- SETTINGS: DATA TRANSFORMS, TEN CROPS |VAL|
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        #-------------------- SETTINGS: DATASET BUILDERS |VAL|
        transformList = []
        
        transformList.append(transforms.RandomResizedCrop(transCrop))
        transformList.append(transforms.RandomHorizontalFlip())
        transformList.append(transforms.ToTensor())
        transformList.append(normalize)      
        
        
#         transformList.append(transforms.Resize(transResize))
#         transformList.append(transforms.TenCrop(transCrop))
#        transformList.append(normalize)  
#         transformList.append(transforms.ToTensor())
#         transformList.append(normalize)
#         transformList.append(transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])))
#         transformList.append(transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops])))
      
        transformSequence=transforms.Compose(transformList)
        
        datasetVal =   ChestXrayDataSet(data_dir=pathDirData, image_list_file=pathFileVal, transform=transformSequence)
        dataLoaderVal = DataLoader(dataset=datasetVal, batch_size=trBatchSize, shuffle=False, num_workers=4, pin_memory=True)
        
        
        
        
        
        #-------------------- SETTINGS: OPTIMIZER & SCHEDULER
        optimizer = optim.Adam (model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
        scheduler = ReduceLROnPlateau(optimizer, factor = 0.1, patience = 5, mode = 'min')
                
        #-------------------- SETTINGS: LOSS
        loss = torch.nn.BCELoss(size_average = True)
        
        #---- Load checkpoint 
        if checkpoint != None:
            modelCheckpoint = torch.load(checkpoint)
            model.load_state_dict(modelCheckpoint['state_dict'])
            optimizer.load_state_dict(modelCheckpoint['optimizer'])

        
        #---- TRAIN THE NETWORK
        counter = 0
        lossMIN = 100000
        
        for epochID in range (0, trMaxEpoch):
            
            timestampTime = time.strftime("%H%M%S")
            timestampDate = time.strftime("%d%m%Y")
            timestampSTART = timestampDate + '-' + timestampTime
                         
            lossTrain, counter = ChexnetTrainer.epochTrain (model, dataLoaderTrain, dataLoaderVal, optimizer, scheduler, trMaxEpoch, nnClassCount, loss, counter)
            lossVal, losstensor = ChexnetTrainer.epochVal (model, dataLoaderVal, optimizer, scheduler, trMaxEpoch, nnClassCount, loss, counter)
            
            timestampTime = time.strftime("%H%M%S")
            timestampDate = time.strftime("%d%m%Y")
            timestampEND = timestampDate + '-' + timestampTime

            scheduler.step(losstensor.item())
            writer.add_scalar('logs/train_loss_epoch', lossTrain, epochID)
            writer.add_scalar('logs/val_loss_epoch', lossVal, epochID)
            if lossVal < lossMIN:

                lossMIN = lossVal    
                torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN, 'optimizer' : optimizer.state_dict()}, 'm-' + launchTimestamp + '.pth.tar')
                print ('Epoch [' + str(epochID + 1) + '] [save] [' + timestampEND + '] loss= ' + str(lossVal))
            else:
                print ('Epoch [' + str(epochID + 1) + '] [----] [' + timestampEND + '] loss= ' + str(lossVal))
                     
    #-------------------------------------------------------------------------------- 
       
    def epochTrain (model, dataLoader, dataLoaderVal, optimizer, scheduler, epochMax, classCount, loss, counter):
        
        model.train()
        lossTrain = 0
        lossTrainNorm = 0
        
        avg_loss = 0.0
        for batchID, (input, target) in enumerate (dataLoader):
            
            target = target#.cuda()
            varInput = torch.autograd.Variable(input)
            varTarget = torch.autograd.Variable(target)         
            varOutput = model(varInput)
            
#             lossvalue = loss(varOutput, varTarget)

            CEloss =  torch.nn.CrossEntropyLoss()
            BCEloss = torch.nn.BCELoss()
            varOutput = varOutput.cpu()
            print(varOutput.shape)
#             varTarget = varTarget.type(torch.long)
            L1 = BCEloss(varOutput[:,:1],varTarget[:,0]) 
            L2 = BCEloss(varOutput[:,1:2],varTarget[:,1])
            L3 = BCEloss(varOutput[:,2:3],varTarget[:,2])
            varTarget = varTarget.long()
            L4 = CEloss(varOutput[:,3:6],varTarget[:,3])
            L5 = CEloss(varOutput[:,6:9],varTarget[:,4])

            
            lossvalue = L1 + L2 + L3 + L4 + L5
            lossvalue /= 5

            print(lossvalue)
            avg_loss = avg_loss * (batchID)/(batchID+1) + lossvalue * 1.0/(batchID+ 1)
            lossTrain += lossvalue
            lossTrainNorm += 1
            
            optimizer.zero_grad()
            lossvalue.backward()
            optimizer.step()
            writer.add_scalar('logs/train_loss', avg_loss, counter)
            if batchID % 1000 == 0:
                ChexnetTrainer.epochVal(model, dataLoaderVal, optimizer, scheduler, epochMax, classCount, loss, counter)
            
            counter += 1
            
        outLoss = lossTrain/lossTrainNorm
        return outLoss, counter

                        
    #-------------------------------------------------------------------------------- 
        
    def epochVal (model, dataLoader, optimizer, scheduler, epochMax, classCount, loss, counter):
        
        model.eval ()
        
        lossVal = 0
        lossValNorm = 0
        
        losstensorMean = 0
        
        
        ###Old code-- didn't handle 5d shapes with crops. We should think about whether we want to crop on val
#         outGT = torch.FloatTensor().cuda()
#         outPRED = torch.FloatTensor().cuda()
       
        
#         for i, (input, target) in enumerate (dataLoader):
#             target = target.cuda()
#             outGT = torch.cat((outGT, target), 0)
        
#             bs, n_crops, c, h, w = input.size()
#             print("Val", input.size())
#             varInput = torch.autograd.Variable(input, volatile=True)
#             varTarget = torch.autograd.Variable(target, volatile=True)    
#             varOutput = model(varInput)
            
#             outMean = varOutput.view(bs, n_crops, -1).mean(1)
            
#             outPRED = torch.cat((outPRED, outMean.data), 0)

        ### Validation computes mean prediction over 10 crops, NOT like in training. 
        outGT = torch.FloatTensor()#.cuda()
        outPRED = torch.FloatTensor()#.cuda()
       
        model.eval()
        
        for i, (input, target) in enumerate(dataLoader):
            
            target = target#.cuda()
            outGT = torch.cat((outGT, target), 0)
            
#             bs, n_crops, c, h, w = input.size()
            bs, c, h, w = input.size()


            varInput = torch.autograd.Variable(input.view(-1, c, h, w), volatile=True) # .cuda() for input
            
            out = model(varInput)
#             outMean = out.view(bs, n_crops, -1).mean(1)
            outMean = out.view(bs, -1)
            print(outMean.type)
            outPRED = torch.zeros(out.shape[0], 5)#.cuda()
            outPRED[:,0] = outMean[:,0]
            outPRED[:,1] = outMean[:,1]
            outPRED[:,2] = outMean[:,2]
            outPRED[:,3] = torch.max(outMean[:,3:6],1)[0]
            outPRED[:,4] = torch.max(outMean[:,6:9],1)[0]
            
            
#             outPRED = torch.cat((outPRED, outMean.data), 0)
            
            varOutput = outPRED
            varTarget = outGT
            
#             losstensor = loss(varOutput, varTarget)

            CEloss =  torch.nn.CrossEntropyLoss()
            BCEloss = torch.nn.BCELoss()

#             varTarget = varTarget.type(torch.long)
            L1 = BCEloss(varOutput[:,:1],varTarget[:,0]) 
            L2 = BCEloss(varOutput[:,1:2],varTarget[:,1])
            L3 = BCEloss(varOutput[:,2:3],varTarget[:,2])
            varTarget = varTarget.long()
            L4 = CEloss(varOutput[:,3:6],varTarget[:,3])
            L5 = CEloss(varOutput[:,6:9],varTarget[:,4])

            
            losstensor = L1 + L2 + L3 + L4 + L5
            losstensor /= 5


            losstensorMean += losstensor
            lossVal += losstensor.item()
            lossValNorm += 1
            
        outLoss = lossVal / lossValNorm
        losstensorMean = losstensorMean / lossValNorm
        
        aurocIndividual = ChexnetTrainer.computeAUROC(outGT, outPRED, classCount)
        aurocMean = np.array(aurocIndividual).mean()
        
        #print("AUROC val", aurocMean)
        writer.add_scalar('logs/val_auroc', aurocMean, counter)
        
        return outLoss, losstensorMean
               
    #--------------------------------------------------------------------------------     
     
    #---- Computes area under ROC curve 
    #---- dataGT - ground truth data
    #---- dataPRED - predicted data
    #---- classCount - number of classes
    
    def computeAUROC (dataGT, dataPRED, classCount):
        
        outAUROC = []
        
        datanpGT = dataGT.cpu().numpy()
        datanpPRED = dataPRED.cpu().numpy()
        
        for i in range(classCount):
            outAUROC.append(roc_auc_score(datanpGT[:, i], datanpPRED[:, i]))
            
        return outAUROC
        
        
    #--------------------------------------------------------------------------------  
    
    #---- Test the trained network 
    #---- pathDirData - path to the directory that contains images
    #---- pathFileTrain - path to the file that contains image paths and label pairs (training set)
    #---- pathFileVal - path to the file that contains image path and label pairs (validation set)
    #---- nnArchitecture - model architecture 'DENSE-NET-121', 'DENSE-NET-169' or 'DENSE-NET-201'
    #---- nnIsTrained - if True, uses pre-trained version of the network (pre-trained on imagenet)
    #---- nnClassCount - number of output classes 
    #---- trBatchSize - batch size
    #---- trMaxEpoch - number of epochs
    #---- transResize - size of the image to scale down to (not used in current implementation)
    #---- transCrop - size of the cropped image 
    #---- launchTimestamp - date/time, used to assign unique name for the checkpoint file
    #---- checkpoint - if not None loads the model and continues training
    
    def test (pathDirData, pathFileTest, pathModel, nnArchitecture, nnClassCount, nnIsTrained, trBatchSize, transResize, transCrop, launchTimeStamp):   
        
        
        CLASS_NAMES = [ 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
                'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia']
        
        cudnn.benchmark = True
        
        #-------------------- SETTINGS: NETWORK ARCHITECTURE, MODEL LOAD
        if nnArchitecture == 'DENSE-NET-121-Sigmoid': model = DenseNet121_Sigmoid(nnClassCount, nnIsTrained).cuda()
        elif nnArchitecture == 'DENSE-NET-121': model = DenseNet121(nnClassCount, nnIsTrained).cuda()
        elif nnArchitecture == 'DENSE-NET-169': model = DenseNet169(nnClassCount, nnIsTrained).cuda()
        elif nnArchitecture == 'DENSE-NET-201': model = DenseNet201(nnClassCount, nnIsTrained).cuda()
        
        model = torch.nn.DataParallel(model).cuda() 
        
        modelCheckpoint = torch.load(pathModel)
        model.load_state_dict(modelCheckpoint['state_dict'])

        #-------------------- SETTINGS: DATA TRANSFORMS, TEN CROPS
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        #-------------------- SETTINGS: DATASET BUILDERS
        transformList = []
        transformList.append(transforms.Resize(transResize))
        transformList.append(transforms.TenCrop(transCrop))
        transformList.append(transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])))
        transformList.append(transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops])))
        transformSequence=transforms.Compose(transformList)
        
        datasetTest = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTest, transform=transformSequence)
        dataLoaderTest = DataLoader(dataset=datasetTest, batch_size=trBatchSize, num_workers=4, shuffle=False, pin_memory=True)
        
        outGT = torch.FloatTensor().cuda()
        outPRED = torch.FloatTensor().cuda()
       
        model.eval()
        
        for i, (input, target) in enumerate(dataLoaderTest):
            
            target = target.cuda()
            outGT = torch.cat((outGT, target), 0)
            
            bs, n_crops, c, h, w = input.size()
            
            varInput = torch.autograd.Variable(input.view(-1, c, h, w).cuda(), volatile=True)
            
            out = model(varInput)
            outMean = out.view(bs, n_crops, -1).mean(1)
            
            outPRED = torch.cat((outPRED, outMean.data), 0)

        aurocIndividual = ChexnetTrainer.computeAUROC(outGT, outPRED, nnClassCount)
        aurocMean = np.array(aurocIndividual).mean()
        
        print ('AUROC mean ', aurocMean)
        
        for i in range (0, len(aurocIndividual)):
            print (CLASS_NAMES[i], ' ', aurocIndividual[i])
        
     
        return
#-------------------------------------------------------------------------------- 


In [7]:
DATA_DIR = './data'
TRAIN_IMAGE_LIST = './data/CheXpert-v1.0-small/train.csv'
VAL_IMAGE_LIST = './data/CheXpert-v1.0-small/valid.csv'
valid_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                image_list_file=VAL_IMAGE_LIST)

nnIsTrained = True
nnArchitecture = 'DENSE-NET-121'
nnClassCount = 9
trBatchSize = 6
trMaxEpoch = 50
transResize = 256
transCrop = 224
launchTimestamp = ''
checkpoint = None

ChexnetTrainer.train(DATA_DIR,TRAIN_IMAGE_LIST,VAL_IMAGE_LIST,nnArchitecture, nnIsTrained, nnClassCount, trBatchSize, trMaxEpoch, transResize, transCrop, launchTimestamp, checkpoint)



UnboundLocalError: local variable 'model' referenced before assignment

Original baseline without our cool loss function

In [4]:
class ChexnetTrainer ():

    #---- Train the densenet network 
    #---- pathDirData - path to the directory that contains images
    #---- pathFileTrain - path to the file that contains image paths and label pairs (training set)
    #---- pathFileVal - path to the file that contains image path and label pairs (validation set)
    #---- nnArchitecture - model architecture 'DENSE-NET-121', 'DENSE-NET-169' or 'DENSE-NET-201'
    #---- nnIsTrained - if True, uses pre-trained version of the network (pre-trained on imagenet)
    #---- nnClassCount - number of output classes 
    #---- trBatchSize - batch size
    #---- trMaxEpoch - number of epochs
    #---- transResize - size of the image to scale down to (not used in current implementation)
    #---- transCrop - size of the cropped image 
    #---- launchTimestamp - date/time, used to assign unique name for the checkpoint file
    #---- checkpoint - if not None loads the model and continues training
    
    def train (pathDirData, pathFileTrain, pathFileVal, nnArchitecture, nnIsTrained, nnClassCount, trBatchSize, trMaxEpoch, transResize, transCrop, launchTimestamp, checkpoint):

        
        #-------------------- SETTINGS: NETWORK ARCHITECTURE
        if nnArchitecture == 'DENSE-NET-121-Sigmoid': model = DenseNet121_Sigmoid(nnClassCount, nnIsTrained).cuda()
        elif nnArchitecture == 'DENSE-NET-121': model = DenseNet121(nnClassCount, nnIsTrained).cuda()
        elif nnArchitecture == 'DENSE-NET-169': model = DenseNet169(nnClassCount, nnIsTrained).cuda()
        elif nnArchitecture == 'DENSE-NET-201': model = DenseNet201(nnClassCount, nnIsTrained).cuda()
        
        model = torch.nn.DataParallel(model)#.cuda()
       
        #-------------------- SETTINGS: DATA TRANSFORMS |TRAIN|
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        transformList = []
        transformList.append(transforms.RandomResizedCrop(transCrop))
        transformList.append(transforms.RandomHorizontalFlip())
        transformList.append(transforms.ToTensor())
        transformList.append(normalize)      
        transformSequence=transforms.Compose(transformList)

        #-------------------- SETTINGS: DATASET BUILDER |TRAIN|
                    
        datasetTrain = ChestXrayDataSet(data_dir=pathDirData,image_list_file=pathFileTrain, transform=transformSequence)
        #datasetVal =   ChestXrayDataSet(data_dir=pathDirData, image_list_file=pathFileVal, diseases=['Pleural Effusion'], transform=transformSequence)
              
        dataLoaderTrain = DataLoader(dataset=datasetTrain, batch_size=trBatchSize, shuffle=True,  num_workers=4, pin_memory=True)
        #dataLoaderVal = DataLoader(dataset=datasetVal, batch_size=trBatchSize, shuffle=False, num_workers=24, pin_memory=True)
        
        
        
        #-------------------- SETTINGS: DATA TRANSFORMS, TEN CROPS |VAL|
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        #-------------------- SETTINGS: DATASET BUILDERS |VAL|
        transformList = []
        
        transformList.append(transforms.RandomResizedCrop(transCrop))
        transformList.append(transforms.RandomHorizontalFlip())
        transformList.append(transforms.ToTensor())
        transformList.append(normalize)      
        
        
#         transformList.append(transforms.Resize(transResize))
#         transformList.append(transforms.TenCrop(transCrop))
#        transformList.append(normalize)  
#         transformList.append(transforms.ToTensor())
#         transformList.append(normalize)
#         transformList.append(transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])))
#         transformList.append(transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops])))
      
        transformSequence=transforms.Compose(transformList)
        
        datasetVal =   ChestXrayDataSet(data_dir=pathDirData, image_list_file=pathFileVal, transform=transformSequence)
        dataLoaderVal = DataLoader(dataset=datasetVal, batch_size=trBatchSize, shuffle=False, num_workers=4, pin_memory=True)
        
        #-------------------- SETTINGS: OPTIMIZER & SCHEDULER
        optimizer = optim.Adam (model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
        scheduler = ReduceLROnPlateau(optimizer, factor = 0.1, patience = 5, mode = 'min')
                
        #-------------------- SETTINGS: LOSS
        loss = torch.nn.BCELoss(size_average = True)
        
        #---- Load checkpoint 
        if checkpoint != None:
            modelCheckpoint = torch.load(checkpoint)
            model.load_state_dict(modelCheckpoint['state_dict'])
            optimizer.load_state_dict(modelCheckpoint['optimizer'])

        
        #---- TRAIN THE NETWORK
        
        lossMIN = 100000
        
        for epochID in range (0, trMaxEpoch):
            
            timestampTime = time.strftime("%H%M%S")
            timestampDate = time.strftime("%d%m%Y")
            timestampSTART = timestampDate + '-' + timestampTime
                         
            ChexnetTrainer.epochTrain (model, dataLoaderTrain, optimizer, scheduler, trMaxEpoch, nnClassCount, loss)
            lossVal, losstensor = ChexnetTrainer.epochVal (model, dataLoaderVal, optimizer, scheduler, trMaxEpoch, nnClassCount, loss)
            
            timestampTime = time.strftime("%H%M%S")
            timestampDate = time.strftime("%d%m%Y")
            timestampEND = timestampDate + '-' + timestampTime
            
            scheduler.step(losstensor.data[0])
            
            if lossVal < lossMIN:
                lossMIN = lossVal    
                torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN, 'optimizer' : optimizer.state_dict()}, 'm-' + launchTimestamp + '.pth.tar')
                print ('Epoch [' + str(epochID + 1) + '] [save] [' + timestampEND + '] loss= ' + str(lossVal))
            else:
                print ('Epoch [' + str(epochID + 1) + '] [----] [' + timestampEND + '] loss= ' + str(lossVal))
                     
    #-------------------------------------------------------------------------------- 
       
    def epochTrain (model, dataLoader, optimizer, scheduler, epochMax, classCount, loss):
        
        model.train()
        
        for batchID, (input, target) in enumerate (dataLoader):
                        
            target = target.cuda()
                 
            varInput = torch.autograd.Variable(input)
            varTarget = torch.autograd.Variable(target)         
            varOutput = model(varInput)
            
            lossvalue = loss(varOutput, varTarget)
                       
            optimizer.zero_grad()
            lossvalue.backward()
            optimizer.step()
            
    #-------------------------------------------------------------------------------- 
        
    def epochVal (model, dataLoader, optimizer, scheduler, epochMax, classCount, loss):
        
        model.eval ()
        
        lossVal = 0
        lossValNorm = 0
        
        losstensorMean = 0
        
        for i, (input, target) in enumerate (dataLoader):
            
            target = target.cuda()
            
            varInput = torch.autograd.Variable(input, volatile=True)
            varTarget = torch.autograd.Variable(target, volatile=True)    
            varOutput = model(varInput)
            
            losstensor = loss(varOutput, varTarget)
            losstensorMean += losstensor
            
            lossVal += losstensor.data[0]
            lossValNorm += 1
            
        outLoss = lossVal / lossValNorm
        losstensorMean = losstensorMean / lossValNorm
        
        return outLoss, losstensorMean
               
    #--------------------------------------------------------------------------------     
     
    #---- Computes area under ROC curve 
    #---- dataGT - ground truth data
    #---- dataPRED - predicted data
    #---- classCount - number of classes
    
    def computeAUROC (dataGT, dataPRED, classCount):
        
        outAUROC = []
        
        datanpGT = dataGT.cpu().numpy()
        datanpPRED = dataPRED.cpu().numpy()
        
        for i in range(classCount):
            outAUROC.append(roc_auc_score(datanpGT[:, i], datanpPRED[:, i]))
            
        return outAUROC
        
        
    #--------------------------------------------------------------------------------  
    
    #---- Test the trained network 
    #---- pathDirData - path to the directory that contains images
    #---- pathFileTrain - path to the file that contains image paths and label pairs (training set)
    #---- pathFileVal - path to the file that contains image path and label pairs (validation set)
    #---- nnArchitecture - model architecture 'DENSE-NET-121', 'DENSE-NET-169' or 'DENSE-NET-201'
    #---- nnIsTrained - if True, uses pre-trained version of the network (pre-trained on imagenet)
    #---- nnClassCount - number of output classes 
    #---- trBatchSize - batch size
    #---- trMaxEpoch - number of epochs
    #---- transResize - size of the image to scale down to (not used in current implementation)
    #---- transCrop - size of the cropped image 
    #---- launchTimestamp - date/time, used to assign unique name for the checkpoint file
    #---- checkpoint - if not None loads the model and continues training
    
#     def test (pathDirData, pathFileTest, pathModel, nnArchitecture, nnClassCount, nnIsTrained, trBatchSize, transResize, transCrop, launchTimeStamp):   
        
        
#         CLASS_NAMES = [ 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
#                 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia']
        
#         cudnn.benchmark = True
        
#         #-------------------- SETTINGS: NETWORK ARCHITECTURE, MODEL LOAD
#         model = DenseNet121_Sigmoid(nnClassCount, nnIsTrained).cuda()
#         if nnArchitecture == 'DENSE-NET-121-Sigmoid': model = DenseNet121_Sigmoid(nnClassCount, nnIsTrained).cuda()
#         elif nnArchitecture == 'DENSE-NET-121': model = DenseNet121(nnClassCount, nnIsTrained).cuda()
#         elif nnArchitecture == 'DENSE-NET-169': model = DenseNet169(nnClassCount, nnIsTrained).cuda()
#         elif nnArchitecture == 'DENSE-NET-201': model = DenseNet201(nnClassCount, nnIsTrained).cuda()
        
#         model = torch.nn.DataParallel(model).cuda() 
        
#         modelCheckpoint = torch.load(pathModel)
#         model.load_state_dict(modelCheckpoint['state_dict'])

#         #-------------------- SETTINGS: DATA TRANSFORMS, TEN CROPS
#         normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
#         #-------------------- SETTINGS: DATASET BUILDERS
#         transformList = []
#         transformList.append(transforms.Resize(transResize))
#         transformList.append(transforms.TenCrop(transCrop))
#         transformList.append(transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])))
#         transformList.append(transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops])))
#         transformSequence=transforms.Compose(transformList)
        
#         datasetTest = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTest, transform=transformSequence)
#         dataLoaderTest = DataLoader(dataset=datasetTest, batch_size=trBatchSize, num_workers=8, shuffle=False, pin_memory=True)
        
#         outGT = torch.FloatTensor().cuda()
#         outPRED = torch.FloatTensor().cuda()
       
#         model.eval()
        
#         for i, (input, target) in enumerate(dataLoaderTest):
            
#             target = target.cuda()
#             outGT = torch.cat((outGT, target), 0)
            
#             bs, n_crops, c, h, w = input.size()
            
#             varInput = torch.autograd.Variable(input.view(-1, c, h, w).cuda(), volatile=True)
            
#             out = model(varInput)
#             outMean = out.view(bs, n_crops, -1).mean(1)
            
#             outPRED = torch.cat((outPRED, outMean.data), 0)

#         aurocIndividual = ChexnetTrainer.computeAUROC(outGT, outPRED, nnClassCount)
#         aurocMean = np.array(aurocIndividual).mean()
        
#         print ('AUROC mean ', aurocMean)
        
#         for i in range (0, len(aurocIndividual)):
#             print (CLASS_NAMES[i], ' ', aurocIndividual[i])
        
     
#         return

In [None]:
DATA_DIR = './data'
TRAIN_IMAGE_LIST = './data/CheXpert-v1.0-small/train.csv'
VAL_IMAGE_LIST = './data/CheXpert-v1.0-small/valid.csv'
valid_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                image_list_file=VAL_IMAGE_LIST)

nnIsTrained = True
nnArchitecture = 'DENSE-NET-121-Sigmoid'
nnClassCount = 5
trBatchSize = 6
trMaxEpoch = 50
transResize = 256
transCrop = 224
launchTimestamp = ''
checkpoint = None

ChexnetTrainer.train(DATA_DIR,TRAIN_IMAGE_LIST,VAL_IMAGE_LIST,nnArchitecture, nnIsTrained, nnClassCount, trBatchSize, trMaxEpoch, transResize, transCrop, launchTimestamp, checkpoint)

