In [None]:
from utils import LoginCredentials
import wandb

authenticator = LoginCredentials()

wandb.login(key=authenticator.wandb_key)

In [None]:
import torch
from torchvision import transforms
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader
from transformers import MobileViTImageProcessor
import wandb

from dataset import MRIFeatureDataModule, MRIFeatureDataset
from models import SimpleEnsembleModel, AdvancedEnsembleModel
from utils import get_best_device, LoginCredentials

from datetime import datetime
import lightning.pytorch as pl
import torch
import numpy as np
import random
from sklearn.metrics import f1_score

wandb.finish()  # Make sure previous sessions are finished


def set_reproducibility(seed=42):
    # Set Python random seed
    random.seed(seed)

    # Set Numpy seed
    np.random.seed(seed)

    # Set PyTorch seed
    torch.manual_seed(seed)

    # If using CUDA:
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # Control sources of nondeterminism
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # PyTorch Lightning utility to seed everything
    pl.seed_everything(seed, workers=True)


set_reproducibility(42)



# Define sweep configuration
sweep_config = {
    "method": "grid",
    "metric": {"name": "val_loss", "goal": "minimize"},
    "parameters": {
        "model_name": {"values": ["MobileVit-s", "efficientnet-b2"]},
        "ensemble_variant": {"values": ["simple", "advanced"]},
        "learning_rate": {"values": [0.00001]},
        "batch_size": {"values": [40]},
        "epochs": {"values": [60]},
    },
}

sweep_id = wandb.sweep(sweep=sweep_config, project="Alzheimer-Detection")

# Define the training function


def train(config=None):
    with wandb.init(config=config):
        config = wandb.config
        device = get_best_device()
        feature_folder = "extracted_features/"
        model_name = config.model_name
        train_pkl = f'{feature_folder}{model_name}_train_features_pooled.pkl'
        val_pkl = f'{feature_folder}{model_name}_val_features_pooled.pkl'
        test_pkl = f'{feature_folder}{model_name}_test_features_pooled.pkl'
        if config.model_name == "MobileVit-s":
            if config.ensemble_variant == "simple":
                model = SimpleEnsembleModel(feature_size=, num_labels=4)
            elif config.ensemble_variant == "advanced":
                model 
        elif config.model_name == "efficientnet-b2":
            model = 

        data_module = MRIFeatureDataModule(train_pkl=train_pkl, val_pkl=val_pkl, test_pkl=test_pkl, batch_size=config.batch_size)
        data_module.setup()
        train_loader = data_module.train_dataloader()
        val_loader = data_module.val_dataloader()
        test_loader = data_module.test_dataloader()

        wandb_logger = WandbLogger()

        checkpoint_callback = ModelCheckpoint(
            dirpath=f"model_checkpoints/ensemble_{config.ensemble_variant}_{config.model_name}",
            filename=f"slice_number_{config.slice_number}_lr_{config.learning_rate}",
            monitor="val_loss",
            mode="min",
            save_top_k=1,
        )

        trainer = L.Trainer(
            max_epochs=config.epochs,
            devices="auto",
            accelerator="auto",
            logger=wandb_logger,
            callbacks=[checkpoint_callback],
            log_every_n_steps=24,
        )

        trainer.fit(
            model=model, train_dataloaders=train_loader, val_dataloaders=val_loader
        )

        # Load best model for testing
        best_model_path = checkpoint_callback.best_model_path
        if config.model_name == "MobileVit":
            model_ckpt = "apple/mobilevit-small"
            best_model = MobileViTLightning.load_from_checkpoint(
                best_model_path, model_ckpt=model_ckpt, num_labels=4
            )

        elif config.model_name.startswith("efficientnet"):
            best_model = EfficientNetBaseline.load_from_checkpoint(
                best_model_path,
                model_name=config.model_name,
                num_classes=4,
            )

        # Evaluate on test set
        best_model = best_model.to(device)
        best_model.eval()
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for batch in test_loader:
                inputs, labels, age, id = batch
                inputs = inputs.to(device).float()
                outputs = best_model(inputs)
                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

            f1 = f1_score(all_labels, all_preds, average="weighted")
        wandb.log({"test_f1_score": f1})


# Run the sweep
# wandb.agent(sweep_id, function=train)

In [None]:
wandb.agent(sweep_id, function=train)