In [None]:
!apt-get update

In [None]:
!apt-get install build-essential libatomic1 gfortran perl wget m4 cmake pkg-config curl -y

In [None]:
import torch
import pickle

from models.models import EquivariantVgg16
from torch.utils.data import DataLoader
from dataset.dataset import PlanetaryDataset
from utilities.training import VAL_TEST_TRANSFORM
from torch.utils.data import DataLoader

# 4 models per chunk (so that the memory doesn't run out)
models = {
    "/kaggle/input/vgg16-models/pytorch/default/2/model_channels_30_31_32_33_34_35.pt": EquivariantVgg16(
        num_classes=2
    ),
    "/kaggle/input/vgg16-models/pytorch/default/2/model_channels_36_37_38_39_40_41.pt": EquivariantVgg16(
        num_classes=2
    ),
    "/kaggle/input/vgg16-models/pytorch/default/2/model_channels_42_43_44_45_46_47.pt": EquivariantVgg16(
        num_classes=2
    ),
    "/kaggle/input/vgg16-models/pytorch/default/2/model_channels_48_49_50_51_52_53.pt": EquivariantVgg16(
        num_classes=2
    ),
}


def evaluate_and_save_predictions(models, dataset, device, pred_file, prob_file):
    test_loader = DataLoader(
        dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True
    )

    all_predictions = {}
    all_probabilities = {}

    with torch.no_grad():
        for model_path, model in models.items():
            model.load_state_dict(torch.load(model_path))
            model.to(device)
            model.eval()
            predictions = []
            probabilities = []

            for batch in test_loader:
                images, labels = batch
                images = images.to(device)

                # Run inference
                logits = model(images)
                probs = torch.softmax(logits, dim=1)
                preds = torch.argmax(probs, dim=1)
                print(f"Probs: {probs}")

                predictions.append(preds.cpu().numpy())
                probabilities.append(probs.cpu().numpy())

            all_predictions[model_path] = predictions
            all_probabilities[model_path] = probabilities

    with open(pred_file, "wb") as f_pred, open(prob_file, "wb") as f_prob:
        pickle.dump(all_predictions, f_pred)
        pickle.dump(all_probabilities, f_prob)


test_dataset = PlanetaryDataset(
    data_dir="/kaggle/input/gsoc-protoplanetary-disks/Test_Clean",
    csv_file="/kaggle/input/gsoc-protoplanetary-disks/test_info.csv",
    channels=list(range(30, 101)),
    transform=VAL_TEST_TRANSFORM,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

evaluate_and_save_predictions(
    models,
    test_dataset,
    device,
    "model_1_4_predictions.pkl",
    "model_1_4_probabilities.pkl",
)