In [1]:
import warnings
warnings.filterwarnings('ignore')
import torch
from torchvision import transforms, models, datasets
import numpy as np
import torch.optim as optim
import torch.nn as nn
import pandas as pd
import os
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset
from torch.nn import functional as F
import cv2
import subprocess
from collections import OrderedDict
import timm
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import roc_auc_score, auc
from sklearn.metrics import precision_score,recall_score, f1_score
import glob

In [2]:
labels_csv = {'train': "/storage/home/akansh12/Vin-ChestXR-Abnormality-detection/Notebooks/weekly_supervised/new_image_labels_train.csv",
             'test': "/scratch/scratch6/akansh12/DeepEXrays/physionet.org/files/vindr-cxr/1.0.0/annotations/image_labels_test.csv"
             }
labels_csv_bb = {'train': "/scratch/scratch6/akansh12/DeepEXrays/physionet.org/files/vindr-cxr/1.0.0/annotations/annotations_train.csv",
             'test': "/scratch/scratch6/akansh12/DeepEXrays/physionet.org/files/vindr-cxr/1.0.0/annotations/annotations_test.csv"
             }
data_dir = {'train': "/scratch/scratch6/akansh12/DeepEXrays/data/data_1024/train/",
           'test': "/scratch/scratch6/akansh12/DeepEXrays/data/data_1024/test/"}

In [3]:
global_labels = ['COPD', 'Lung tumor', 'Pneumonia', 'Tuberculosis', 'Other diseases', 'No finding']

local_labels = ['Aortic enlargement', 'Atelectasis',
       'Calcification', 'Cardiomegaly', 'Clavicle fracture', 'Consolidation',
       'Edema', 'Emphysema', 'Enlarged PA', 'ILD', 'Infiltration',
       'Lung Opacity', 'Lung cavity', 'Lung cyst', 'Mediastinal shift',
       'Nodule/Mass', 'Pleural effusion', 'Pleural thickening', 'Pneumothorax',
       'Pulmonary fibrosis', 'Rib fracture', 'Other lesion', 'No finding']

In [4]:
train_label = pd.read_csv(labels_csv['train'])
test_label = pd.read_csv(labels_csv['test'])
train_label_bb = pd.read_csv(labels_csv_bb['train'])
test_label_bb = pd.read_csv(labels_csv_bb['test'])

In [5]:
class Vin_big_dataset(Dataset):
    def __init__(self, image_loc, label_loc, transforms, data_type, selec_radio, radio_id = None, label_type = None):
        if label_type == 'global':
            global_labels = ['COPD', 'Lung tumor', 'Pneumonia', 'Tuberculosis', 'Other diseases', 'No finding']
        if label_type == 'local':
            global_labels = ['Aortic enlargement', 'Atelectasis',
       'Calcification', 'Cardiomegaly', 'Clavicle fracture', 'Consolidation',
       'Edema', 'Emphysema', 'Enlarged PA', 'ILD', 'Infiltration',
       'Lung Opacity', 'Lung cavity', 'Lung cyst', 'Mediastinal shift',
       'Nodule/Mass', 'Pleural effusion', 'Pleural thickening', 'Pneumothorax',
       'Pulmonary fibrosis', 'Rib fracture', 'Other lesion', 'No finding']
        
        if data_type == 'train':
            label_df = pd.read_csv(label_loc)
            if selec_radio == 'rand_one':
                label_df['labels'] = label_df['image_id']
                label_df.set_index("labels", inplace = True)
                filenames = np.unique(label_df.index.values).tolist()
                self.full_filenames = [os.path.join(image_loc, i +'.png') for i in filenames]
                self.labels = []
                for i in (filenames):
                    self.labels.append(label_df[global_labels].loc[i].values.tolist()[np.random.choice([0,1,2])])
                self.labels = torch.tensor(self.labels)
            if selec_radio == 'agree_two':
                label_df['labels'] = label_df['image_id']
                label_df.set_index("labels", inplace = True)
                filenames_temp = np.unique(label_df.index.values).tolist()
                self.labels = []
                filenames = []
                for i in filenames_temp:
                    a,b = np.unique(label_df.loc[i][global_labels].values, axis = 0, return_counts=True)
                    if b[0] >= 2:
                        filenames.append(i)
                        self.labels.append(a[0])
                self.labels = torch.tensor(self.labels)
                self.full_filenames = [os.path.join(image_loc, i +'.png') for i in filenames]
            if selec_radio == 'agree_three':
                label_df['labels'] = label_df['image_id']
                label_df.set_index("labels", inplace = True)
                filenames_temp = np.unique(label_df.index.values).tolist()
                self.labels = []
                filenames = []
                for i in filenames_temp:
                    a,b = np.unique(label_df.loc[i][global_labels].values, axis = 0, return_counts=True)
                    if b[0] == 3:
                        filenames.append(i)
                        self.labels.append(a[0])
                self.labels = torch.tensor(self.labels)
                self.full_filenames = [os.path.join(image_loc, i +'.png') for i in filenames]
            if selec_radio == 'radio_per_epoch':
                label_df['labels'] = label_df['image_id']
                label_df.set_index("labels", inplace = True)
                filenames = np.unique(label_df.index.values).tolist()
                self.labels = []
                for i in filenames:
                    self.labels.append(label_df.loc[i][global_labels].values[radio_id].tolist())
                self.labels = torch.tensor(self.labels)
                self.full_filenames = [os.path.join(image_loc, i +'.png') for i in filenames]
            if selec_radio == 'all': 
#                 label_df['labels'] = label_df['image_id'] +'_'+ label_df['rad_id']
                label_df.set_index("image_id", inplace = True)
                filenames = label_df.index.values.tolist()
            
                self.full_filenames = [os.path.join(image_loc,i+'.png') for i in filenames]
                self.labels = []
                for i in tqdm(filenames):
                    self.labels.append(label_df[global_labels].loc[i].values.tolist())         
                self.labels = torch.tensor(self.labels)
                
        if data_type == 'test':                     
            filenames = os.listdir(image_loc)
            self.full_filenames = [os.path.join(image_loc, i) for i in filenames]
            label_df = pd.read_csv(label_loc)
            label_df.set_index("image_id", inplace = True)
            self.labels = [label_df[global_labels].loc[filename[:-4]].values for filename in filenames]
            
        self.transforms = transforms
#         self.data_type = data_type
    def __len__(self):
        return len(self.full_filenames)
    
    def __getitem__(self, idx):
        image = Image.open(self.full_filenames[idx])
        image = self.transforms(image)
        
        return image, self.labels[idx]

In [6]:
data_transforms = { 
    "train": transforms.Compose([
        transforms.RandomHorizontalFlip(p = 0.5), 
        transforms.RandomPerspective(distortion_scale=0.3),
        transforms.RandomRotation((-30,30)),
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
    ]),
    
    "test": transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])        
    ])
    
}

In [7]:
train_data = Vin_big_dataset(image_loc = data_dir['train'],
                          label_loc = labels_csv['train'],
                          transforms = data_transforms['train'],
                          data_type = 'train', selec_radio = 'all', label_type = 'local')

test_data = Vin_big_dataset(image_loc = data_dir['test'],
                          label_loc = labels_csv['test'],
                          transforms = data_transforms['test'],
                          data_type = 'test', selec_radio = None, label_type = 'local')

trainloader = DataLoader(train_data,batch_size = 4,shuffle = False)
testloader = DataLoader(test_data,batch_size = 8,shuffle = False)


  0%|          | 0/15000 [00:00<?, ?it/s]

In [8]:
model = timm.create_model('efficientnet_b6', pretrained=False)
model.load_state_dict(torch.load("/storage/home/akansh12/Vin-ChestXR-Abnormality-detection/model/tf_efficientnet_b6_aa-80ba17e4.pth"))

<All keys matched successfully>

In [9]:
model.classifier = nn.Sequential(OrderedDict([
    ('fcl1', nn.Linear(2304,23)),
    ('out', nn.Sigmoid()),
]))

In [10]:
def train_one_epoch(model, optimizer, lr_scheduler,
                    dataloader, epoch, criterion, device):
    
    print("Start Train ...")
    model.train()

    losses_train = []
    model_train_result = []
    train_target = []


    for data, targets in tqdm(dataloader):
        data = data.to(device)
        targets = targets.to(device).type(torch.float)


        outputs = model(data)
        model_train_result.extend(outputs.detach().cpu().numpy().tolist())
        train_target.extend(targets.cpu().numpy())


        loss = criterion(outputs, targets)

        losses_train.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
#     train_auc = roc_auc_score(train_target, np.array(model_train_result), average=None)
    F1_train = f1_score(train_target, np.array(model_train_result), average=None)

        
    if lr_scheduler is not None:
        lr_scheduler.step()

    lr = lr_scheduler.get_last_lr()[0]
    print("Epoch [%d]" % (epoch),
          "Mean loss on train:", np.array(losses_train).mean(), 
          "F1 score:",np.array(F1_train),
#           "Mean AUC score on train:", np.array(train_auc).mean(), 
          "Learning Rate:", lr)

    
    return np.array(losses_train).mean(), np.array(F1_train), lr


def val_epoch(model, dataloader, epoch, criterion, device):
    
    print("Start Validation ...")
    model.eval()
    
    model_val_result = []
    val_target = []
    losses_val = []

    with torch.no_grad():
        for data, targets in tqdm(dataloader):

            data = data.to(device)
            targets = targets.to(device).type(torch.float)

            outputs = model(data)
            
            #loss
            loss = criterion(outputs, targets)
            losses_val.append(loss.item())

            
            model_val_result.extend(outputs.detach().cpu().numpy().tolist())
            val_target.extend(targets.cpu().numpy())
            
        F1_val = f1_score(val_target, np.array(model_val_result), average=None)




        print("Epoch:  " + str(epoch) + " F1 valid Score:", np.array(F1_val), 
              "Mean valid AUC score", np.array(val_auc).mean())
        
    return np.array(losses_val).mean(), np.array(F1_val)

In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if torch.cuda.is_available():
    model = model.cuda()
    model = nn.DataParallel(model)

In [12]:
for param in model.parameters():
    param.requires_grad = True
    
params = [p for p in model.parameters() if p.requires_grad]

stage_epoch =  [20, 5, 20] #[12, 8, 5]
stage_optimizer = [
    torch.optim.Adamax(params, lr=0.0002),
    torch.optim.SGD(params, lr=0.00009, momentum=0.9),
    torch.optim.Adam(params, lr=0.00005),
]

stage_scheduler = [
    torch.optim.lr_scheduler.CosineAnnealingLR(stage_optimizer[0], 4, 1e-6),
    torch.optim.lr_scheduler.CyclicLR(stage_optimizer[1], base_lr=1e-5, max_lr=2e-4),
    torch.optim.lr_scheduler.CosineAnnealingLR(stage_optimizer[2], 4, 1e-6),
]


In [13]:
criterion = nn.BCELoss()

In [18]:
weights_dir = "/scratch/scratch6/akansh12/DeepEXrays/local_label/efficient-net/"

train_loss_history = []
val_loss_history = []
train_f1_history = []
val_f1_history = []
lr_history = []
weights_dir = "/scratch/scratch6/akansh12/DeepEXrays/model"


for k, (num_epochs, optimizer, lr_scheduler) in enumerate(zip(stage_epoch, stage_optimizer, stage_scheduler)):
    for epoch in range(num_epochs):
        
        
        train_loss, train_f1, lr = train_one_epoch(model, optimizer, lr_scheduler,trainloader, epoch, criterion, device = device)
    
        val_loss, val_f1 = val_epoch(model, testloader, epoch, criterion, device = device)
        
        
        # train history
        train_loss_history.append(train_loss)
        train_f1_history.append(train_f1)
        lr_history.append(lr)
        
        #val history
        val_loss_history.append(val_loss)
        val_f1_history.append(val_f1)
        
        # save best weights
        best_f1 = max(np.mean(val_f1_history, axis =1 ))
        if np.mean(val_f1) >= best_f1:
            torch.save({'state_dict': model.state_dict()},
                        os.path.join(weights_dir, f"{np.mean(val_auc):0.6f}_.pth"))
    
    print("\nNext stage\n")
    # Load the best weights
    best_weights =  sorted(glob.glob(weights_dir + "/*"),
                       key= lambda x: x[8:-5])[-1]
    checkpoint = torch.load(best_weights)
    model.load_state_dict(checkpoint['state_dict'])

    print(f'Loaded model: {best_weights.split("/")[1]}')


Start Train ...


  0%|          | 0/3750 [00:00<?, ?it/s]

KeyboardInterrupt: 