In [4]:
from utils import LoginCredentials
import wandb

authenticator = LoginCredentials()

wandb.login(key=authenticator.wandb_key)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/henrismidt/.netrc


True

In [5]:
import torch
from torchvision import transforms
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader
from transformers import MobileViTImageProcessor
import wandb

from dataset import MRIImageDataModule, MRIDataset
from models import MobileViTLightning, EfficientNetBaseline
from utils import get_best_device, LoginCredentials, set_reproducibility
from sampler import WeightedRandomSampler

from datetime import datetime
import lightning.pytorch as pl
from sklearn.metrics import f1_score
from sklearn.metrics import precision_score, recall_score, classification_report, confusion_matrix
import os
import pickle

set_reproducibility(42)

wandb.finish()  # Make sure previous sessions are finished


csv_path = "Data/metadata_for_preprocessed_files.csv"

# Define sweep configuration
sweep_config = {
    "method": "grid",
    "metric": {"name": "val_loss", "goal": "minimize"},
    "parameters": {
        "model_name": {
            "values": ["MobileVit", "efficientnet-b2"]
        },  # ["efficientnet-b2""MobileVit", "efficientnet-b0", "efficientnet-b2", "efficientnet-b5"]
        "slice_number": {
            "values": [
                "65",
                # "86",
                # "56",
                # "95",
                # "62",
                # "35",
                # "59",
                # "74",
                # "80",
                # "134",
            ]  # ['65', '86', '56', '95', '62', '35', '59', '74', '80', '134', '41', '104', '101', '116', '68', '89', '107', '92', '71', '77', '113', '23', '98', '110', '131', '128', '125', '122', '119', '20', '83', '53', '50', '47', '44', '38', '32', '29', '26', '137']
        },
        "learning_rate": {"values": [0.00001]},
        "batch_size": {"values": [40]},
        "epochs": {"values": [60]},
        "sampling_strategy": {"values": ['log']}, #'inverse', 'sqrt',
        "smoothing": {"values": [10]}, #10 for mobilevit, 0 for efficientnet
        "self_distillation_alpha": {"values": [0.3, 0.5, 0.7]},
        "self_distillation_temperature": {"values": [1, 3, 7]}
    },
}

sweep_id = wandb.sweep(sweep=sweep_config, project="Alzheimer-Detection")

def load_soft_labels(slice_number, model_name):
    soft_labels_path = f"soft_labels/{model_name}/soft_labels_slice_{slice_number}.pkl"
    if os.path.exists(soft_labels_path):
        with open(soft_labels_path, 'rb') as f:
            soft_labels = pickle.load(f)
    else:
        soft_labels = None
    return soft_labels


# Define the training function

def train(config=None):
    with wandb.init(config=config):
        config = wandb.config
        device = get_best_device()

        if config.model_name == "MobileVit":
            # Load the preprocessor
            model_ckpt = "apple/mobilevit-small"
            processor = MobileViTImageProcessor.from_pretrained(model_ckpt)

            def transform(image):
                # Use MobileViTImageProcessor for preprocessing
                return processor(image, return_tensors="pt")["pixel_values"].squeeze(0)

            model = MobileViTLightning(model_ckpt=model_ckpt, num_labels=4, self_distillation_alpha=config.self_distillation_alpha, self_distillation_temperature=config.self_distillation_temperature)

        elif config.model_name.startswith("efficientnet"):
            transform = transforms.Compose(
                [
                    transforms.ToPILImage(),
                    transforms.Resize(224),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                    ),
                ]
            )
            # transform = None
            model = EfficientNetBaseline(
                model_name=config.model_name, lr=config.learning_rate, self_distillation_alpha=config.self_distillation_alpha, self_distillation_temperature=config.self_distillation_temperature
            )

        # Load soft labels
        soft_labels = load_soft_labels(int(config.slice_number), config.model_name)

        data_module = MRIImageDataModule(
            csv_path,
            slice_number=int(config.slice_number),
            transform=transform,
            batch_size=config.batch_size,
            num_workers=0,
            soft_labels=soft_labels,
        )
        data_module.setup()
        train_loader = data_module.train_dataloader(sampling_strategy=config.sampling_strategy, smoothing=config.smoothing)
        val_loader = data_module.val_dataloader()
        test_loader = data_module.test_dataloader()

        wandb_logger = WandbLogger()

        checkpoint_callback = ModelCheckpoint(
            dirpath=f"model_checkpoints/student_models/{config.model_name}",
            filename=f"slice_number_{config.slice_number}_sd_alpha_{config.self_distillation_alpha}_sd_temp_{config.self_distillation_temperature}",
            monitor="val_loss",
            mode="min",
            save_top_k=1,
        )

        trainer = L.Trainer(
            max_epochs=config.epochs,
            devices="auto",
            accelerator="auto",
            logger=wandb_logger,
            callbacks=[checkpoint_callback],
            log_every_n_steps=24,
        )

        trainer.fit(
            model=model, train_dataloaders=train_loader, val_dataloaders=val_loader
        )

        # Load best model for testing
        best_model_path = checkpoint_callback.best_model_path
        if config.model_name == "MobileVit":
            model_ckpt = "apple/mobilevit-small"
            best_model = MobileViTLightning.load_from_checkpoint(
                best_model_path, model_ckpt=model_ckpt, num_labels=4
            )

        elif config.model_name.startswith("efficientnet"):
            best_model = EfficientNetBaseline.load_from_checkpoint(
                best_model_path,
                model_name=config.model_name,
                num_classes=4,
            )


        # Evaluate on test set
        best_model = best_model.to(device)
        best_model.eval()
        all_preds = []
        all_labels = []
        all_ids = []
        with torch.no_grad():
            for batch in test_loader:
                inputs, labels, age, id = batch
                inputs = inputs.to(device).float()
                outputs = best_model(inputs)
                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_ids.extend(id)

            # Calculate metrics
            f1_weighted = f1_score(all_labels, all_preds, average="weighted")
            f1_individual = f1_score(all_labels, all_preds, average=None)
            precision = precision_score(all_labels, all_preds, average="weighted")
            recall = recall_score(all_labels, all_preds, average="weighted")
            conf_matrix = confusion_matrix(all_labels, all_preds)
            classification_rep = classification_report(all_labels, all_preds, output_dict=True)

            # Log metrics to wandb
            wandb.log({
                "test_f1_weighted": f1_weighted,
                "test_f1_individual": f1_individual,
                "test_precision": precision,
                "test_recall": recall,
                "confusion_matrix": conf_matrix,
                "classification_report": classification_rep
            })



# Run the sweep
# wandb.agent(sweep_id, function=train)

Seed set to 42


Create sweep with ID: 0n5egokx
Sweep URL: https://wandb.ai/finnhenri-smidt/Alzheimer-Detection/sweeps/0n5egokx


In [6]:
wandb.agent(sweep_id, function=train)

[34m[1mwandb[0m: Agent Starting Run: vr4tjz1p with config:
[34m[1mwandb[0m: 	batch_size: 40
[34m[1mwandb[0m: 	epochs: 60
[34m[1mwandb[0m: 	learning_rate: 1e-05
[34m[1mwandb[0m: 	model_name: MobileVit
[34m[1mwandb[0m: 	sampling_strategy: log
[34m[1mwandb[0m: 	self_distillation_alpha: 0.3
[34m[1mwandb[0m: 	self_distillation_temperature: 1
[34m[1mwandb[0m: 	slice_number: 65
[34m[1mwandb[0m: 	smoothing: 10


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011169073611111142, max=1.0…

Some weights of MobileViTForImageClassification were not initialized from the model checkpoint at apple/mobilevit-small and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 640]) in the checkpoint and torch.Size([4, 640]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([4]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.

  | Name         

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (9) is smaller than the logging interval Trainer(log_every_n_steps=24). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Traceback (most recent call last):
  File "/var/folders/l5/4bf7qdjd6k9b38g7jlnrcf580000gn/T/ipykernel_51213/3298305049.py", line 142, in train
    trainer.fit(
  File "/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 987, in _run
    results = self._run_stage()
  File "/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 1

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


[34m[1mwandb[0m: [32m[41mERROR[0m Problem finishing run
Traceback (most recent call last):
  File "/var/folders/l5/4bf7qdjd6k9b38g7jlnrcf580000gn/T/ipykernel_51213/3298305049.py", line 142, in train
    trainer.fit(
  File "/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 987, in _run
    results = self._run_stage()
  File "/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3