#### PyTorch Lightning: для воспроизведения экспериментов
#### TensorBoard + Optuna: для логирования экспериментов и подбора гиперпараметров 
#### Albumentations: для аугментаций
#### timm: для использования предобученных моделей для задачи классификации 
#### captum: для анализа в каких областях изображения у модели наибольшие градиенты, аккумулированные по всем слоям
#### ONNX: для инференса 

In [None]:
import os
import random

import cv2
import matplotlib.pyplot as plt
import numpy as np
import onnxruntime as ort
import optuna
import pandas as pd
import pytorch_lightning as pl
import timm
import torch
import torch.nn as nn
from albumentations import (
    Blur,
    CoarseDropout,
    Compose,
    GaussNoise,
    HorizontalFlip,
    HueSaturationValue,
    MedianBlur,
    MotionBlur,
    Normalize,
    RandomResizedCrop,
    Resize,
    ShiftScaleRotate,
)
from albumentations.pytorch import ToTensorV2
from captum.attr import IntegratedGradients
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from sklearn.metrics import f1_score
from tqdm import tqdm

In [None]:
PATH_DATA = "/home/a.makarchuk@rit.va/Desktop/kaggle-CV-best11/dogs_vs_cats/data/"
IMAGE_SIZE = (256, 256)
IMAGE_SIZE_TEST = (320, 320)
BATCH_SIZE = 10
EPOCHS = 1
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
TRAIN_RATIO = 0.9
SEED = 42
PERCENT = 0.05

os.chdir(PATH_DATA)

list_train_imgs = os.listdir(f"{PATH_DATA}train")
list_test_imgs = os.listdir(f"{PATH_DATA}test1")

In [None]:
# COUNT MEAN AND STD BY %(PERCENT) OF DATA

num_imgs = int(len(list_train_imgs) * PERCENT)

mean_by_percent = np.mean(
    [
        np.mean(plt.imread(f"{PATH_DATA}/train/{path_data}").astype(float) / 255.0, axis=(0, 1))
        for path_data in list_train_imgs[:num_imgs]
    ],
    axis=0,
)

std_by_percent = np.mean(
    [
        np.std(plt.imread(f"{PATH_DATA}/train/{path_data}").astype(float) / 255.0, axis=(0, 1))
        for path_data in list_train_imgs[:num_imgs]
    ],
    axis=0,
)

print(f"Mean for {PERCENT*100}% of images: {mean_by_percent}")
print(f"Std for {PERCENT*100}% of images: {std_by_percent}")

In [None]:
# DEFINE AUGMENTATIONS (tuned)

train_transforms = Compose(
    [
        Resize(*IMAGE_SIZE, p=1),
        HorizontalFlip(p=0.5),
        ShiftScaleRotate(p=0.5),
        HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
        CoarseDropout(max_holes=3, max_height=32, max_width=32, p=0.25),
        Blur(blur_limit=3, p=0.25),
        GaussNoise(p=0.25),
        Normalize(mean=mean_by_percent, std=std_by_percent),
        ToTensorV2(),
    ]
)

val_transforms = Compose(
    [
        Resize(*IMAGE_SIZE_TEST, p=1),
        Normalize(mean=mean_by_percent, std=std_by_percent),
        ToTensorV2(),
    ]
)

In [None]:
class CatsDogsDataset(torch.utils.data.Dataset):
    def __init__(self, list_imgs, transforms=None):
        self.list_imgs = list_imgs
        self.transforms = transforms

    def __getitem__(self, index):
        img_path = self.list_imgs[index]
        img_name = img_path.split("/")[-1]
        label = 1 if img_name.split(".")[0] == "dog" else 0

        img = cv2.imread(PATH_DATA + "train/" + img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        if self.transforms:
            img = self.transforms(image=img)["image"]

        return img, torch.tensor(label, dtype=torch.float32)

    def __len__(self):
        return len(self.list_imgs)

    def show_augmentated_img(self, index):
        img, _ = self[index]
        if self.transforms:
            img = np.array(img.permute(1, 2, 0), dtype=np.float32)
        return img

In [None]:
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
np.random.seed(SEED)
random.seed(SEED)

list_imgs = list_train_imgs[:]
random.shuffle(list_imgs)
list_train_imgs = list_imgs[: int(len(list_imgs) * TRAIN_RATIO)]
list_val_imgs = list_imgs[int(len(list_imgs) * TRAIN_RATIO) :]

train_dataset = CatsDogsDataset(list_train_imgs, train_transforms)
val_dataset = CatsDogsDataset(list_val_imgs, val_transforms)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4
)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=4)

plt.imshow(train_dataset.show_augmentated_img(0))

In [None]:
class LitModel(pl.LightningModule):
    def __init__(self, model_name, num_classes=1, pretrained=True, lr=5e-4, thr=0.5):
        super().__init__()
        self.model = timm.create_model(model_name, num_classes=num_classes, pretrained=pretrained)
        self.lr = lr
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.thr = thr

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y_true = batch
        y_pred = self.model(x).view(-1)
        loss = self.loss_fn(y_pred, y_true)
        if batch_idx % 50 == 0:
            print(loss)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y_true = batch
        y_pred = self.model(x).view(-1)
        loss = self.loss_fn(y_pred, y_true)
        self.log("val_loss", loss)

        preds = (torch.sigmoid(y_pred) > self.thr).float()
        f1 = f1_score(y_true.cpu().numpy(), preds.cpu().numpy())
        self.log("val_f1", f1)

        bce_loss = nn.BCELoss()
        self.log("val_loss_after_thr", bce_loss(preds, y_true))

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

In [None]:
def objective(trial):
    lr = trial.suggest_float("lr", 5e-5, 1e-3, log=True)
    thr = trial.suggest_float("thr", 0.3, 0.8)

    model = LitModel("mobilenetv4_conv_large.e500_r256_in1k", thr=thr, lr=lr)

    logger = TensorBoardLogger(f"{PATH_DATA}/tb_logs", name="optuna")
    early_stop = EarlyStopping(monitor="val_loss", patience=3, mode="min")
    checkpoint = ModelCheckpoint(monitor="val_loss", save_top_k=1, mode="min")

    trainer = pl.Trainer(
        logger=logger,
        max_epochs=EPOCHS,
        accelerator="auto",
        callbacks=[
            early_stop,
            checkpoint,
        ],
    )

    trainer.fit(model, train_dataloader, val_dataloader)

    return trainer.callback_metrics["val_loss_after_thr"].item()

In [None]:
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=10)

In [None]:
print("Best trial:")
trial = study.best_trial
print("  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

In [None]:
%reload_ext tensorboard
%tensorboard --logdir {PATH_DATA}/tb_logs

In [None]:
model = LitModel(
    "mobilenetv4_conv_large.e500_r256_in1k", thr=trial.params["thr"], lr=trial.params["lr"]
)

logger = TensorBoardLogger(f"{PATH_DATA}/tb_logs", name="best")
early_stop = EarlyStopping(monitor="val_loss", patience=3, mode="min")
checkpoint = ModelCheckpoint(monitor="val_loss", save_top_k=1, mode="min")

trainer = pl.Trainer(
    logger=logger,
    max_epochs=EPOCHS,
    accelerator="auto",
    callbacks=[early_stop, checkpoint],
)

trainer.fit(model, train_dataloader, val_dataloader)

In [None]:
model = LitModel.load_from_checkpoint(
    f"{PATH_DATA}tb_logs/best/version_1/checkpoints/epoch=0-step=2250.ckpt",
    model_name="mobilenetv4_conv_large.e500_r256_in1k",
)
model.eval()

dummy_input = torch.randn(1, 3, *(IMAGE_SIZE_TEST))
torch.onnx.export(model, dummy_input, "model.onnx", opset_version=11)

ort_session = ort.InferenceSession("model.onnx")

test_preds = []
for path_img in tqdm(list_test_imgs):
    img = cv2.imread(f"{PATH_DATA}test1/{path_img}")
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = val_transforms(image=img)["image"]
    img = img.unsqueeze(0)

    input_name = ort_session.get_inputs()[0].name
    input_data = img.numpy()

    pred = ort_session.run(None, {input_name: input_data})[0]
    pred = torch.sigmoid(torch.from_numpy(pred)).item()

    test_preds.append(pred)

In [None]:
test_preds_df = pd.DataFrame(
    {
        "id": [int(path_img.split(".")[0]) for path_img in list_test_imgs],
        "label": test_preds,
    }
)

test_preds_df = test_preds_df.sort_values("id").reset_index(drop=True)
test_preds_df.to_csv("advanced_solution.csv", index=False)

In [None]:
# CHECK PREDICTS ON TEST

rand_img_train_path = f'{PATH_DATA}test1/{".".join(random.choice(list_train_imgs).split(".")[1:])}'
img = plt.imread(rand_img_train_path)
img = cv2.resize(img, IMAGE_SIZE_TEST, interpolation=cv2.INTER_AREA)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

print(rand_img_train_path)

plt.imshow(img)

img = val_transforms(image=img)["image"]
img = img.unsqueeze(0)

input_name = ort_session.get_inputs()[0].name
input_data = img.numpy()

pred = ort_session.run(None, {input_name: input_data})[0]

sigmoid_output = torch.sigmoid(torch.from_numpy(pred)).item()

print(sigmoid_output)
plt.show()

In [None]:
rand_img_train_path = f'{PATH_DATA}test1/{".".join(random.choice(list_train_imgs).split(".")[1:])}'
img = plt.imread(rand_img_train_path)
img = cv2.resize(img, IMAGE_SIZE_TEST, interpolation=cv2.INTER_AREA)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

print(rand_img_train_path)

img_original = img.copy()

img = val_transforms(image=img)["image"]
img = img.unsqueeze(0)

input_name = ort_session.get_inputs()[0].name
input_data = img.numpy()

pred = ort_session.run(None, {input_name: input_data})[0]
sigmoid_output = torch.sigmoid(torch.from_numpy(pred)).item()

print(sigmoid_output)

ig = IntegratedGradients(model)

attributions, delta = ig.attribute(img, target=0, return_convergence_delta=True)

attributions = attributions.squeeze().cpu().detach().numpy()
attributions = np.transpose(attributions, (1, 2, 0))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

ax1.imshow(img_original)
ax1.set_title("Original Image")
ax1.axis("off")

ax2.imshow(attributions, cmap="viridis", alpha=0.7)
ax2.set_title("Integrated Gradients")
ax2.axis("off")

plt.colorbar(
    ax2.imshow(attributions, cmap="viridis", alpha=0.7), ax=ax2, label="Integrated Gradients"
)

plt.tight_layout()
plt.show()