In [1]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import random
import sys
import timm
import torch
from albumentations import (
    Compose,
    Normalize,
    ShiftScaleRotate,
    RandomBrightnessContrast,
    MotionBlur,
    CLAHE,
    HorizontalFlip
)
from copy import deepcopy
from torch.utils.data import Dataset
from tqdm.auto import tqdm
from sklearn.metrics import roc_auc_score, accuracy_score

### Constants

In [4]:
dataset_path = "/home/jupyter/vinbigdata-chest-xray-resized-png-256x256"
model_path = "/home/jupyter/vinbigdata-chest-xray-resized-png-256x256/save_models"
class_weights_path = "/home/jupyter/vinbigdata-chest-xray-resized-png-256x256/class_weights.npy"

train_csv_path = os.path.join(dataset_path, 'vindrcxr_train.csv')
test_csv_path = os.path.join(dataset_path, 'vindrcxr_test.csv')
train_image_path = os.path.join(dataset_path, 'train')
test_image_path = os.path.join(dataset_path, 'test')
save_path = os.path.join(model_path, '')

print(train_image_path)
print(test_image_path)

/home/jupyter/vinbigdata-chest-xray-resized-png-256x256/train
/home/jupyter/vinbigdata-chest-xray-resized-png-256x256/test


In [7]:
bs = 2
lr = 1e-3
N_EPOCHS = 10
NUM_CLASSES = 15

In [5]:
!ls vinbigdata-chest-xray-resized-png-256x256

28d11d6696649471d0af4e67f7126299.png  save_models  train_meta.csv
class_weights.npy		      test.csv	   vindrcxr_test.csv
saliency_map.png		      train.csv    vindrcxr_train.csv


### Dataset Object

In [6]:
### Code from https://github.com/Scu-sen/VinBigData-Chest-X-ray-Abnormalities-Detection

class Dataset(Dataset):
    
    def __init__(self, df, image_path, transform=None):
        self.df = df
        self.image_path = image_path
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        labels = torch.from_numpy(
            self.df.loc[idx,np.arange(0, NUM_CLASSES).astype(str).tolist()].values.astype(float)
        ).float()

        img = cv2.imread(
            self.image_path + '/' + str(self.df.image_id[idx]) + '.png'
        )
        
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            img = self.transform(image=img)['image']
        img = torch.from_numpy(img.transpose((2, 0, 1))).float()
            
        return img, labels

### Model Training Utility

In [9]:
### Modified from https://github.com/Scu-sen/VinBigData-Chest-X-ray-Abnormalities-Detection

def train_model(model, data_loader, optimizer, criterion, class_weights):
    """
    Trains the model for 1 epoch
    
    Parameters:
        model (torch.nn.Module): The model to be trained/validated.
        data_loader (torch.utils.data.DataLoader): Dataloader object for training/validation.
        optimizer (A torch.optim class): The optimizer.
        criterion (A function in torch.nn.modules.loss): The loss function. 
        class_weights (np.array): class_weights[i] represents the weight of class i (inversely prop. to class frequency).
        
    Return: 
        avg_loss (float): The average loss.
        aucs (List[float]): List of per-class AUCs
        overall_auc (float): weighted-average AUC across classes
        accs (List[float]): List of per-class accuracies
        overall_acc (float): weighted-average accuracy across classes
    """
    
    model.train()
    
    running_loss = 0.0
    running_n = 0
    avg_loss = 0.0
    preds_list, targets_list = [], []

    optimizer.zero_grad()
    
    tk = tqdm(data_loader, total=len(data_loader), position=0, leave=True)
    # Run model training on each batch from the train data_loader.
    for idx, (imgs, labels) in enumerate(tk):
        imgs, labels = imgs.cuda(), labels.cuda()
        output = model(imgs)
        
        loss = criterion(output, labels) 
        
        loss.backward()
        optimizer.step() 
        optimizer.zero_grad() 
        
        running_loss += loss.item() * imgs.size(0)
        running_n += imgs.size(0)
        tk.set_postfix(loss=running_loss / running_n)

        preds = torch.sigmoid(output).detach().cpu().numpy()
        labels = labels.detach().cpu().numpy()
      
        preds_list.append(preds)
        targets_list.append(labels.round().astype(int))

    avg_loss = running_loss / running_n

    preds_list = np.concatenate(preds_list,axis=0).T
    targets_list = np.concatenate(targets_list,axis=0).T
    
    aucs = np.array(
        [roc_auc_score(i,j) if len(set(i))>1 else np.nan for i,j in zip(targets_list, preds_list)]
    )
    overall_auc = np.nansum(class_weights * aucs)/np.nansum(class_weights)

    thresholded_preds_list = np.round(preds_list)
    accs = np.array(
        [accuracy_score(i,j) if len(set(i))>1 else np.nan for i,j in zip(targets_list, thresholded_preds_list)]
    )
    overall_acc = np.nansum(class_weights * accs)/np.nansum(class_weights)

    return avg_loss, aucs, overall_auc, accs, overall_acc

### Model Validation Utility

In [10]:
### Modified from https://github.com/Scu-sen/VinBigData-Chest-X-ray-Abnormalities-Detection

def val_model(model, data_loader, criterion, class_weights):
    """
    Tests the model on the validation set
    
    Parameters:
        model (torch.nn.Module): The model to be trained/validated.
        data_loader (torch.utils.data.DataLoader): Dataloader object for training/validation.
        optimizer (A torch.optim class): The optimizer.
        criterion (A torch.nn.modules.loss class): The loss function. 
        
    Return: 
        avg_loss (float): The average loss.
        aucs (List[float]): List of per-class AUCs
        overall_auc (float): weighted-average AUC across classes
        accs (List[float]): List of per-class accuracies
        overall_acc (float): weighted-average accuracy across classes
    """
    model.eval()
    
    running_loss = 0.0
    running_n = 0
    avg_loss = 0.0
    preds_list, targets_list = [], []
    
    with torch.no_grad():
        tk = tqdm(data_loader, total=len(data_loader), position=0, leave=True)
        
        # Run model inference on each batch from the val data_loader.
        for idx, (imgs, labels) in enumerate(tk):  
            imgs, labels = imgs.cuda(), labels.cuda()
            output = model(imgs)
            
            loss = criterion(output, labels)
            running_loss += loss.item() * imgs.size(0)
            running_n += imgs.size(0)
            tk.set_postfix(loss=running_loss / running_n)
            
            preds = torch.sigmoid(output).detach().cpu().numpy()
            labels = labels.detach().cpu().numpy()
         
            preds_list.append(preds)
            targets_list.append(labels.round().astype(int))        
        ## Compute Metrics ##
        avg_loss = running_loss / running_n

        preds_list = np.concatenate(preds_list,axis=0).T
        targets_list = np.concatenate(targets_list,axis=0).T
        
        aucs = np.array(
            [roc_auc_score(i,j) if len(set(i))>1 else np.nan for i,j in zip(targets_list, preds_list)]
        )
        overall_auc = np.nansum(class_weights * aucs)/np.nansum(class_weights)

        thresholded_preds_list = np.round(preds_list)
        accs = np.array(
            [accuracy_score(i,j) if len(set(i))>1 else np.nan for i,j in zip(targets_list, thresholded_preds_list)]
        )
        overall_acc = np.nansum(class_weights * accs)/np.nansum(class_weights)
        
    return avg_loss, aucs, overall_auc, accs, overall_acc

In [11]:
### Modified from https://github.com/Scu-sen/VinBigData-Chest-X-ray-Abnormalities-Detection

def main():
    ## Select fold {0, 1, 2, 3, 4} to use for the validation set.
    fold = 0
    
    # Read the train data and saved class weights.
    train = pd.read_csv(train_csv_path)
    class_weights = np.load(class_weights_path)
    
    # Add Training Data Augmentations.
    train_transform = Compose([
        HorizontalFlip(p=0.5),
        ShiftScaleRotate(scale_limit = 0.15, rotate_limit = 10, p = 0.5),
        RandomBrightnessContrast(p=0.5),
        Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0)
    ])
    # Validation Data transform.
    val_transform = Compose([
        Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0)
    ])
    
    # Create Train Dataset and DataLoader.
    trainset = Dataset(
        train.loc[train['fold'] != fold].reset_index(),
        image_path=train_image_path,
        transform=train_transform
    )
    train_loader = torch.utils.data.DataLoader(
        trainset, batch_size=bs, num_workers=1,
        shuffle=True 
    )

    # Create Val Dataset and DataLoader.
    valset = Dataset(
        train.loc[train['fold'] == fold].reset_index(),
        image_path=test_image_path,
        transform=val_transform
    )
    val_loader = torch.utils.data.DataLoader(
        valset, batch_size=bs, shuffle=False, num_workers=1
    )

    # Load EfficientNet B4 model.
    model = timm.create_model('tf_efficientnet_b4_ns',pretrained=True,num_classes=15).cuda()
    # Setup optimizer (Adam or SGD with momentum
    optimizer = torch.optim.Adam(model.parameters())
    # optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    # Add loss function.
    criterion = torch.nn.BCEWithLogitsLoss(
        pos_weight = torch.FloatTensor(class_weights).cuda()
    )
    # Add LR scheduler.
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=1, factor=0.1, mode='max')

    # Keep track of best model's metrics.
    best_weights = deepcopy(model.state_dict())
    previous_lr = lr
    best_auc = 0
    best_aucs = [0]*NUM_CLASSES
    best_val_loss = sys.float_info.max
    es = 0 # Early stopping parameter.

    # Keep track of the train/val loss/accuracy history.
    train_loss_history, val_loss_history = [], []
    acc_train_history, acc_val_history = [], []
    for epoch in range(N_EPOCHS):
        # Run training and validation for 1 epoch.
        avg_train_loss, aucs_train, auc_train, accs_train, acc_train = train_model(model, train_loader, optimizer, criterion, class_weights)
        avg_val_loss, aucs_val, auc_val, accs_val, acc_val = val_model(model, val_loader, criterion, class_weights)

        train_loss_history.append(avg_train_loss)
        val_loss_history.append(avg_val_loss)
        acc_train_history.append(acc_train)
        acc_val_history.append(acc_val)

        # Report metrics for each epoch.
        print('epoch:', epoch)
        print("Training Metrics")
        print('lr:', previous_lr, 'train_loss:', avg_train_loss, 'weighted avg auc:',auc_train, 'weighted avg acc:', acc_train)
        print('aucs:',aucs_train)
        print('accs:', accs_train)
        print("Validation Metrics")
        print('lr:', previous_lr, 'val_loss:',avg_val_loss, 'weighted avg auc:',auc_val, 'weighted avg acc:', acc_val)
        print('aucs:',aucs_val)
        print('accs:', accs_val)

        # Save the best weights if either overall AUC or val_loss improved.
        if auc_val > best_auc or avg_val_loss < best_val_loss:
            print('saving best weight...')
            best_weights = deepcopy(model.state_dict())
            for k,v in best_weights.items():
                best_weights[k] = v.cpu()

        # Save the model for class i if the per-class validation AUC of class i improved.
        for i in range(len(best_aucs)):
            if aucs_val[i] > best_aucs[i]:
                best_aucs[i] = aucs_val[i]
                d = {'weight':model.state_dict(), 'auc':aucs_val[i], 'epoch':epoch}
                torch.save(d, save_path + f'multilabel_efnb4_v1_adam_cls{i}.pth')

        # Update best overall val_loss.
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss

        # Update best val AUC
        if auc_val > best_auc:
            es = 0
            best_auc = auc_val
        else:
            # Early stopping implementation. Stop training if no
            # improvement in val AUC in 10 epochs.
            es += 1
            if es > 10:
                break

        scheduler.step(auc_val)

    # Plot training and validation loss curves.
    plt.plot(range(N_EPOCHS), train_loss_history, label='Training Loss')
    plt.plot(range(N_EPOCHS), val_loss_history, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.savefig(save_path + f'multilabel_efnb4_v1_loss_history_weighted_adam.png')
    plt.show()