In [None]:
%load_ext nb_black
%load_ext autoreload
%autoreload 2

import os
import glob
import pandas as pd
import numpy as np
import random

from sklearn.metrics import roc_auc_score, roc_curve, auc
import seaborn as sns
from matplotlib import pyplot as plt

import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import torchvision

import pytorch_lightning as pl

import albumentations
from albumentations.pytorch import ToTensorV2

import timm

In [None]:
class Config:
    use_amp = False
    debug = False
    train_df_fp = "data/train_folds.csv"
    test_df_fp = "data/sample_submission.csv"
    num_workers = 8
    model_name = "resnet200d_320"
    image_size = 512
    batch_size = 8
    seed = 1710
    target_size = 11
    target_cols = [
        "ETT - Abnormal",
        "ETT - Borderline",
        "ETT - Normal",
        "NGT - Abnormal",
        "NGT - Borderline",
        "NGT - Incompletely Imaged",
        "NGT - Normal",
        "CVC - Abnormal",
        "CVC - Borderline",
        "CVC - Normal",
        "Swan Ganz Catheter Present",
    ]
    fold = 0
    output_dir = os.path.join("outputs", "checkpoints", model_name)
    submission_dir = os.path.join("outputs", "results", model_name)

In [None]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
os.listdir("data")

In [None]:
train_df = pd.read_csv("data/train_folds.csv")
test_df = pd.read_csv("data/sample_submission.csv")

In [None]:
display(train_df.head())
display(test_df.head())

In [None]:
train_img_dir = "data/train"
test_img_dir = "data/test"

In [None]:
def seed_everything(seed=1710):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.bench_mark = True  # for faster training but not deterministic
    torch.backends.cudnn.deterministic = True

In [None]:
seed_everything(Config.seed)

In [None]:
class RANZCRDataset(Dataset):
    def __init__(
        self,
        df,
        img_dir,
        mode,
        image_size=Config.image_size,
        target_cols=Config.target_cols,
    ):

        self.df = self._get_df(df.copy(), img_dir)
        self.labels = self.df[target_cols].values
        self.mode = mode
        self.image_size = image_size

        self._setup_transform()
        if mode == "train":
            self.transform = self.transform_train
        elif mode == "val":
            self.transform = self.transform_val
        elif mode == "test":
            self.transform = self.transform_test
        else:
            raise ValueError(f"Invalid mode {mode}")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        data = self.df.iloc[idx]
        img = cv2.imread(data["file_path"])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.transform(image=img)["image"]
        img = img.astype(np.float32)
        img = img.transpose(2, 0, 1)
        label = torch.tensor(self.labels[idx]).float()
        if self.mode == "test":
            return torch.tensor(img).float()
        return torch.tensor(img).float(), label

    def _setup_transform(self):
        self.transform_train = albumentations.Compose(
            [
                albumentations.RandomResizedCrop(
                    self.image_size, self.image_size, scale=(0.9, 1), p=1
                ),
                albumentations.HorizontalFlip(p=0.5),
                albumentations.ShiftScaleRotate(p=0.5),
                albumentations.HueSaturationValue(
                    hue_shift_limit=10, sat_shift_limit=10, val_shift_limit=10, p=0.7
                ),
                albumentations.RandomBrightnessContrast(
                    brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), p=0.7
                ),
                albumentations.CLAHE(clip_limit=(1, 4), p=0.5),
                albumentations.OneOf(
                    [
                        albumentations.OpticalDistortion(distort_limit=1.0),
                        albumentations.GridDistortion(num_steps=5, distort_limit=1.0),
                        albumentations.ElasticTransform(alpha=3),
                    ],
                    p=0.2,
                ),
                albumentations.OneOf(
                    [
                        albumentations.GaussNoise(var_limit=[10, 50]),
                        albumentations.GaussianBlur(),
                        albumentations.MotionBlur(),
                        albumentations.MedianBlur(),
                    ],
                    p=0.2,
                ),
                albumentations.Resize(self.image_size, self.image_size),
                albumentations.OneOf(
                    [
                        albumentations.JpegCompression(),
                        albumentations.Downscale(scale_min=0.1, scale_max=0.15),
                    ],
                    p=0.2,
                ),
                albumentations.IAAPiecewiseAffine(p=0.2),
                albumentations.IAASharpen(p=0.2),
                albumentations.Cutout(
                    max_h_size=int(self.image_size * 0.1),
                    max_w_size=int(self.image_size * 0.1),
                    num_holes=5,
                    p=0.5,
                ),
                albumentations.Normalize(),
            ]
        )

        self.transform_val = albumentations.Compose(
            [
                albumentations.Resize(self.image_size, self.image_size),
                albumentations.Normalize(),
            ]
        )
        self.transform_test = albumentations.Compose(
            [
                albumentations.Resize(self.image_size, self.image_size),
                albumentations.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                ),
                ToTensorV2(),
            ]
        )

    def _get_df(self, df, img_dir):
        df["file_path"] = df["StudyInstanceUID"].apply(
            lambda id_: os.path.join(img_dir, id_ + ".jpg")
        )
        df = df.reset_index(drop=True)
        return df

In [None]:
class RANZCRDataModule(pl.LightningDataModule):
    def __init__(
        self,
        train_df,
        test_df,
        train_img_dir,
        test_img_dir,
        fold,
        batch_size=Config.batch_size,
        image_size=Config.image_size,
    ):
        super().__init__()
        self.train_df = train_df.copy()
        self.test_df = test_df.copy()

        #         split train-val
        self.val_df = self.train_df[self.train_df["fold"] == fold]
        self.train_df = self.train_df[self.train_df["fold"] != fold]

        self.train_img_dir = train_img_dir
        self.test_img_dir = test_img_dir
        self.batch_size = batch_size
        self.image_size = image_size

        #         debug:
        self.train_df = self.train_df.iloc[:100]
        self.val_df = self.val_df.iloc[:100]
        self.test_df = self.test_df.iloc[:100]

    def setup(self, stage=None):

        self.RANZCR_train = RANZCRDataset(
            df=self.train_df,
            img_dir=self.train_img_dir,
            mode="train",
            image_size=self.image_size,
        )
        self.RANZCR_val = RANZCRDataset(
            df=self.val_df,
            img_dir=self.train_img_dir,
            mode="val",
            image_size=self.image_size,
        )
        self.RANZCR_test = RANZCRDataset(
            df=self.test_df,
            img_dir=self.test_img_dir,
            mode="test",
            image_size=self.image_size,
        )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.RANZCR_train,
            batch_size=self.batch_size,
            num_workers=Config.num_workers,
            drop_last=True,
            shuffle=True,
            pin_memory=True,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.RANZCR_val,
            batch_size=self.batch_size,
            num_workers=Config.num_workers,
            drop_last=False,
            shuffle=False,
            pin_memory=True,
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.RANZCR_test,
            batch_size=self.batch_size,
            num_workers=Config.num_workers,
            drop_last=False,
            shuffle=False,
            pin_memory=False,
        )

In [None]:
# x = RANZCRDataModule(
#     train_df=train_df,
#     test_df=test_df,
#     train_img_dir=train_img_dir,
#     test_img_dir=test_img_dir,
#     fold=0,
# )
# x.setup()
# x.train_dataloader()

In [None]:
# F.binary_cross_en

In [None]:
class RANZCRModel(pl.LightningModule):
    def __init__(
        self,
        model_name=Config.model_name,
        output_dim=Config.target_size,
        pretrained=True,
    ):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=pretrained)

        n_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.backbone.global_pool = nn.Identity()
        self.pooling = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(n_features, output_dim)

    def binary_loss(self, logits, labels):
        return F.binary_cross_entropy_with_logits(logits, labels)

    def macro_auc(self, labels, pred):
        fig, ax = plt.subplots(figsize=(8, 5))
        aucs = []
        for i, col in enumerate(Config.target_cols):
            fpr, tpr, threshold = roc_curve(labels[:, i], pred[:, i])
            roc_auc = auc(fpr, tpr)
            aucs.append(roc_auc)

            plt.plot(fpr, tpr, label=f"Field {col} (AUC = {roc_auc:.4f})")

        mean_auc = np.mean(aucs)
        std_auc = np.std(aucs)

        ax.plot([0, 1], [0, 1], label="Luck", linestyle="--", color="r")
        ax.plot(
            mean_auc, label=f"Average AUC score: {mean_auc:.4f} $\pm$ {std_auc:.4f}"
        )
        ax.legend(loc="lower right")
        ax.set(
            xlim=[-0.1, 1.1],
            ylim=[-0.1, 1.1],
            title=f"Average AUC over {Config.target_size} fields",
        )
        plt.show()
        return mean_auc

    def forward(self, x):
        bs = x.size(0)
        features = self.backbone(x)
        pooled_features = self.pooling(features).view(bs, -1)
        output = self.fc(pooled_features)
        return output

    def training_step(self, train_batch, batch_idx):

        x, y = train_batch
        logits = self.forward(x)
        loss = self.binary_loss(logits, y).unsqueeze(0)
        #         self.log("train_loss", loss)
        #         auc = self.macro_auc(logits, y)
        tb_log = {
            "train_loss": loss,
            #                   "train_auc": auc
        }
        return {
            "loss": loss,
            #                 "train_auc": auc,
            "log": tb_log,
        }

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = self.binary_loss(logits, y).unsqueeze(0)
        probs = torch.sigmoid(logits)
        #         self.log('val_loss', loss)
        #         already call torch.no_grad() so we had not to call the detach()
        auc = self.macro_auc(logits.cpu().numpy(), y.cpu().numpy())
        tb_log = {"val_loss": loss, "val_auc": auc}
        return {"val_loss": loss, "val_auc": auc, "log": tb_log}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.cat([out["val_loss"] for out in outputs], dim=0).mean()
        avg_auc = torch.cat([out["val_auc"] for out in outputs], dim=0).mean()
        print(f"EPOCH: {self.current_epoch} AUC:{auc:.4f}")
        tensorboard_logs = {"val_loss": avg_loss, "val_auc": avg_auc}
        return {"avg_val_loss": avg_loss, "val_auc": avg_auc, "log": tensorboard_logs}

    def test_step(self, batch, batch_idx):
        logits = self.forward(batch)
        probs = torch.sigmoid(logits)
        return {"probs": probs}

    def test_epoch_end(self, outputs):
        y_hat = torch.cat([x["y_hat"] for x in outputs])
        df_test["target"] = y_hat.tolist()

        os.makedirs(Config.output_dir, exists_ok=True)
        os.makedirs(Config.submission_dir, exists_ok=True)
        N = len(os.listdir(Config.output_dir))
        df_test.target.to_csv(os.path.join(Config.submission_dir, f"submission{N}.csv"))
        return {"tta": N}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

In [None]:
# x = RANZCRModel()

In [None]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    "{epoch:02d}_{val_auc:.4f}",
    monitor="val_auc",
    mode="max",
    #     save_top_k=1,
)
trainer = pl.Trainer(
    gpus=1,
    precision=16,
    max_epochs=30,
    num_sanity_val_steps=1 if Config.debug else 0,
    checkpoint_callback=checkpoint_callback,
    #     val_check_interval=0.25, # check validation 4 times per epoch
)

In [None]:
model = RANZCRModel()
data_module = RANZCRDataModule(
    train_df=train_df,
    test_df=test_df,
    train_img_dir=train_img_dir,
    test_img_dir=test_img_dir,
    fold=0,
)
trainer.fit(model, data_module)