In [None]:
from utils import LoginCredentials
import wandb

authenticator = LoginCredentials()

wandb.login(key=authenticator.wandb_key)

In [1]:
import torch
from torchvision import transforms
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader
from transformers import MobileViTImageProcessor
import wandb

from dataset import MRIImageDataModule, MRIDataset
from models import MobileViTLightning, EfficientNetBaseline
from utils import get_best_device, LoginCredentials

from datetime import datetime
import lightning.pytorch as pl
import torch
import numpy as np
import random
from sklearn.metrics import f1_score

wandb.finish() # Make sure previous sessions are finished

authenticator = LoginCredentials()
wandb.login(key=authenticator.wandb_key)

def set_reproducibility(seed=42):
    # Set Python random seed
    random.seed(seed)
    
    # Set Numpy seed
    np.random.seed(seed)
    
    # Set PyTorch seed
    torch.manual_seed(seed)
    
    # If using CUDA:
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    
    # Control sources of nondeterminism
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # PyTorch Lightning utility to seed everything
    pl.seed_everything(seed, workers=True)

set_reproducibility(42)


csv_path = 'Data/metadata_for_preprocessed_files.csv'

# Define sweep configuration
sweep_config = {
    'method': 'grid',
    'parameters': {
        'model_name': {
            'values': ['EfficientNetb0']
        },
        'slice_number': {
            'values': ['65', '86', '56', '95', '62']  # ['65', '86', '56', '95', '62', '35', '59', '74', '80', '134', '41', '104', '101', '116', '68', '89', '107', '92', '71', '77', '113', '23', '98', '110', '131', '128', '125', '122', '119', '20', '83', '53', '50', '47', '44', '38', '32', '29', '26', '137']
        }
    }
}

sweep_id = wandb.sweep(sweep=sweep_config, project="Alzheimer-Detection")

# Define the training function
def train(config=None):
    with wandb.init(config=config):
        config = wandb.config

        if config.model_name == "MobileVit":
            # Load the preprocessor
            model_ckpt = "apple/mobilevit-x-small"
            processor = MobileViTImageProcessor.from_pretrained(model_ckpt)

            def transform(image):
                # Use MobileViTImageProcessor for preprocessing
                return processor(image, return_tensors="pt")["pixel_values"].squeeze(0)
            
            model = MobileViTLightning(model_ckpt=model_ckpt, num_labels=4)
            
        elif config.model_name == 'EfficientNetb0':
            transform = None
            model = EfficientNetBaseline()

        
        data_module = MRIImageDataModule(csv_path, slice_number=int(config.slice_number), transform=transform, batch_size=16, num_workers=0)
        data_module.setup()
        train_loader = data_module.train_dataloader()
        val_loader = data_module.val_dataloader()
        test_loader = data_module.test_dataloader()

        wandb_logger = WandbLogger()

        checkpoint_callback = ModelCheckpoint(
            dirpath=f'model_checkpoints/{config.model_name}',
            filename=f'slice_numer_{config.slice_number}',
            monitor='val_loss',
            mode='min',
            save_top_k=1
        )

        trainer = L.Trainer(
            max_epochs=30,
            devices='auto',
            accelerator='auto',
            logger=wandb_logger,
            callbacks=[checkpoint_callback],
            log_every_n_steps=24
        )

        trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

        # Load best model for testing
        # best_model_path = checkpoint_callback.best_model_path
        # best_model = MobileViTLightning.load_from_checkpoint(best_model_path, model_ckpt=model_ckpt, num_labels=4)

        # # Evaluate on test set
        # best_model.eval()
        # all_preds = []
        # all_labels = []
        # for batch in test_loader:
        #     inputs, labels = batch
        #     outputs = best_model(inputs)
        #     preds = torch.argmax(outputs, dim=1)
        #     all_preds.extend(preds.cpu().numpy())
        #     all_labels.extend(labels.cpu().numpy())

        # f1 = f1_score(all_labels, all_preds, average='weighted')
        # wandb.log({'test_f1_score': f1})

# Run the sweep
wandb.agent(sweep_id, function=train)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mfinnhenri-smidt[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/henrismidt/.netrc
Seed set to 42


Create sweep with ID: dba53nlg
Sweep URL: https://wandb.ai/finnhenri-smidt/Alzheimer-Detection/sweeps/dba53nlg


[34m[1mwandb[0m: Agent Starting Run: goa2q277 with config:
[34m[1mwandb[0m: 	model_name: EfficientNetb0
[34m[1mwandb[0m: 	slice_number: 65
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011168671288888888, max=1.0…

Loaded pretrained weights for efficientnet-b0


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | EfficientNet     | 4.0 M 
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
4.0 M     Trainable params
0         Non-trainable params
4.0 M     Total params
16.051    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/henrismidt/anaconda3/envs/alzheimer/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (21) is smaller than the logging interval Trainer(log_every_n_steps=24). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]