In [1]:
import os
import shutil
import multiprocessing

import numpy as np
import pandas as pd
from PIL import Image

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

from sklearn.metrics import roc_auc_score

from typing import Tuple, Union, List, Dict

from tqdm import tqdm

In [2]:
MODEL_PATH = './models/final_models/convnext_base-batch_size=128-epochs=20-lr=0.003-1660509636.671306.pt'

TRAIN_TEST_SET = 'train'
TABULAR_DATA_PATH = f'./data/{TRAIN_TEST_SET}.csv'
test_df = pd.read_csv(TABULAR_DATA_PATH)
IMAGE_PATH = f'./data/images/'

PREPROCESS_IMAGES = False
MAX_IMAGE_SIZE = 1_000_000_000 # If image is larger than 1.0GB then a black square will be used
N_JOBS = 1

BATCH_SIZE = 128

DEVICE = 'cpu'

## Preprocess Images

In [3]:
def mask_median(im, val=255):
    '''
    Creates the mask where the value is greater or lower than the median for each color map, to decide whether
    to delete row/column or not (in prune_image_rows_cols function). 
    To work properly assumes the background is white (0,0,0).
    '''
    masks = [None] * 3
    
    for c in range(3):
        masks[c] = im[..., c] >= np.median(im[:, :, c]) - 5
        
    mask = np.logical_and(*masks)
    im[mask, :] = val
    
    return im, mask


def prune_image_rows_cols(im, mask, thr=0.990):
    '''
    Deletes rows and columns where the number of pixels in the mask is greater than the threshold
    '''
    
    # delete empty columns
    for l in reversed(range(im.shape[1])):
        if (np.sum(mask[:, l]) / float(mask.shape[0])) > thr:
            im = np.delete(im, l, 1)
            
    # delete empty rows
    for l in reversed(range(im.shape[0])):
        if (np.sum(mask[l, :]) / float(mask.shape[1])) > thr:
            im = np.delete(im, l, 0)
            
    return im


def image_load_scale_norm(img_path, prune_thr=0.990, bg_val=255):
    '''
    Prunes the image, and resizes the image if they still to big
    '''
    
    img = Image.open(img_path)
    
    scale = min(img.height / 2e3, img.width / 2e3)
    
    if scale > 1:
        tmp_size = int(img.width / scale), int(img.height / scale)
        img.thumbnail(tmp_size, resample=Image.Resampling.BILINEAR, reducing_gap=scale)
        
    im, mask = mask_median(np.array(img), val=bg_val)
    im = prune_image_rows_cols(im, mask, thr=prune_thr)
    img = Image.fromarray(im)
    scale = min(img.height / 1e3, img.width / 1e3)
    
    if scale > 1:
        img = img.resize((int(img.width / scale), int(img.height / scale)), Image.Resampling.LANCZOS)
        
    return img


def preprocess_image(input_dir, target_dir, image_id, max_image_size=1_000_000_000):
    '''
    Gets and image and creates the preprocessed one in the "train_images" folder.
    '''
    img_path = os.path.join(input_dir, f"{image_id}.tif")
    
    if os.path.getsize(img_path) > max_image_size:
        img = Image.fromarray(np.zeros((512,512,3), np.uint8))
    else:
        img = image_load_scale_norm(img_path)
    
    img.save(os.path.join(target_dir, f"{image_id}.png"))
    
    del img

## Dataset

In [4]:
class MayoClinicDataset(Dataset):
    def __init__(self, csv_file:Union[str, pd.DataFrame], root_dir:str, transform:transforms.Compose=None) -> None:
        super().__init__()
        self.tabular_data = pd.read_csv(csv_file) if isinstance(csv_file, str) else csv_file
        self.root_dir = root_dir
        self.transform = transform
        
        self.train = 'label' in self.tabular_data.columns
        self.classes, self.class_to_idx = self._find_classes(self.tabular_data) if self.train else (['CE', 'LAA'], {'CE':0, 'LAA':1})
        
    def __len__(self) -> int:
        return len(self.tabular_data)
    
    def __getitem__(self, index:int) -> Tuple[torch.Tensor, int]:
        img_path = os.path.join(self.root_dir, self.tabular_data['image_id'].iloc[index])
        image = Image.open(f'{img_path}.png')
        
        label = self.class_to_idx[self.tabular_data['label'].iloc[index]] if self.train else -1
        
        if self.transform:
            image = self.transform(image)
            
        return (image, label)
    
    def _find_classes(self, tabular_data:pd.DataFrame) -> Tuple[List[str], Dict[str, int]]:
        classes = list(sorted(tabular_data['label'].unique()))
        class_to_idx = {classes[i]:i for i in range(len(classes))}
        
        return classes, class_to_idx

## DataLoader

In [5]:
def create_TTA_dataloader(
    tabular_data:Union[str, pd.DataFrame],
    img_root_dir:str,
    base_transforms:transforms.Compose,
    aug_transforms:transforms.Compose,
    batch_size:int=1,
    num_workers:int=0
) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, List[str]]:
    """
    Creates Dataloaders for the tabular dataframe given to be used with TTA predictions, based on MayoClinicDataset. 

    Args:
        tabular_data (Union[str, pd.DataFrame]): Path or pandas Dataframe to tabular data.
        img_root_dir (str): Path to the image folder.
        base_transforms (transforms.Compose): Compose to indicate which tranformation, without any data augmentation. 
        Ex: Resize, ToTensor,...
        aug_transforms (transforms.Compose): Compose to indicate which tranformation, with data augmentation steps.
        batch_size (int, optional): Number of samples per batch in each of the DataLoaders. Defaults to 1.
        num_workers (int, optional): Number of workers (cpu's). Defaults to 0.

    Returns:
        Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, List[str]]: Returns a Tuple with the 
        base Dataloader, the augment Dataloader and the target classes labels
    """

    base_dataset = MayoClinicDataset(tabular_data, img_root_dir, base_transforms)
    
    classes = base_dataset.classes
    
    base_dataloader = DataLoader(
        dataset=base_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
    )
    
    augment_dataloader = DataLoader(
        dataset=MayoClinicDataset(tabular_data, img_root_dir, aug_transforms),
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
    )                                        
            
    return base_dataloader, augment_dataloader, classes


## Models

In [6]:
def create_efficientnet_b0_model():
    weights = models.EfficientNet_B0_Weights.DEFAULT
    model = models.efficientnet_b0(weights=weights)
    
    for param in model.features.parameters():
            param.requires_grad = False
    
    model.classifier = nn.Sequential(
        nn.Dropout(p=0.50),
        nn.Linear(in_features=1280, out_features=1, bias=True)
    )
    
    return model


def create_efficientnet_b4_model():
    weights = models.EfficientNet_B4_Weights.DEFAULT
    model = models.efficientnet_b4(weights=weights)
    
    for param in model.features.parameters():
            param.requires_grad = False
    
    model.classifier = nn.Sequential(
        nn.Dropout(p=0.50),
        nn.Linear(in_features=1792, out_features=1, bias=True),
    )
    
    return model


def create_efficientnet_v2_model(size):
    if size == 's' or size == 'small':
        weights = models.EfficientNet_V2_S_Weights.DEFAULT
        model = models.efficientnet_v2_s(weights=weights)
        
    if size == 'm' or size == 'medium':
        weights = models.EfficientNet_V2_M_Weights.DEFAULT
        model = models.efficientnet_v2_m(weights=weights)
        
    if size == 'l' or size == 'large':
        weights = models.EfficientNet_V2_L_Weights.DEFAULT
        model = models.efficientnet_v2_l(weights=weights)
    
    for param in model.features.parameters():
            param.requires_grad = False
    
    model.classifier = nn.Sequential(
        nn.Dropout(p=0.50),
        nn.Linear(in_features=1280, out_features=1, bias=True),
    )
    
    return model


def create_convnext_model(size):
    if size == 't' or size == 'tiny':
        weights = models.ConvNeXt_Tiny_Weights.DEFAULT
        model = models.convnext_tiny(weights=weights)
        in_features = 768
        
    if size == 's' or size == 'small':
        weights = models.ConvNeXt_Small_Weights.DEFAULT
        model = models.convnext_small(weights=weights)
        in_features = 768
        
    if size == 'b' or size == 'base':
        weights = models.ConvNeXt_Base_Weights.DEFAULT
        model = models.convnext_base(weights=weights)
        in_features = 1024
        
    if size == 'l' or size == 'large':
        weights = models.ConvNeXt_Large_Weights.DEFAULT
        model = models.convnext_large(weights=weights)
        in_features = 1536
    
    for param in model.features.parameters():
            param.requires_grad = False
            
    class LayerNorm2d(nn.LayerNorm):
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            x = x.permute(0, 2, 3, 1)
            x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
            x = x.permute(0, 3, 1, 2)
            return x
    
    model.classifier = nn.Sequential(
        LayerNorm2d((in_features,), eps=1e-06, elementwise_affine=True),
        nn.Flatten(start_dim=1, end_dim=-1),
        nn.Linear(in_features=in_features, out_features=1, bias=True)
    )
    
    return model

## Predict functions

In [7]:
def predict(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Makes predictions for all the samples in the DataLoader.

    Args:
        model (torch.nn.Module): Trained model.
        dataloader (torch.utils.data.DataLoader): DataLoader with data to predict (data should not be shuffled).
        device (torch.device): Device.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Predictions of the probabilities and observerd classes.
    """
    model.to(device)
    model.eval()

    ys = []
    predictions = []

    with torch.inference_mode():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            test_pred_logits = model(X).flatten().type(torch.float64)

            ys.append(y)
            predictions.append(torch.sigmoid(test_pred_logits))

    return torch.cat(predictions), torch.cat(ys)


def predict_TTA(
    model:torch.nn.Module,
    base_dataloader:torch.utils.data.DataLoader,
    aug_dataloader:torch.utils.data.DataLoader,
    device:torch.device,
    n_aug_samples:int=4,
    beta:float=0.25,
    use_max:bool=False
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Uses Test Time Augmentation to make predictions.

    Args:
        model (torch.nn.Module): Trained model.
        base_dataloader (torch.utils.data.DataLoader): DataLoader with data to predict, without the data augmentation 
        steps (data should not be shuffled). 
        aug_dataloader (torch.utils.data.DataLoader): DataLoader with data to predict, with the data augmentation 
        steps (data should not be shuffled). 
        device (torch.device): Device.
        n_aug_samples (int, optional): Number of times to transform the data and make predictions on it. Defaults to 4.
        beta (float, optional): Importance given to prediction of augmented predictions. Defaults to 0.25.
        use_max (bool, optional): Whether or not to use maximum values for predictions. Defaults to False.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Predictions of the probabilities and observed classes.
    """
    model.to(device)
    model.eval()

    aug_predictions = []

    with tqdm(total=n_aug_samples+1) as pbar:
        predictions, targets = predict(model, base_dataloader, device)
        
        pbar.update()
        
        for _ in range(n_aug_samples):
            aug_predictions.append(predict(model, aug_dataloader, device)[0])
            pbar.update()
            
    aug_predictions = torch.stack(aug_predictions)
    aug_predictions = aug_predictions.max(0)[0] if use_max else aug_predictions.mean(0)
                
    if use_max: return targets, torch.stack([predictions, aug_predictions], 0).max(0)[0]
    predictions = torch.lerp(aug_predictions, predictions, beta)

    return predictions, targets


## Score functions

In [8]:
def multi_class_logarithmic_loss(y_true:List[int], y_preds:Tuple[List])->float:
    """
    Calculates the weighted multi-class logarithmic loss. 
    The code is not optimized as it assumes the weights for each class to be the same (1/nr_classes),
    however it does not take this assumpytion into consideration when calculating the loss.

    Args:
        y_true (List[int]): True values.
        y_preds (Tuple[List]): Predicted values

    Returns:
        float: Mean loss value.
    """
    loss = 0
    
    if not isinstance(y_true, (np.ndarray, np.generic)):
        y_true = np.array(y_true)
        
    if not isinstance(y_preds, (np.ndarray, np.generic)):
        y_preds = np.array(y_preds)
        
    classes, counts = np.unique(y_true, return_counts=True)
    nr_classes = len(classes)
    
    # Gives same weight to every class 1/number of class
    w = np.zeros(nr_classes) + 1/nr_classes
        
    # Normalize predictions
    y_preds = y_preds/np.expand_dims(np.sum(y_preds, axis=1), axis=-1)
    
    # Clip predicted probabilities
    y_preds = np.clip(y_preds, 10**-15, 1-10**-15)
    
    for true, preds in zip(y_true, y_preds):
        for i in range(nr_classes):
            if true != classes[i]:
                continue # When it is not the true class the value added is 0, so we ignore it
                        
            loss += -((w[i] * np.log(preds[i]) / counts[i]) / (np.sum(w)))
            
    return loss # loss/len(y_true)

## Main

In [9]:
if PREPROCESS_IMAGES:
    print('[INFO] Preprocessing Images...')
    
    df = pd.read_csv(TABULAR_DATA_PATH)
    
    if N_JOBS == 1:
        for name in tqdm(df["image_id"]):
            preprocess_image(IMAGE_PATH, IMAGE_PATH, name, MAX_IMAGE_SIZE)
    else: 
        # It has problems with space
        with multiprocessing.Pool(processes=os.cpu_count() if N_JOBS == -1 else N_JOBS) as pool:
            pool.starmap(
                func=preprocess_image, 
                iterable=[(IMAGE_PATH, IMAGE_PATH, name, MAX_IMAGE_SIZE) for name in df["image_id"]]
            )

            pool.close()

In [10]:
base_transform = transforms.Compose([
    transforms.Resize((512, 512)), 
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.9454, 0.8770, 0.8563), std=(0.1034, 0.2154, 0.2716))
])

aug_transform = transforms.Compose([
    transforms.Resize((512, 512)), 
    transforms.TrivialAugmentWide(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.9454, 0.8770, 0.8563), std=(0.1034, 0.2154, 0.2716))
])

base_test_dataloader, aug_test_dataloader, classes = create_TTA_dataloader(
    tabular_data=TABULAR_DATA_PATH,
    img_root_dir=IMAGE_PATH,
    base_transforms=base_transform,
    aug_transforms=aug_transform,
    batch_size=BATCH_SIZE
)

model = create_convnext_model('b')

model.load_state_dict(
    torch.load(
        f=MODEL_PATH,
        map_location=torch.device(DEVICE)  # load to CPU
    )
)

predictions, y = predict_TTA(
    model=model,
    base_dataloader=base_test_dataloader,
    aug_dataloader=aug_test_dataloader,
    device=DEVICE
)

'predictions, y = predict_TTA(\n    model=model,\n    base_dataloader=base_test_dataloader,\n    aug_dataloader=aug_test_dataloader,\n    device=DEVICE\n)'

In [None]:
y = torch.tensor([base_test_dataloader.dataset.class_to_idx[label] for label in test_df['label']])

auc = roc_auc_score(y.numpy(), prob['LAA'].to_numpy())
multi_class_log_loss = multi_class_logarithmic_loss(y.numpy(), prob.to_numpy())

#print(f'Accuracy = {torch.sum(y == torch.round(predictions).type(torch.int))/len(predictions)*100:.2f}')
print(f'AUC = {auc:.3f}')
print(f'Multi-class logarithmic loss = {multi_class_log_loss:.5f}')

In [None]:
y = torch.tensor([base_test_dataloader.dataset.class_to_idx[label] for label in test_df['label']])

auc = roc_auc_score(y.numpy(), predictions.numpy())
predictions_df = pd.DataFrame({
    'CE':1-predictions.numpy(),
    'LAA':predictions.numpy()
})
multi_class_log_loss = multi_class_logarithmic_loss(y.numpy(), predictions_df.to_numpy())

print(f'Accuracy = {torch.sum(y == torch.round(predictions).type(torch.int))/len(predictions)*100:.2f}')
print(f'AUC = {auc:.3f}')
print(f'Multi-class logarithmic loss = {multi_class_log_loss:.5f}')

In [None]:
predictions_df = pd.DataFrame({
    'patient_id':test_df['patient_id'],
    'CE':1-predictions.numpy(),
    'LAA':predictions.numpy()
})

predictions_df.head()