In [None]:

import pytorch_lightning as pl
import wandb
from pytorch_lightning.loggers import WandbLogger
from Dataset import Nature12KDataModule
from ResNet import ResNetFinetune
import torch
from tqdm import tqdm

In [None]:
def train():
    wandb.init(project="iNaturalist_ResNet_Eval", config={
        "batch_size": 64,
        "data_augmentation": True,
        "freeze_type": "upto",
        "freeze_upto_layer": 6,
        "image_size": 384,
        "lr": 0.0023958851762794424,
        "max_epochs": 20,
        "momentum": 0.8069106650235346,
        "optimizer": "sgd",
        "resnet_variant": "resnet101",
        "scheduler": True,
        "weight_decay": 0.00481366621399702,
    })

    config = wandb.config

    data_module = Nature12KDataModule(
        data_dir="../../inaturalist_12K",
        batch_size=config.batch_size,
        image_size=(config.image_size, config.image_size),
        data_aug=config.data_augmentation
    )

    data_module.prepare_data()
    data_module.setup()

    model = ResNetFinetune(
        num_classes=10,
        lr=config.lr,
        optimizer=config.optimizer,
        weight_decay=config.weight_decay,
        scheduler=config.scheduler,
        freeze_type=config.freeze_type,
        freeze_upto_layer=config.freeze_upto_layer,
        resnet_variant=config.resnet_variant,
        momentum=config.momentum
    )

    wandb_logger = WandbLogger(project=wandb.run.project, name=wandb.run.name)

    trainer = pl.Trainer(
        max_epochs=config.max_epochs,
        accelerator="auto",
        devices="auto",
        log_every_n_steps=10,
        logger=wandb_logger
    )

    trainer.fit(model, data_module.train_dataloader(), data_module.val_dataloader())
    trainer.test(model, datamodule=data_module)

    return model, data_module


In [None]:


def predict_and_log_confmat(model, data_module):
    model.eval()
    test_loader = data_module.test_dataloader()

    ground_truth = []
    predictions = []

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="🔍 Predicting"):
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            ground_truth.extend(labels.cpu().numpy())
            predictions.extend(preds.cpu().numpy())

    class_names = data_module.class_names

    wandb.log({
        "conf_mat": wandb.plot.confusion_matrix(
            probs=None,
            y_true=ground_truth,
            preds=predictions,
            class_names=class_names
        )
    })

    print("✅ Confusion matrix logged to W&B.")


In [None]:
model, data_module = train()
predict_and_log_confmat(model, data_module)