In [None]:
import pandas as pd
import os
import numpy as np
from tqdm import tqdm
import pydicom
from PIL import Image
import pandas as pd
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
from sklearn import metrics
from sklearn.metrics import classification_report, roc_auc_score, roc_curve, precision_recall_curve
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve, precision_recall_curve, matthews_corrcoef, auc, accuracy_score, recall_score, precision_score, f1_score
import torchvision
from torchvision.transforms import InterpolationMode
from torchvision import transforms, datasets, models
import torch
from torch import optim, cuda
from torch.utils.data import DataLoader, sampler
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset
import time
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)
from timeit import default_timer as timer
from collections import OrderedDict

from hachoir.parser import createParser # pip install --user hachoir
from hachoir.metadata import extractMetadata

import timm

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"]="4"

In [None]:
data_df = pd.read_csv('image_table.csv')


In [None]:
imgtransResize = (320, 320)
#TRANSFORM DATA

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transformList = []
transformList.append(transforms.RandomHorizontalFlip())
transformList.append(transforms.RandomInvert())
transformList.append(transforms.RandomRotation(degrees=(-15, 15), expand=False, interpolation=InterpolationMode.BICUBIC))
transformList.append(transforms.ColorJitter(brightness=0.5, contrast=0.5)) # try 0.75 next
#transformList.append(transforms.RandAugment(num_ops=2, magnitude=9, interpolation=InterpolationMode.BICUBIC))
transformList.append(transforms.Resize(imgtransResize,interpolation=InterpolationMode.BICUBIC))
transformList.append(transforms.ToTensor())
transformList.append(normalize)      
transformSequence=transforms.Compose(transformList)


val_transformList = []
val_transformList.append(transforms.Resize(imgtransResize,interpolation=InterpolationMode.BICUBIC))
val_transformList.append(transforms.ToTensor())
val_transformList.append(normalize)
val_transformSequence=transforms.Compose(val_transformList)



In [None]:
class png_loader(Dataset):
    def __init__(self, image_list_file, transform=None):

        image_names = []
        labels = []

        for row, line in image_list_file.iterrows():
            image_name = line['png_path']
            label = line[y_list]
            
            for i in range(len(label)):
                label[i] = line[y_list[i]]
            
            image_names.append(image_name)
            labels.append(label)

        self.image_names = image_names
        self.labels = labels
        self.transform = transform

    def __getitem__(self, index):
        """Take the index of item and returns the image and its labels"""
        
        filename = self.image_names[index]

        image = Image.open(root_dir + str(filename))
        
        image = image.convert('RGB')

        label = self.labels[index]
        
        if self.transform is not None:
            image = self.transform(image)
        return image, torch.FloatTensor(label)

    def __len__(self):
        return len(self.image_names)

In [None]:
y_list = ['bilateral']


In [None]:
training_BatchSize = 8
nnClassCount = len(y_list)

In [None]:
empi_list = data_df.empi_anon.tolist()

In [None]:
val_list = shuffle(empi_list, random_state=42)[:1000]

In [None]:
train_df = data_df[~data_df.empi_anon.isin(val_list)]
validate_df = data_df[data_df.empi_anon.isin(val_list)]

In [None]:
datasetTrain = png_loader(train_df, transformSequence)
datasetValid = png_loader(validate_df, val_transformSequence)

dataLoaderTrain = DataLoader(dataset=datasetTrain, batch_size=training_BatchSize, shuffle=True,  num_workers=32, pin_memory=True, drop_last=True)
dataLoaderVal = DataLoader(dataset=datasetValid, batch_size=training_BatchSize, shuffle=False, num_workers=32, pin_memory=True)


In [None]:
class Model_Trainer():

    def train (model, dataLoaderTrain, dataLoaderVal, nnClassCount, trMaxEpoch, launchTimestamp, checkpoint):
        
        overall_start = timer()
        #SETTINGS: OPTIMIZER & SCHEDULER
        optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=0, min_lr=1e-6, verbose=True)
        #SETTINGS: LOSS
        loss = torch.nn.BCEWithLogitsLoss()
        #loss = torch.nn.CrossEntropyLoss()
        
        #LOAD CHECKPOINT 
        if checkpoint != None and use_gpu:
            modelCheckpoint = torch.load(checkpoint)
            model.load_state_dict(modelCheckpoint['state_dict'])
            optimizer.load_state_dict(modelCheckpoint['optimizer'])

        
        #TRAIN THE NETWORK
        lossMIN = 100000
        start = timer()
        dir_path = dir_str+launchTimestamp
        #os.mkdir(dir_path)
        for epochID in range(0, trMaxEpoch):
            
            timestampTime = time.strftime("%H%M%S")
            timestampDate = time.strftime("%d%m%Y")
            timestampSTART = timestampDate + '-' + timestampTime
            
            batchs, losst = Model_Trainer.epochTrain(model, dataLoaderTrain, optimizer, trMaxEpoch, epochID, nnClassCount, loss, start, lossMIN, launchTimestamp, dir_path)
            outLoss, ground_truth, prediction, aurocMean, aurocIndividual = Model_Trainer.test(model, dataLoaderVal, nnClassCount, class_names, loss)
   
            outLoss = outLoss.cpu().detach().numpy()
            print("\nval loss: " + str(outLoss))

            timestampTime = time.strftime("%H%M%S")
            timestampDate = time.strftime("%d%m%Y")
            timestampEND = timestampDate + '-' + timestampTime
            scheduler.step(outLoss)
            
            torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict()}, 'save_bad_exam_models/bad_exam_detection_' + timestampEND + '_val_loss-' + str(outLoss) + '.pth.tar')   
   
        return batchs, losst        
    #-------------------------------------------------------------------------------- 
       
    def epochTrain(model, dataLoader, optimizer, epochMax, epochID, classCount, loss, start, lossMIN, launchTimestamp, dir_path):
        
        batch = []
        losstrain = []
        losseval = []
        
        
        
        model.train()

        for batchID, (varInput, target) in enumerate(dataLoaderTrain):
            
            target = torch.squeeze(target)
            target = torch.ravel(target)
            varInput = varInput.cuda()
            varTarget = target.cuda(non_blocking = True)
            
            with torch.cuda.amp.autocast():
                varOutput = model(varInput)
                varOutput = np.squeeze(varOutput)
                lossvalue = loss(varOutput, varTarget)
            
            optimizer.zero_grad()
            scaler.scale(lossvalue).backward()#lossvalue.backward()
            scaler.step(optimizer)#optimizer.step()
            scaler.update()
          
            l = lossvalue.item()
            losstrain.append(l)
            print(
                f'Epoch: {epochID}\t{100 * (batchID / (len(datasetTrain)//training_BatchSize)):.1f}% complete. {timer() - start:.1f} seconds elapsed in epoch. Training loss: ' + str(round(np.mean(losstrain), 4)),
                end='\r')
           
        return batch, losstrain
    
    
    def epochVal(model, dataLoader, optimizer, loss):
        
        model.eval()
        
        lossVal = 0 
        lossValNorm = 0
        
        outGT = torch.FloatTensor().cuda()
        outPRED = torch.FloatTensor().cuda()

        with torch.no_grad():
            for i, (varInput, target) in enumerate(dataLoader):
                
                
                target = torch.FloatTensor(target)
                target = target.cuda(non_blocking = True)
                
                varOutput = model(varInput)
                
                losstensor = loss(varOutput, target)
                
                lossVal += losstensor
                lossValNorm += 1
                                
                outGT = torch.cat((outGT, target), 0).cuda()
                outPRED = torch.cat((outPRED, varOutput), 0)
                
                print(
                    f'Epoch: {epochID}\t{100 * (batchID / (len(dataLoader)//training_BatchSize)):.1f}% complete. {timer() - start:.1f} seconds elapsed in epoch. Training loss: ' + str(round(np.mean(losstrain), 4)),
                    end='\r')
           
        
                
        
        outLoss = lossVal / lossValNorm
        
        print("val_loss: " + str(outLoss))
           
        return outLoss, outGT, outPRED
    

    
    def computeAUROC (dataGT, dataPRED, classCount):
        
        outAUROC = []
        
        datanpGT = dataGT.cpu().numpy()
        datanpPRED = dataPRED.cpu().numpy()

        outAUROC.append(roc_auc_score(datanpGT, datanpPRED))

        return outAUROC
        
        
    
    
    def test(model, dataLoaderTest, nnClassCount, class_names, loss):   
        
        outGT = torch.FloatTensor().cuda()
        outPRED = torch.FloatTensor().cuda()

       
        model.eval()
        
        lossVal = 0
        lossValNorm = 0
        
        batch_num = 0
        div = len(dataLoaderTest)
        with torch.no_grad():
            for i, (input, target) in enumerate(dataLoaderTest):
                
                target = torch.squeeze(target)
                target = torch.ravel(target)
                target = target.cuda()
                outGT = torch.cat((outGT, target), 0).cuda()

                bs, c, h, w = input.size()
                varInput = input.view(-1, c, h, w)
            
                with torch.cuda.amp.autocast():
                    varOutput = model(varInput)
                    varOutput = np.squeeze(varOutput)
                    losstensor = loss(varOutput, target)   
                                
                lossVal += losstensor
                lossValNorm += 1
                
                outPRED = torch.cat((outPRED, varOutput), 0)
                
                print(
                    str(round(batch_num/div,3)),
                    end='\r')
               
                batch_num = batch_num + 1
                
                
        outLoss = lossVal / lossValNorm 
        
        aurocIndividual = Model_Trainer.computeAUROC(outGT, outPRED, nnClassCount)
        aurocMean = np.array(aurocIndividual).mean()
        
        print ('\nAUROC mean ', round(aurocMean, 5))
        
        for i in range (0, len(aurocIndividual)):
            print (class_names[i], ' ', round(aurocIndividual[i],5))
            
        print("\n")
        
        return outLoss, outGT, outPRED, aurocMean, aurocIndividual


In [None]:
model = timm.create_model('convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_320', pretrained=True, in_chans=3)


In [None]:
num_features = model.head.fc.in_features


In [None]:
model.head.fc = nn.Sequential(
    nn.Linear(num_features, len(y_list), bias=True),
)

In [None]:
model = torch.nn.DataParallel(model).cuda()


In [None]:
nnClassCount = len(y_list)

In [None]:

timestampTime = time.strftime("%H%M%S")
timestampDate = time.strftime("%d%m%Y")
timestampLaunch = timestampDate + '-' + timestampTime



In [None]:
dir_str = 'bilateral'

In [None]:
batch, losst = Model_Trainer.train(model, dataLoaderTrain, dataLoaderVal, nnClassCount, trMaxEpoch, timestampLaunch, checkpoint = None)
