In [1]:
import warnings
warnings.filterwarnings('ignore')
import torch
from torchvision import transforms, models, datasets
import numpy as np
import torch.optim as optim
import torch.nn as nn
import pandas as pd
import os
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset
from torch.nn import functional as F
from collections import OrderedDict
import timm
import matplotlib.pyplot as plt
from PIL import Image
from collections import OrderedDict
from torch import nn

In [2]:
data_dir = {'train': "/scratch/scratch6/akansh12/DeepEXrays/data/data_256/train/",
           'test': "/scratch/scratch6/akansh12/DeepEXrays/data/data_256/test/"}
labels = ['Aortic enlargement', 'Atelectasis','Calcification', 'Cardiomegaly', 'Consolidation','ILD', 'Infiltration','Lung Opacity','Nodule/Mass','Other lesion', 'Pleural effusion', 'Pleural thickening', 'Pneumothorax',
       'Pulmonary fibrosis', 'No finding']

In [3]:
class fourteen_class(Dataset):
    def __init__(self, label_loc, img_location, transform,  data_type= 'train'):
        label_dataframe = pd.read_csv(label_loc)
        label_dataframe.set_index("image_id", inplace = True)
        filenames = label_dataframe.index.values
        self.full_filename = [os.path.join(img_location,i+'.png') for i in filenames]
        self.labels = label_dataframe.iloc[:].values
        self.transform = transform
    def __len__(self):
        return len(self.full_filename)
    
    def __getitem__(self, idx):
        
        image = Image.open(self.full_filename[idx])
        image = self.transform(image)
        return image, self.labels[idx]        

data_transforms = { 
    "train": transforms.Compose([
        transforms.RandomHorizontalFlip(p = 0.5), 
        transforms.RandomPerspective(distortion_scale=0.3),
        transforms.RandomRotation((-30,30)),
        transforms.ToTensor(),
        transforms.Normalize(mean =  [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
    ]),
    
    "test": transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean =  [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])        
    ])
    
}

In [5]:
train_data = fourteen_class("/storage/home/akansh12/Vin-ChestXR-Abnormality-detection/Notebooks/14_class/labels/exp_3.csv",
                                       img_location = "/scratch/scratch6/akansh12/DeepEXrays/data/data_256/train/", transform =data_transforms['train'])
test_data = fourteen_class("/storage/home/akansh12/Vin-ChestXR-Abnormality-detection/Notebooks/14_class/labels/test.csv",
                                       img_location = "/scratch/scratch6/akansh12/DeepEXrays/data/data_256/test/", transform =data_transforms['test'])


In [6]:
trainloader = DataLoader(train_data,batch_size = 4,shuffle = True)
testloader = DataLoader(test_data,batch_size = 4,shuffle = False)

In [7]:
def exp_model(path):
    model = timm.models.efficientnet_b0(pretrained=False)
    model.classifier = nn.Sequential(OrderedDict([
        ('fcl1', nn.Linear(1280,15)),
        ('out', nn.Sigmoid()),
    ]))    
    state_dict = torch.load(path, map_location = 'cpu')['state_dict']
    for keyA, keyB in zip(state_dict, model.state_dict()):
        state_dict = OrderedDict((keyB if k == keyA else k, v) for k, v in state_dict.items())

    model.load_state_dict(state_dict)
    
    return model

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
exp_5_model = exp_model("/scratch/scratch6/akansh12/DeepEXrays/radiologist_selection/exp_5/exp_5_eff_b00.183247_20_.pth")