In [2]:
import warnings
warnings.filterwarnings('ignore')
import torch
from torchvision import transforms, models, datasets
import numpy as np
import torch.optim as optim
import torch.nn as nn
import pandas as pd
import os
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset
from torch.nn import functional as F

In [3]:
from collections import OrderedDict

In [4]:
import timm

In [5]:
import matplotlib.pyplot as plt
from PIL import Image

In [6]:
labels_csv = {'train': "/scratch/scratch6/akansh12/DeepEXrays/physionet.org/files/vindr-cxr/1.0.0/annotations/image_labels_train.csv",
             'test': "/scratch/scratch6/akansh12/DeepEXrays/physionet.org/files/vindr-cxr/1.0.0/annotations/image_labels_test.csv"
             }

data_dir = {'train': "/scratch/scratch6/akansh12/DeepEXrays/data/data_1024/train/",
           'test': "/scratch/scratch6/akansh12/DeepEXrays/data/data_1024/test/"}

In [7]:
#Normalization values:
global_labels = ['Pleural effusion', 'Lung tumor', 'Pneumonia', 'Tuberculosis', 'Other diseases', 'No finding']

In [15]:
#dataset
class Vin_big_dataset(Dataset):
    def __init__(self, image_loc, label_loc, transforms, data_type = 'train'):
        global_labels = ['Pleural effusion', 'Lung tumor', 'Pneumonia', 'Tuberculosis', 'Other diseases', 'No finding']
        
        if data_type == 'train':
            label_df = pd.read_csv(label_loc)
            label_df['labels'] = label_df['image_id'] +'_'+ label_df['rad_id']
            label_df.set_index("labels", inplace = True)
            filenames = label_df.index.values.tolist()
            
            self.full_filenames = [os.path.join(image_loc, i.split('_')[0]+'.png') for i in filenames]
            self.labels = []
            for i in tqdm(filenames):
                self.labels.append(label_df[global_labels].loc[i].values.tolist())         
            self.labels = torch.tensor(self.labels)
        if data_type == 'test':                     
            filenames = os.listdir(image_loc)
            self.full_filenames = [os.path.join(image_loc, i) for i in filenames]
            label_df = pd.read_csv(label_loc)
            label_df.set_index("image_id", inplace = True)
            self.labels = [label_df[global_labels].loc[filename[:-4]].values for filename in filenames]
            
        self.transforms = transforms
#         self.data_type = data_type
    def __len__(self):
        return len(self.full_filenames)
    
    def __getitem__(self, idx):
        image = Image.open(self.full_filenames[idx])
        image = self.transforms(image)
        
        return image, self.labels[idx]
    
            

### Get mean and STD

In [16]:
train_data = Vin_big_dataset(image_loc = data_dir['train'],
                          label_loc = labels_csv['train'],
                          transforms = transforms.ToTensor(),
                          data_type = 'train')

def get_mean_std(loader):
    channels_sum, channels_squared_sum, num_batches = 0,0,0
    
    for data,_ in tqdm(loader):
        channels_sum += torch.mean(data, dim = [0,2,3])
        channels_squared_sum += torch.mean(data**2, dim = [0,2,3])
        num_batches += 1
        
    mean = channels_sum/num_batches
    std = (channels_squared_sum/num_batches - mean**2)**0.5
    
    return mean, std

mean, std = get_mean_std(DataLoader(train_data,batch_size = 16,shuffle = True))
print(mean, std)

HBox(children=(FloatProgress(value=0.0, max=45000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2813.0), HTML(value='')))


tensor([0.5490, 0.5490, 0.5490]) tensor([0.2684, 0.2684, 0.2684])


In [17]:
data_transforms = { 
    "train": transforms.Compose([
#         transforms.Resize((256,256)),
#         transforms.CenterCrop((224,224)),
        transforms.RandomHorizontalFlip(p = 0.5), 
        transforms.RandomPerspective(distortion_scale=0.3),
        transforms.RandomRotation((-20,20)),
        transforms.ToTensor(),
        transforms.Normalize([0.5490, 0.5490, 0.5490], [0.2684, 0.2684, 0.2684])
    ]),
    
    "test": transforms.Compose([
#         transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5490, 0.5490, 0.5490], [0.2684, 0.2684, 0.2684])        
    ])
    
}

In [18]:
train_data = Vin_big_dataset(image_loc = data_dir['train'],
                          label_loc = labels_csv['train'],
                          transforms = data_transforms['train'],
                          data_type = 'train')

test_data = Vin_big_dataset(image_loc = data_dir['test'],
                          label_loc = labels_csv['test'],
                          transforms = data_transforms['test'],
                          data_type = 'test')

HBox(children=(FloatProgress(value=0.0, max=45000.0), HTML(value='')))




In [19]:
trainloader = DataLoader(train_data,batch_size = 4,shuffle = True)
testloader = DataLoader(test_data,batch_size = 4,shuffle = False)

In [20]:
model = timm.create_model('efficientnet_b6', pretrained=False)

In [21]:
# torch.save(model.state_dict(), "/storage/home/akansh12/Vin-ChestXR-Abnormality-detection/model/imageNet_DenseNet201.pt")

In [22]:
model.load_state_dict(torch.load("/storage/home/akansh12/Vin-ChestXR-Abnormality-detection/model/tf_efficientnet_b6_aa-80ba17e4.pth"))

<All keys matched successfully>

In [23]:
model.classifier = nn.Sequential(OrderedDict([
    ('fcl1', nn.Linear(2304,512)),
    ('dp1', nn.Dropout(0.3)),
    ('r1', nn.ReLU()),
    ('fcl2', nn.Linear(512,6)),
    ('out', nn.Sigmoid()),
]))

In [24]:
#metric
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_score,recall_score, f1_score
# roc_auc_score(y, clf.decision_function(X), average=None)
def calculate_metrics(pred, target, threshold=0.5):
    pred = np.array(pred > threshold, dtype=float)
    return {'micro/precision': precision_score(y_true=target, y_pred=pred, average='micro')*100,
            'micro/recall': recall_score(y_true=target, y_pred=pred, average='micro')*100,
            'micro/f1': f1_score(y_true=target, y_pred=pred, average='micro')*100,
            'macro/precision': precision_score(y_true=target, y_pred=pred, average='macro')*100,
            'macro/recall': recall_score(y_true=target, y_pred=pred, average='macro')*100,
            'macro/f1': f1_score(y_true=target, y_pred=pred, average='macro')*100,
            'samples/precision': precision_score(y_true=target, y_pred=pred, average='samples')*100,
            'samples/recall': recall_score(y_true=target, y_pred=pred, average='samples')*100,
            'samples/f1': f1_score(y_true=target, y_pred=pred, average='samples')*100,
            }


In [25]:
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(),lr = 0.001)
schedular = optim.lr_scheduler.ReduceLROnPlateau(optimizer,factor = 0.1,patience = 5, verbose= True)
epochs = 40
test_loss_min = np.Inf

In [26]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [27]:
if torch.cuda.is_available():
    model = model.cuda()

In [None]:
train_loss = []
test_loss = []

for epoch in range(0,epochs):
    train_loss = 0.0
    test_loss = 0.0
    model_train_result = []
    train_target = []
    model_test_result = []
    test_target = []
    
    
    model.train()
    for images,labels in tqdm(trainloader):
        images = images.to(device)
        labels = labels.to(device)
        ps = model(images)
        
        #for metric computing
        model_train_result.extend(ps.detach().cpu().numpy().tolist())
        train_target.extend(labels.cpu().numpy())
        
        
        loss = criterion(ps,labels.type(torch.float))
        
        optimizer.zero_grad()
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        
    avg_train_loss = train_loss / len(trainloader)
    train_loss.append(avg_train_loss)
    
    train_result = calculate_metrics(np.array(model_train_result), train_target)
    train_auc = roc_auc_score(train_target, np.array(model_train_result), average=None)
    
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(testloader):
            images = images.to(device)
            ps = model(images)
            
            model_test_result.extend(ps.cpu().numpy().tolist())
            test_target.extend(labels.cpu().numpy())
            
            
            loss = criterion(ps,labels.type(torch.float))
            test_loss += loss.item()
            
        avg_test_loss = test_loss / len(testloader)
        test_loss.append(avg_test_loss)
        
        schedular.step(avg_test_loss)

        test_result = calculate_metrics(np.array(model_test_result), test_target)
        test_auc = roc_auc_score(test_target, np.array(model_test_result), average=None)

        
        if avg_test_loss <= test_loss_min:
                    print('testation loss decreased ({:.6f} --> {:.6f}).   Saving model ...'.format(test_loss_min,avg_test_loss))
                    torch.save({
                        'epoch' : i,
                        'model_state_dict' : model.state_dict(),
                        'optimizer_state_dict' : optimizer.state_dict(),
                        'test_loss_min' : avg_test_loss
                    },'DenseNet_size224.pt')
                
    
    
    print(f"Train Loss: {avg_train_loss}")
    
    print("epoch:{:2d} iter:{:3d} Train: "
                  "micro f1: {:.3f} "
                  "macro f1: {:.3f} "
                  "samples f1: {:.3f}".format(epoch,0,
                                              train_result['micro/f1'],
                                              train_result['macro/f1'],
                                              train_result['samples/f1']))
    print(f"Train AUC{train_auc}")
    
    
    print(f"Test Loss: {avg_test_loss}")

    print("epoch:{:2d} iter:{:3d} test: "
                  "micro f1: {:.3f} "
                  "macro f1: {:.3f} "
                  "samples f1: {:.3f}".format(epoch,0,
                                              test_result['micro/f1'],
                                              test_result['macro/f1'],
                                              test_result['samples/f1']))    
    print(f"Test AUC{test_auc}")
    

HBox(children=(FloatProgress(value=0.0, max=11250.0), HTML(value='')))