# 24FS\_I4DS27: Adversarial Attacks \\ Wie kann KI überlistet werden? <br> 03-Training

In [None]:
%load_ext autoreload
%autoreload 2

import os
import torch
import wandb
import torchvision

from lightning import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint

os.chdir("..")
from src.data.mri import MRIDataModule
from src.data.covidx import COVIDXDataModule
from src.models.imageclassifier import ImageClassifier

os.environ["WANDB_NOTEBOOK_NAME"] = "notebooks/03-training.ipynb"

In [None]:
BATCH_SIZE = 8
OUTPUT_SIZE = 1
NUM_WORKERS = 0

WANDB_ENTITY = "24FS_I4DS27"
WANDB_PROJECT = "baselines"

models = [
    "alexnet", 
    "vgg11", 
    "resnet18", 
    "densenet121", 
    "efficientnet_v2_m", 
    "vit_l_32"
]
datasets = [
    "mri_data",
    "covidx_data", 
]

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((448, 448), antialias=True),
    ]
)

In [None]:
def train():
    with wandb.init(project=WANDB_PROJECT, entity=WANDB_ENTITY) as run:
        config = wandb.config

        model = ImageClassifier(modelname=config.model, output_size=OUTPUT_SIZE, p_dropout_classifier=config.p_dropout_classifier, weight_decay=config.weight_decay)

        if config.dataset == "covidx_data":
            datamodule = COVIDXDataModule(
                path="data/raw/COVIDX-CXR4", transform=transform, batch_size=config.batch_size
            ).setup()
        elif config.dataset == "mri_data":
            datamodule = MRIDataModule(
                path="data/raw/Brain-Tumor-MRI", path_processed="data/processed/Brain-Tumor-MRI", transform=transform, batch_size=config.batch_size
            ).setup()

        trainer = Trainer(
            max_epochs=config.epochs,
            log_every_n_steps=1,
            gradient_clip_val=0.5,
            accelerator="auto",
            logger=WandbLogger(log_model="all"), # log_model=True log model checkpoints at the end of training
            fast_dev_run=False, # set to True to test run
            enable_progress_bar=False,
            enable_model_summary=False,
            callbacks=[
                #EarlyStopping(monitor="val_loss", mode="min", patience=1),
                ModelCheckpoint(monitor="val_loss", 
                                mode="min", 
                                save_top_k=1, # save the best model
                                save_last=True, # save the last model
                                dirpath=f"models/{model.modelname}", 
                                filename=f"{model.modelname}-lr{model.lr}-pdrop{model.p_dropout_classifier}-wd{model.weight_decay}") 
            ]
        )

        # Train the model
        trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())
        wandb.finish()

In [None]:
for model_name in models:
    for dataset_name in datasets:

        sweep_config = {
            "method": "grid",
            "metric": {"name": "val_loss", "goal": "minimize"},
            "parameters": {
                "model": {"values": [model_name]},  
                "dataset": {"values": [dataset_name]},
                "lr": {"values": [1e-5]},
                "p_dropout_classifier": {"values": [0.0, 0.2]},
                "weight_decay": {"values": [0.0, 0.0001]},
                "batch_size": {"values": [BATCH_SIZE]},
                "epochs": {"values": [2]},
            },
        }
        
        sweep_id = wandb.sweep(sweep_config, project=WANDB_PROJECT, entity=WANDB_ENTITY)
        wandb.agent(sweep_id, function=train)