In [None]:

import torch
import torch.nn as nn
import pytorch_lightning as pl
from torchvision import  transforms
from pytorch_lightning.loggers import WandbLogger
import os
from Dataset import Nature12KDataModule
from cnn_model import CNN
import matplotlib.pyplot as plt
import torchvision.transforms as transforms


In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [None]:
def visualize_predictions(model, dataloader, class_names, num_images=30):
    model.eval()
    images_shown = 0
    rows, cols = 10, 3
    fig, axs = plt.subplots(rows, cols, figsize=(cols * 4, rows * 2.5))
    axs = axs.flatten()

    transform = transforms.ToPILImage()

    with torch.no_grad():
        for batch in dataloader:
            x, y = batch
            logits = model(x.to(model.device))
            preds = logits.argmax(dim=1)

            for img, label, pred in zip(x, y, preds):
                if images_shown >= num_images:
                    break
                img = transform(img.cpu())
                axs[images_shown].imshow(img)
                axs[images_shown].axis('off')
                axs[images_shown].set_title(f"True: {class_names[label]}\nPred: {class_names[pred]}")
                images_shown += 1

            if images_shown >= num_images:
                break

    plt.tight_layout()
    plt.savefig("test_predictions_grid.png")
    plt.show()


In [None]:

def visualize_predictions(model, dataloader, class_names, num_images=30):
    model.eval()
    images_shown = 0
    rows, cols = 10, 3
    fig, axs = plt.subplots(rows, cols, figsize=(cols * 4, rows * 2.5))
    axs = axs.flatten()

    transform = transforms.ToPILImage()

    with torch.no_grad():
        for batch in dataloader:
            x, y = batch
            logits = model(x.to(model.device))
            preds = logits.argmax(dim=1)

            for img, label, pred in zip(x, y, preds):
                if images_shown >= num_images:
                    break
                img = transform(img.cpu())
                axs[images_shown].imshow(img)
                axs[images_shown].axis('off')
                axs[images_shown].set_title(f"True: {class_names[label]}\nPred: {class_names[pred]}")
                images_shown += 1

            if images_shown >= num_images:
                break

    plt.tight_layout()
    plt.savefig("test_predictions_grid.png")
    plt.show()


In [None]:
def train_best_model():
    # Define best config manually
    config = {
        "activation": "silu",
        "batch_norm": True,
        "batch_size": 32,
        "data_augmentation": True,
        "dense_neurons": 256,
        "dropout": 0.3,
        "input_channels": 3,
        "lr": 0.0004165458022262786,
        "num_classes": 10,
        "conv_filters": [32, 64, 128, 128, 128],
        "kernel_sizes": [3, 3, 3, 3, 3]
    }

    # Init datamodule
    data_module = Nature12KDataModule(
        data_dir="../../inaturalist_12K",
        batch_size=config["batch_size"],
        image_size=(128, 128),
        data_aug=config["data_augmentation"]
    )
    data_module.prepare_data()
    data_module.setup()

    # Build model with best hyperparameters
    model = CNN(
        input_channels=config["input_channels"],
        conv_filters=config["conv_filters"],
        kernel_sizes=config["kernel_sizes"],
        activation=config["activation"],
        dense_neurons=config["dense_neurons"],
        num_classes=config["num_classes"],
        lr=config["lr"],
        batch_norm=config["batch_norm"],
        dropout=config["dropout"]
    )

    # Optional: disable wandb logging if not needed
    wandb_logger = WandbLogger(project="best_model_eval", name="Best_CNN_Model")

    # Train the model
    trainer = pl.Trainer(
        max_epochs=10,
        accelerator="auto",
        devices="auto",
        log_every_n_steps=10,
        logger=wandb_logger
    )

    print("🚀 Training best model...")
    trainer.fit(model, data_module.train_dataloader(), data_module.val_dataloader())

    print("🧪 Evaluating best model on test set...")
    trainer.test(model, data_module.test_dataloader())

    return model, data_module


In [None]:
model, data_module = train_best_model()

visualize_predictions(model, data_module.test_dataloader(), data_module.class_names)