In [1]:
import warnings
warnings.filterwarnings('ignore')
import torch
from torchvision import transforms, models, datasets
import numpy as np
import torch.optim as optim
import torch.nn as nn
import pandas as pd
import os
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset
from torch.nn import functional as F

In [28]:
from collections import OrderedDict

In [2]:
import matplotlib.pyplot as plt
from PIL import Image

In [3]:
labels_csv = {'train': "/scratch/scratch6/akansh12/DeepEXrays/physionet.org/files/vindr-cxr/1.0.0/annotations/image_labels_train.csv",
             'test': "/scratch/scratch6/akansh12/DeepEXrays/physionet.org/files/vindr-cxr/1.0.0/annotations/image_labels_test.csv"
             }

data_dir = {'train': "/scratch/scratch6/akansh12/DeepEXrays/data/data_1024/train/",
           'test': "/scratch/scratch6/akansh12/DeepEXrays/data/data_1024/test/"}

In [20]:
#Normalization values:
global_labels = ['Pleural effusion', 'Lung tumor', 'Pneumonia', 'Tuberculosis', 'Other diseases', 'No finding']




In [5]:
#dataset
class Vin_big_dataset(Dataset):
    def __init__(self, image_loc, label_loc, transforms, data_type = 'train'):
        global_labels = ['Pleural effusion', 'Lung tumor', 'Pneumonia', 'Tuberculosis', 'Other diseases', 'No finding']
        filenames = os.listdir(image_loc)
        self.full_filenames = [os.path.join(image_loc, i) for i in filenames]
        
        label_df = pd.read_csv(label_loc)
        label_df.set_index("image_id", inplace = True)
        self.labels = [label_df[global_labels].loc[filename[:-4]].values for filename in filenames]
            
        self.transforms = transforms
        self.data_type = data_type
    def __len__(self):
        return len(self.full_filenames)
    
    def __getitem__(self, idx):
        image = Image.open(self.full_filenames[idx])
        image = self.transforms(image)
        
        if self.data_type == 'train':
            return image, self.labels[idx][np.random.choice([0,1,2], size = 1)[0]]
        else:
            return image, self.labels[idx]
    
    
            

In [6]:
data_transforms = { 
    "train": transforms.Compose([
        transforms.Resize((256,256)),
        transforms.CenterCrop((224,224)),
        transforms.RandomHorizontalFlip(p = 0.5), 
        transforms.RandomRotation((-20,20)),
        transforms.ToTensor(),
        transforms.Normalize([123.675,116.28,103.53], [58.395,57.12,57.375])
    ]),
    
    "test": transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([123.675,116.28,103.53], [58.395,57.12,57.375])        
    ])
    
}

In [7]:
train_data = Vin_big_dataset(image_loc = data_dir['train'],
                          label_loc = labels_csv['train'],
                          transforms = data_transforms['train'],
                          data_type = 'train')

test_data = Vin_big_dataset(image_loc = data_dir['test'],
                          label_loc = labels_csv['test'],
                          transforms = data_transforms['test'],
                          data_type = 'test')

In [22]:
trainloader = DataLoader(train_data,batch_size = 16,shuffle = True)
testloader = DataLoader(test_data,batch_size = 16,shuffle = False)

In [25]:
model = models.densenet121(pretrained=True)

Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /storage/home/akansh12/.cache/torch/checkpoints/densenet121-a639ec97.pth


HBox(children=(FloatProgress(value=0.0, max=32342954.0), HTML(value='')))




In [29]:
model.classifier = nn.Sequential(OrderedDict([
    ('fcl1', nn.Linear(1024,256)),
    ('dp1', nn.Dropout(0.3)),
    ('r1', nn.ReLU()),
    ('fcl2', nn.Linear(256,32)),
    ('dp2', nn.Dropout(0.3)),
    ('r2', nn.ReLU()),
    ('fcl3', nn.Linear(32,6)),
    ('out', nn.LogSoftmax(dim=1)),
]))