In [None]:
!pip install python-box timm pytorch-lightning==1.4.0 grad-cam ttach torchmetrics

In [None]:
import os
import warnings
from pprint import pprint
from glob import glob
from tqdm import tqdm
import torchmetrics

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torchvision.transforms as T
from box import Box
from timm import create_model
from sklearn.model_selection import StratifiedKFold
from torchvision.io import read_image
from torch.utils.data import DataLoader, Dataset
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.image import show_cam_on_image


import pytorch_lightning as pl
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning import callbacks
from pytorch_lightning.callbacks.progress import ProgressBarBase
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import LightningDataModule, LightningModule


warnings.filterwarnings("ignore")

In [None]:
print(torch.__version__)

In [None]:
import timm
timm.list_models(pretrained=True)

## config

In [None]:
# !rm ./convit_base -r

In [None]:
config = {'seed': 42,
          'root': '/kaggle/input/emsig-spectrum/', 
          'n_splits': 5,
          'epoch': 20,
          'trainer': {
              'gpus': 1,
              'accumulate_grad_batches': 4,
              'progress_bar_refresh_rate': 1,
              'fast_dev_run': False,
              'num_sanity_val_steps': 0,
              'resume_from_checkpoint': None,
          },
          'transform':{
              'name': 'get_default_transforms',
              'image_size': 384
          },
          'train_loader':{
              'batch_size': 4,
              'shuffle': True,
              'num_workers': 4,
              'pin_memory': False,
              'drop_last': True,
          },
          'val_loader': {
              'batch_size': 8,
              'shuffle': False,
              'num_workers': 4,
              'pin_memory': False,
              'drop_last': False
         },
          'model':{
              'name': 'vit_large_patch16_384',
              'output_dim': 6
          },
          'optimizer':{
              'name': 'optim.AdamW',
              'params':{
                  'lr': 1e-5
              },
          },
          'scheduler':{
              'name': 'optim.lr_scheduler.CosineAnnealingWarmRestarts',
              'params':{
                  'T_0': 20,
                  'eta_min': 1e-4,
              }
          },
          'loss': 'nn.CrossEntropyLoss',
}

config = Box(config)

In [None]:
pprint(config)

## dataset

In [None]:
class PetfinderDataset(Dataset):
    def __init__(self, df, image_size=384):
        self._X = df["id"].values
        self._y = None
        if "target" in df.keys():
            self._y = df["target"].values
        self._transform = T.Resize([image_size, image_size])

    def __len__(self):
        return len(self._X)

    def __getitem__(self, idx):
        image_path = self._X[idx]
        image = read_image(image_path)
        image = self._transform(image)
        if self._y is not None:
            label = self._y[idx]
            return image, label
        return image

class PetfinderDataModule(LightningDataModule):
    def __init__(
        self,
        train_df,
        val_df,
        cfg,
    ):
        super().__init__()
        self._train_df = train_df
        self._val_df = val_df
        self._cfg = cfg

    def __create_dataset(self, train=True):
        return (
            PetfinderDataset(self._train_df, self._cfg.transform.image_size)
            if train
            else PetfinderDataset(self._val_df, self._cfg.transform.image_size)
        )

    def train_dataloader(self):
        dataset = self.__create_dataset(True)
        return DataLoader(dataset, **self._cfg.train_loader)

    def val_dataloader(self):
        dataset = self.__create_dataset(False)
        return DataLoader(dataset, **self._cfg.val_loader)

## visualize data

In [None]:
torch.autograd.set_detect_anomaly(True)
seed_everything(config.seed)

df = pd.read_csv('../input/emsig-train-label/train_labels.csv')
df["id"] = df["id"].apply(lambda x: os.path.join(config.root, "crop_spectrum_train_pic", x))

In [None]:
df['target'] = df['target'] - 1

In [None]:
sample_dataloader = PetfinderDataModule(df, df, config).val_dataloader()
images, labels = iter(sample_dataloader).next()

plt.figure(figsize=(12, 12))
for it, (image, label) in enumerate(zip(images[:16], labels[:16])):
    plt.subplot(4, 4, it+1)
    plt.imshow(image.permute(1, 2, 0))
    plt.axis('off')
    plt.title(f'target: {int(label)}')

## augmentation

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]  # RGB
IMAGENET_STD = [0.229, 0.384, 0.225]  # RGB


def get_default_transforms():
    transform = {
        "train": T.Compose(
            [
#                 T.RandomHorizontalFlip(),
#                 T.RandomVerticalFlip(),
#                 T.RandomAffine(15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
#                 T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
                T.ConvertImageDtype(torch.float),
                T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
            ]
        ),
        "val": T.Compose(
            [
                T.ConvertImageDtype(torch.float),
                T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
            ]
        ),
    }
    return transform


## model

In [None]:
def mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
    assert alpha > 0, "alpha should be larger than 0"
    assert x.size(0) > 1, "Mixup cannot be applied to a single instance."

    lam = np.random.beta(alpha, alpha)
    rand_index = torch.randperm(x.size()[0])
    mixed_x = lam * x + (1 - lam) * x[rand_index, :]
    target_a, target_b = y, y[rand_index]
    return mixed_x, target_a, target_b, lam

class Model(pl.LightningModule):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.__build_model()
        self._criterion = eval(self.cfg.loss)()
        self.transform = get_default_transforms()
        self.save_hyperparameters(cfg)

    def __build_model(self):
        self.backbone = create_model(
            self.cfg.model.name, pretrained=True, num_classes=0, in_chans=3
        )
        self.fc = nn.Sequential(
            nn.Dropout(0.5), nn.Linear(self.backbone.num_features, self.cfg.model.output_dim)
        )
        

    def forward(self, x):
        f = self.backbone(x)
        out = self.fc(f)
        return out

    def training_step(self, batch, batch_idx):
        loss, pred, labels = self.__share_step(batch, 'train')
        # print('train loss: ', loss)
        return {'loss': loss, 'pred': pred, 'labels': labels}
        
    def validation_step(self, batch, batch_idx):
        loss, pred, labels = self.__share_step(batch, 'val')
        # print('valid loss: ', loss)
        return {'pred': pred, 'labels': labels}
    
    def __share_step(self, batch, mode):
        images, labels = batch
#         labels = labels.float() / 100.0
        images = self.transform[mode](images)
        
        pred = self.forward(images)
        loss = self._criterion(pred, labels)
        
        pred_cpu = pred.detach().cpu()# * 100.
        labels = labels.detach().cpu()# * 100.
        return loss, pred_cpu, labels
        
    def training_epoch_end(self, outputs):
        print('-----train-----')
        self.__share_epoch_end(outputs, 'train')

    def validation_epoch_end(self, outputs):
        print('-----valid-----')
        self.__share_epoch_end(outputs, 'val')    
        
    def __share_epoch_end(self, outputs, mode):
        preds = []
        labels = []
        for out in outputs:
            pred, label = out['pred'], out['labels']
            preds.append(pred)
            labels.append(label)
        preds = torch.cat(preds)
        labels = torch.cat(labels)
        loss = self._criterion(preds, labels)
        metrics = torchmetrics.functional.accuracy(preds, labels)
        sfmx = nn.Softmax(dim = 1)
        preds_prob = sfmx(preds)
        cm = torchmetrics.functional.confusion_matrix(preds_prob, labels, num_classes=6, normalize=None, threshold=0.5, multilabel=False)
        print(cm)
        print('Accuracy： ', metrics)
        print('epoch loss: ', loss)
#         print('preds: ',preds)
#         print('labels: ',labels)
        self.log(f'{mode}_loss', metrics)
    
    def check_gradcam(self, dataloader, target_layer, target_category=None, reshape_transform=None):
        cam = GradCAMPlusPlus(
            model=self,
            target_layers = [target_layer], 
            use_cuda=self.cfg.trainer.gpus, 
            reshape_transform=reshape_transform)
        
        org_images, labels = iter(dataloader).next()
        cam.batch_size = len(org_images)
        images = self.transform['val'](org_images)
        images = images.to(self.device)
        logits = self.forward(images).squeeze(1)
        pred = logits.sigmoid().detach().cpu().numpy() * 100
        labels = labels.cpu().numpy()
        
        grayscale_cam = cam(input_tensor=images, eigen_smooth=True) #target_category=None, 
        org_images = org_images.detach().cpu().numpy().transpose(0, 2, 3, 1) / 255.
        return org_images, grayscale_cam, pred, labels

    def configure_optimizers(self):
        optimizer = eval(self.cfg.optimizer.name)(
            self.parameters(), **self.cfg.optimizer.params
        )
        scheduler = eval(self.cfg.scheduler.name)(
            optimizer,
            **self.cfg.scheduler.params
        )
        return [optimizer], [scheduler]

## train

In [None]:
skf = StratifiedKFold(
    n_splits=config.n_splits, shuffle=True, random_state=config.seed
)

FOLD = 0

for fold, (train_idx, val_idx) in enumerate(skf.split(df["id"], df["target"])):
    print('FOLD: ', FOLD)
    train_df = df.loc[train_idx].reset_index(drop=True)
    val_df = df.loc[val_idx].reset_index(drop=True)
    datamodule = PetfinderDataModule(train_df, val_df, config)
    model = Model(config)
    earystopping = EarlyStopping(monitor="val_loss", patience=4)
    lr_monitor = callbacks.LearningRateMonitor()
    loss_checkpoint = callbacks.ModelCheckpoint(
        filename="best_loss",
        monitor="val_loss",
        save_top_k=1,
        mode="min",
        save_last=False,
    )
    logger = TensorBoardLogger(config.model.name)
    
    trainer = pl.Trainer(
        logger=logger,
        max_epochs=config.epoch,
        callbacks=[lr_monitor, loss_checkpoint, earystopping],
        **config.trainer,
    )
    trainer.fit(model, datamodule=datamodule)
    PATH = './' + str(fold)
    torch.save(model.state_dict(), PATH)
    FOLD += 1

# class activation map

In [None]:
# gradcam reshape_transform for vit
def reshape_transform(tensor, height=7, width=7):
    result = tensor.reshape(tensor.size(0),
                            height, width, tensor.size(2))

    # like in CNNs.
    result = result.permute(0, 3, 1, 2)
    return result

In [None]:
print(model)

In [None]:
# # import torch, gc
# # gc.collect()
# # torch.cuda.empty_cache()

# model = Model(config) 
# model.load_state_dict(torch.load(f'{config.model.name}/default/version_0/checkpoints/best_loss.ckpt')['state_dict'])
# model = model.cuda().eval()
# config.val_loader.batch_size = 16
# datamodule = PetfinderDataModule(train_df, val_df, config)
# target_category=None
# images, grayscale_cams, preds, labels = model.check_gradcam(
#                                             datamodule.val_dataloader(), 
#                                             target_layer=model.backbone.layers[-1].blocks[-1].norm1,
# #                                             target_category=target_category,
#                                             reshape_transform=reshape_transform)

In [None]:
# plt.figure(figsize=(12, 12))
# for it, (image, grayscale_cam, pred, label) in enumerate(zip(images, grayscale_cams, preds, labels)):
#     plt.subplot(4, 4, it + 1)
#     visualization = show_cam_on_image(image, grayscale_cam)
#     plt.imshow(visualization)
#     plt.title(f'pred: {pred} label: {label}')
#     plt.axis('off')

# visualize result

In [None]:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

path = glob(f'./{config.model.name}/default/version_0/events*')[0]
event_acc = EventAccumulator(path, size_guidance={'scalars': 0})
event_acc.Reload()

scalars = {}
for tag in event_acc.Tags()['scalars']:
    events = event_acc.Scalars(tag)
    scalars[tag] = [event.value for event in events]