In [1]:
import numpy as np
import torch
from PIL import Image
import pandas as pd
import os
import matplotlib.pyplot as plt
torch.manual_seed(0)
from glob import glob
import numpy as np
import pandas as pd
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision import transforms, models, datasets
from PIL import Image
from torch import nn
from torch import optim
from tqdm.auto import tqdm
from sklearn.metrics import roc_auc_score, auc
from sklearn.metrics import precision_score,recall_score, f1_score
import timm

class fourteen_class(Dataset):
    def __init__(self, label_loc, img_location, transform,  data_type= 'train'):
        label_dataframe = pd.read_csv(label_loc)
        label_dataframe.set_index("image_id", inplace = True)
        filenames = label_dataframe.index.values
        self.full_filename = [os.path.join(img_location,i+'.png') for i in filenames]
        self.labels = label_dataframe.iloc[:].values
        self.transform = transform
    def __len__(self):
        return len(self.full_filename)
    
    def __getitem__(self, idx):
        
        image = Image.open(self.full_filename[idx])
        image = self.transform(image)
        return image, self.labels[idx]
#         return image, np.expand_dims(np.array(self.labels[idx]), axis = 0)
        

data_transforms = { 
    "train": transforms.Compose([
        transforms.RandomHorizontalFlip(p = 0.5), 
        transforms.RandomPerspective(distortion_scale=0.3),
        transforms.RandomRotation((-30,30)),
        transforms.ToTensor(),
        transforms.Normalize(mean =  [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
    ]),
    
    "test": transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean =  [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])        
    ])
    
}

train_data = fourteen_class("/storage/home/akansh12/Vin-ChestXR-Abnormality-detection/Notebooks/14_class/labels/exp_3.csv",
                                       img_location = "/scratch/scratch6/akansh12/DeepEXrays/data/data_256/train/", transform =data_transforms['train'])
test_data = fourteen_class("/storage/home/akansh12/Vin-ChestXR-Abnormality-detection/Notebooks/14_class/labels/test.csv",
                                       img_location = "/scratch/scratch6/akansh12/DeepEXrays/data/data_256/test/", transform =data_transforms['test'])


trainloader = DataLoader(train_data,batch_size = 4,shuffle = True)
testloader = DataLoader(test_data,batch_size = 8,shuffle = False)

model = timm.models.efficientnet_b0(pretrained=False)

model.load_state_dict(torch.load("/scratch/scratch6/akansh12/DeepEXrays/radiologist_selection/eff_bo.pt"))

from collections import OrderedDict
from torch import nn
model.classifier = nn.Sequential(OrderedDict([
    ('fcl1', nn.Linear(1280,15)),
    ('out', nn.Sigmoid()),
]))

device = 'cuda' if torch.cuda.is_available() else 'cpu'
if torch.cuda.is_available():
    model = model.cuda()
    print("cuda")
    model = nn.DataParallel(model)

def train_one_epoch(model, optimizer, lr_scheduler,
                    dataloader, epoch, criterion, device):
    
    print("Start Train ...")
    model.train()

    losses_train = []
    model_train_result = []
    train_target = []


    for data, targets in tqdm(dataloader):
        data = data.to(device)
        targets = targets.to(device).type(torch.float)


        outputs = model(data)
        model_train_result.extend(outputs.detach().cpu().numpy().tolist())
        train_target.extend(targets.cpu().numpy())


        loss = criterion(outputs, targets)

        losses_train.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


        
    if lr_scheduler is not None:
        lr_scheduler.step()

    lr = lr_scheduler.get_last_lr()[0]
    print("Epoch [%d]" % (epoch),
          "Mean loss on train:", np.array(losses_train).mean(), 
          "Learning Rate:", lr)

    
    return np.array(losses_train).mean(), lr


def val_epoch(model, dataloader, epoch, criterion, device):
    
    print("Start Validation ...")
    model.eval()
    
    model_val_result = []
    val_target = []
    losses_val = []

    with torch.no_grad():
        for data, targets in tqdm(dataloader):

            data = data.to(device)
            targets = targets.to(device).type(torch.float)

            outputs = model(data)
            
            #loss
            loss = criterion(outputs, targets)
            losses_val.append(loss.item())

            
            model_val_result.extend(outputs.detach().cpu().numpy().tolist())
            val_target.extend(targets.cpu().numpy())

        auc_score_valid = roc_auc_score(val_target, model_val_result)    
    
        print("Epoch:  " + str(epoch) + " AUC valid Score:", np.array(auc_score_valid), 
              "Mean Loss", np.array(losses_val).mean())
        
    return np.array(losses_val).mean(), np.array(auc_score_valid)



for param in model.parameters():
    param.requires_grad = True

num_epochs = 25
optimizer = optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = 0.00002, steps_per_epoch=len(trainloader), epochs=5)
criterion = nn.BCELoss()

train_loss_history = []
val_loss_history = []
val_auc_history = []
lr_history = []
weights_dir = "/scratch/scratch6/akansh12/DeepEXrays/radiologist_selection/exp_3/"


for epoch in range(num_epochs):
    
    
    train_loss, lr = train_one_epoch(model, optimizer, lr_scheduler,trainloader, epoch, criterion, device = device)
    val_loss, val_auc = val_epoch(model, testloader, epoch, criterion, device = device)
    
    
#     train history
    train_loss_history.append(train_loss)
    lr_history.append(lr)
    
    #val history
    val_loss_history.append(val_loss)
    val_auc_history.append(val_auc)
    
    # save best weights
    best_loss = min(val_loss_history)
    if (val_loss <= best_loss) or (epoch % 10 == 0):
        print('saving model')
        torch.save({'state_dict': model.state_dict()},
                    os.path.join(weights_dir, f"exp_3_eff_b0{val_loss:0.6f}_{epoch}_.pth"))


np.save("/storage/home/akansh12/Vin-ChestXR-Abnormality-detection/model/exp_3_train_loss_hist_label.npy", np.array(train_loss_history))
np.save("/storage/home/akansh12/Vin-ChestXR-Abnormality-detection/model/exp_3_test_loss_hist_label.npy", np.array(val_loss_history))
np.save("/storage/home/akansh12/Vin-ChestXR-Abnormality-detection/model/exp_3_test_auc_hist_label.npy", np.array(val_auc_history))


Start Train ...


  0%|          | 0/2798 [00:00<?, ?it/s]

KeyboardInterrupt: 