In [1]:
import cv2
import numpy as np
import os
import pandas as pd
import random
import timm
import torch
from albumentations import (
    Compose,
    Normalize,
    ShiftScaleRotate,
    RandomBrightnessContrast,
    MotionBlur,
    CLAHE,
    HorizontalFlip
)
from copy import deepcopy
from torch.utils.data import Dataset
from tqdm.auto import tqdm
from sklearn.metrics import roc_auc_score

In [2]:
dataset_path = "vinbigdata-chest-xray-resized-png-256x256"
processed_path = "VinBigData-Chest-X-ray-Abnormalities-Detection/vinbigdata_classifierPP_2021"

csv_path = os.path.join(processed_path, 'input/vinbigdata-chest-xray-abnormalities-detection/multilabel_cls_train.csv')
pos_weight_path = os.path.join(processed_path, 'input/vinbigdata-chest-xray-abnormalities-detection/multilabel_pos_weight.npy')
image_path = os.path.join(dataset_path, 'train') # The path to the folder with converted PNG files
save_path = os.path.join(processed_path, 'classifier_weights/1024/')

In [3]:
class TrainDataset(Dataset):
    
    def __init__(self, df, image_path, transform=None):
        self.df = df
        self.image_path = image_path
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        labels = torch.from_numpy(
            self.df.loc[idx,np.arange(0,15).astype(str).tolist()].values.astype(float)
        ).float()

        img = cv2.imread(
            self.image_path + '/' + str(self.df.image_id[idx]) + '.png'
        )
        
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            img = self.transform(image=img)['image']
        img = torch.from_numpy(img.transpose((2, 0, 1))).float()
            
        return img, labels

In [4]:
test_path = os.path.join(dataset_path, 'test')
print(test_path)
df = pd.read_csv(dataset_path+"/train.csv")
df.head()

vinbigdata-chest-xray-resized-png-256x256/test


Unnamed: 0,image_id,class_name,class_id,rad_id,x_min,y_min,x_max,y_max,width,height
0,50a418190bc3fb1ef1633bf9678929b3,No finding,14,R11,,,,,2332,2580
1,21a10246a5ec7af151081d0cd6d65dc9,No finding,14,R7,,,,,2954,3159
2,9a5094b2563a1ef3ff50dc5c7ff71345,Cardiomegaly,3,R10,691.0,1375.0,1653.0,1831.0,2080,2336
3,051132a778e61a86eb147c7c6f564dfe,Aortic enlargement,0,R10,1264.0,743.0,1611.0,1019.0,2304,2880
4,063319de25ce7edb9b1c6b8881290140,No finding,14,R10,,,,,2540,3072


In [5]:
df = pd.read_csv(dataset_path+"/test.csv")
df.head()
image = "ea93703162b05a5c8acf2e18c2f69acf"
df.query("image_id==@image")

Unnamed: 0,image_id,width,height
2885,ea93703162b05a5c8acf2e18c2f69acf,2642,3170


In [6]:
bs = 2
lr = 1e-3
N_EPOCHS = 100
IMG_SIZE = 1024

In [7]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

In [8]:
def train_model(model, train_loader, optimizer, criterion): # train 1 epoch
    """
    Trains the model for 1 epoch
    
    Parameters:
        model (torch.nn.Module): The model to be trained.
        train_loader (torch.utils.data.DataLoader): Dataloader object for training.
        optimizer (A torch.optim class): The optimizer.
        criterion (A function in torch.nn.modules.loss): The loss function. 
        
    Return: 
        avg_loss (float): The average loss.
    """
    
    model.train() 
    
    running_loss = 0.0
    running_n = 0.0
    avg_loss = 0.0

    optimizer.zero_grad()
    
    tk = tqdm(train_loader, total=len(train_loader), position=0, leave=True)
    for idx, (imgs, labels) in enumerate(tk):
        imgs_train, labels_train = imgs.cuda(), labels.cuda()
        output_train = model(imgs_train)
        
        loss = criterion(output_train, labels_train) 
        loss.backward()

        optimizer.step() 
        optimizer.zero_grad() 
        
        running_loss += loss.item()
        running_n += imgs_train.size(0)
        avg_loss += loss.item() / len(train_loader)
        tk.set_postfix(loss=running_loss / running_n)
        
    return avg_loss

In [9]:
def val_model(model, val_loader, criterion):
    """
    Test the model on the validation set
    
    Parameters:
        model (torch.nn.Module): The model to be trained.
        val_loader (torch.utils.data.DataLoader): Dataloader object for validation.
        optimizer (A torch.optim class): The optimizer.
        criterion (A torch.nn.modules.loss class): The loss function. 
        
    Return: 
        avg_val_loss (float): The average loss.
        aucs (np.array): The validation AUC of each class.
    """
    model.eval()
    
    running_loss = 0.0
    running_n = 0.0
    avg_loss = 0.0
    valid_preds, valid_targets = [], []
    
    with torch.no_grad():
        tk = tqdm(val_loader, total=len(val_loader), position=0, leave=True)
        
        for idx, (imgs, labels) in enumerate(tk):  
            imgs_valid, labels_valid = imgs.cuda(), labels.cuda()
            output_valid = model(imgs_valid)
            
            loss = criterion(output_valid, labels_valid)
            running_loss += loss.item()
            running_n += imgs_valid.size(0)
            avg_loss += loss.item() / len(val_loader)
            tk.set_postfix(loss=running_loss / running_n)
            
            valid_pred = torch.sigmoid(output_valid).detach().cpu().numpy()
            label_valid = labels_valid.detach().cpu().numpy()
         
            valid_preds.append(valid_pred)
            valid_targets.append(label_valid.round().astype(int))

        valid_preds = np.concatenate(valid_preds,axis=0).T
        valid_targets = np.concatenate(valid_targets,axis=0).T
        
        aucs = np.array(
            [roc_auc_score(i,j) if len(set(i))>1 else np.nan for i,j in zip(valid_targets, valid_preds)]
        )
        overall_auc = np.nanmean(aucs)
        
    return avg_loss, aucs, overall_auc

In [10]:
def main():
    seed_everything(42)

    train = pd.read_csv(csv_path)

    train_transform = Compose([
        HorizontalFlip(p=0.5),
        ShiftScaleRotate(scale_limit = 0.15, rotate_limit = 10, p = 0.5),
        RandomBrightnessContrast(p=0.5),
        Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0)
    ])
    test_transform = Compose([
        Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0)
    ])

    for fold in range(5):
        trainset = TrainDataset(
            train.loc[train['fold']!=fold].reset_index(),
            image_path = image_path,
            transform=train_transform
        )
        train_loader = torch.utils.data.DataLoader(
            trainset, batch_size=bs, num_workers=1,
            shuffle=True 
        )

        valset = TrainDataset(
            train.loc[train['fold']==fold].reset_index(),
            image_path = image_path,
            transform=test_transform
        )
        val_loader = torch.utils.data.DataLoader(
            valset, batch_size=bs, shuffle=False, num_workers=1
        )

        model = timm.create_model('tf_efficientnet_b4_ns',pretrained=True,num_classes=15).cuda()
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
        criterion = torch.nn.BCEWithLogitsLoss()
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=1, factor=0.1, mode='max')

        best_weights = deepcopy(model.state_dict())
        previous_lr = lr
        best_auc = 0
        best_aucs = [0]*15
        best_val_loss = 100
        es = 0

        for epoch in range(N_EPOCHS):
            avg_loss = train_model(model, train_loader, optimizer, criterion)
            avg_val_loss, aucs, auc = val_model(model, val_loader, criterion)

            print(
                'epoch:', epoch, 'lr:', previous_lr, 'val_loss:',avg_val_loss, 'weighted avg auc:',auc
            )
            print('aucs:',aucs)

            # Record the best weights if either of AUC or val_loss improved.
            if auc > best_auc or avg_val_loss < best_val_loss:
                print('saving best weight...')
                best_weights = deepcopy(model.state_dict())
                for k,v in best_weights.items():
                    best_weights[k] = v.cpu()

            # Save the model weight if the AUC of any class is improved. 
            for i in range(len(best_aucs)):
                if aucs[i] > best_aucs[i]:
                    best_aucs[i] = aucs[i]
                    d = {
                        'weight':model.state_dict(),
                        'auc':aucs[i],
                        'epoch':epoch,
                    }
                    torch.save(
                        d, save_path + f'multilabel_efnb4_v1_cls{i}_fold{fold}.pth'
                    )

            # Update best avg_val_loss
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss

            # Update best weighted../../results/multilabel_cls/v2/ AUC and implement early stop
            if auc > best_auc:
                es = 0
                best_auc = auc
            else:
                es += 1
                if es > 10:
                    break

            scheduler.step(auc)  

            # if lr changes, start from previous best weight:
            if optimizer.param_groups[0]['lr'] < previous_lr:
                print('restoring best weight...')
                model.load_state_dict(best_weights)
                previous_lr = optimizer.param_groups[0]['lr']

                if optimizer.param_groups[0]['lr'] < 0.99e-6:
                    break

In [11]:
main()

  0%|          | 0/6002 [00:00<?, ?it/s]

KeyboardInterrupt: 