In [None]:
%cd ..

In [None]:
from datetime import datetime

import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from pytorch_lightning.loggers import TensorBoardLogger
from tqdm import tqdm

from src.data.ConsPracDataModule import ConsPracDataModule
from src.models.EfficientNetV2S import EfficientNetV2S
from src.models.ResNet50 import ResNet50

In [None]:
MODEL_PATH = "logs/lightning_logs/version_13/checkpoints/best-checkpoint-epoch=04-val_loss=1.42.ckpt"

PREVIOUS_BEST_SUBMISSION = "data/submissions/0830_2220_submission_df.csv"

In [None]:
dm = ConsPracDataModule()
dm.setup()

In [None]:
dl_test = dm.test_dataloader()

In [None]:
model = EfficientNetV2S.load_from_checkpoint(MODEL_PATH)

In [None]:
trainer = pl.Trainer(logger=TensorBoardLogger(save_dir="logs/"))
torch.set_float32_matmul_precision("medium")
predictions = trainer.predict(model, dl_test)

In [None]:
species_labels = [
    "antelope_duiker",
    "bird",
    "blank",
    "civet_genet",
    "hog",
    "leopard",
    "monkey_prosimian",
    "rodent",
]
preds_collector = []

for batch_idx, batch in enumerate(tqdm(predictions)):
    idxs = batch[0]
    preds = F.softmax(batch[1], dim=1)

    preds_df = pd.DataFrame(
            preds.detach().numpy(),
            index=idxs,
            columns=species_labels,
        )
    preds_collector.append(preds_df)

submission_df = pd.concat(preds_collector)
submission_df

In [None]:
submission_format = pd.read_csv("data/submission_format.csv", index_col="id")

assert all(submission_df.index == submission_format.index)
assert all(submission_df.columns == submission_format.columns)

In [None]:
mmdd_hhmm = datetime.now().strftime("%m%d_%H%M")

submission_df.to_csv(f"data/submissions/{mmdd_hhmm}_submission_df.csv")

### Compare with previous best prediction

In [None]:
previous_best_submission = pd.read_csv(PREVIOUS_BEST_SUBMISSION, index_col=0)

previous_best_submission.head()

In [None]:
f"Precentage of different entries: {(previous_best_submission.idxmax(axis=1) != submission_df.idxmax(axis=1)).mean():.2%}"