In [1]:
from utils import LoginCredentials
import wandb

authenticator = LoginCredentials()

wandb.login(key=authenticator.wandb_key)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mfinnhenri-smidt[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/henrismidt/.netrc


True

In [2]:
import torch
from torchvision import transforms
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader
from transformers import MobileViTImageProcessor
import wandb

from dataset import MRIFeatureDataModule, MRIFeatureDataset
from models import SimpleEnsembleModel, AdvancedEnsembleModel
from utils import get_best_device, LoginCredentials

from datetime import datetime
import lightning.pytorch as pl
import torch
import numpy as np
import random
from sklearn.metrics import f1_score

wandb.finish()  # Make sure previous sessions are finished


def set_reproducibility(seed=42):
    # Set Python random seed
    random.seed(seed)

    # Set Numpy seed
    np.random.seed(seed)

    # Set PyTorch seed
    torch.manual_seed(seed)

    # If using CUDA:
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # Control sources of nondeterminism
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # PyTorch Lightning utility to seed everything
    pl.seed_everything(seed, workers=True)


set_reproducibility(42)


# Define sweep configuration
sweep_config = {
    "method": "grid",
    "metric": {"name": "val_loss", "goal": "minimize"},
    "parameters": {
        "model_name": {"values": ["MobileVit-s", "efficientnet-b2"]},
        "ensemble_variant": {"values": ["simple", "advanced"]},
        "learning_rate": {"values": [0.0001, 0.00001, 0.000001]},
        "batch_size": {"values": [40]},
        "epochs": {"values": [60]},
    },
}

sweep_id = wandb.sweep(sweep=sweep_config, project="Alzheimer-Detection")

# Define the training function


def train(config=None):
    with wandb.init(config=config):
        config = wandb.config
        device = get_best_device()
        feature_folder = "extracted_features/"
        model_name = config.model_name
        train_pkl = f'{feature_folder}{model_name}_train_features_pooled.pkl'
        val_pkl = f'{feature_folder}{model_name}_val_features_pooled.pkl'
        test_pkl = f'{feature_folder}{model_name}_test_features_pooled.pkl'
        num_classes = 4 
        
        data_module = MRIFeatureDataModule(train_pkl=train_pkl, val_pkl=val_pkl, test_pkl=test_pkl, batch_size=config.batch_size)
        data_module.setup()
        train_loader = data_module.train_dataloader()
        val_loader = data_module.val_dataloader()
        test_loader = data_module.test_dataloader()
        featuremap_length = data_module.featuremap_length
        sequence_length = data_module.sequence_length

        if config.model_name == "MobileVit-s":
            if config.ensemble_variant == "simple":
                feature_size = featuremap_length * sequence_length
                model = SimpleEnsembleModel(feature_size=feature_size, num_classes=num_classes, lr=config.learning_rate)
            elif config.ensemble_variant == "advanced":
                feature_size = featuremap_length
                model = AdvancedEnsembleModel(feature_size=feature_size, num_classes=num_classes, lr=config.learning_rate)
        elif config.model_name == "efficientnet-b2":
            if config.ensemble_variant == "simple":
                feature_size = featuremap_length * sequence_length
                model = SimpleEnsembleModel(feature_size=feature_size, num_classes=num_classes, lr=config.learning_rate)
            elif config.ensemble_variant == "advanced":
                feature_size = featuremap_length
                model = AdvancedEnsembleModel(feature_size=feature_size, num_classes=num_classes, lr=config.learning_rate)

        wandb_logger = WandbLogger()

        checkpoint_callback = ModelCheckpoint(
            dirpath=f"model_checkpoints/ensemble_{config.ensemble_variant}_{config.model_name}",
            filename=f"lr_{config.learning_rate}",
            monitor="val_loss",
            mode="min",
            save_top_k=1,
        )

        trainer = L.Trainer(
            max_epochs=config.epochs,
            devices="auto",
            accelerator="auto",
            logger=wandb_logger,
            callbacks=[checkpoint_callback],
            log_every_n_steps=config.batch_size//2,
        )

        trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

        # Load best model for testing
        best_model_path = checkpoint_callback.best_model_path
        if config.ensemble_variant == "simple":
            best_model_class = SimpleEnsembleModel
            feature_size = featuremap_length * sequence_length
        else:
            best_model_class = AdvancedEnsembleModel
            feature_size = featuremap_length

        best_model = best_model_class.load_from_checkpoint(
            best_model_path,
            feature_size=feature_size,
            num_classes=num_classes,
            lr=config.learning_rate  # if lr is needed
        )
        best_model = best_model.to(device)
        best_model.eval()

        all_preds = []
        all_labels = []
        with torch.no_grad():
            for batch in test_loader:
                inputs, labels = batch
                inputs = inputs.to(device).float()
                outputs = best_model(inputs)
                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

            f1 = f1_score(all_labels, all_preds, average="weighted")
        wandb.log({"test_f1_score": f1})




# Run the sweep
# wandb.agent(sweep_id, function=train)


Seed set to 42


Create sweep with ID: 7mdfful6
Sweep URL: https://wandb.ai/finnhenri-smidt/Alzheimer-Detection/sweeps/7mdfful6


In [3]:
wandb.agent(sweep_id, function=train)

[34m[1mwandb[0m: Agent Starting Run: uorftf8a with config:
[34m[1mwandb[0m: 	batch_size: 40
[34m[1mwandb[0m: 	ensemble_variant: simple
[34m[1mwandb[0m: 	epochs: 60
[34m[1mwandb[0m: 	learning_rate: 0.0001
[34m[1mwandb[0m: 	model_name: MobileVit-s
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /Users/henrismidt/Documents/Informatik/Master/Alzheimer_Detection/model_checkpoints/ensemble_simple_MobileVit-s exists and is not empty.

  | Name | Type   | Params
--------------------------------
0 | fc   | Linear | 6.4 K 
--------------------------------
6.4 K     Trainable params
0         Non-trainable params
6.4 K     Total params
0.026     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (9) is smaller than the logging interval Trainer(log_every_n_steps=20). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=60` reached.


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
test_f1_score,▁
train_loss,█▇▆█▅▄▄▄▆▅▄▃▃▃▁▂▅▃▃▃▃▃▄▂▄▁▅
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
val_loss,█▅▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,59.0
test_f1_score,0.84113
train_loss,0.4173
trainer/global_step,539.0
val_loss,0.4427


[34m[1mwandb[0m: Agent Starting Run: myi55xb9 with config:
[34m[1mwandb[0m: 	batch_size: 40
[34m[1mwandb[0m: 	ensemble_variant: simple
[34m[1mwandb[0m: 	epochs: 60
[34m[1mwandb[0m: 	learning_rate: 0.0001
[34m[1mwandb[0m: 	model_name: efficientnet-b2
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /Users/henrismidt/Documents/Informatik/Master/Alzheimer_Detection/model_checkpoints/ensemble_simple_efficientnet-b2 exists and is not empty.

  | Name | Type   | Params
--------------------------------
0 | fc   | Linear | 56.3 K
--------------------------------
56.3 K    Trainable params
0         Non-trainable params
56.3 K    Total params
0.225     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (9) is smaller than the logging interval Trainer(log_every_n_steps=20). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=60` reached.


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

In [1]:
import torch
from torchvision import transforms
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader
from transformers import MobileViTImageProcessor
import wandb

from dataset import MRIFeatureDataModule, MRIFeatureDataset
from models import SimpleEnsembleModel, AdvancedEnsembleModel
from utils import get_best_device, LoginCredentials

from datetime import datetime
import lightning.pytorch as pl
import torch
import numpy as np
import random
from sklearn.metrics import f1_score
best_model_class = SimpleEnsembleModel
best_model = best_model_class.load_from_checkpoint('model_checkpoints/ensemble_simple_MobileVit-s/lr_0.0001.ckpt',
    feature_size=1600,
    num_classes=4)

