# 24FS\_I4DS27: Adversarial Attacks \\ Wie kann KI überlistet werden? <br> 03-Training

In [None]:
%load_ext autoreload
%autoreload 2

import os
import torch
import wandb
import warnings
import torchvision

from lightning import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint

os.chdir("..")
from src.data.mri import MRIDataModule
from src.data.covidx import COVIDXDataModule
from src.models.imageclassifier import ImageClassifier

os.environ["WANDB_NOTEBOOK_NAME"] = "notebooks/03-training.ipynb"
warnings.filterwarnings("ignore", category=UserWarning)
torch.set_float32_matmul_precision("medium")

In [None]:
BATCH_SIZE = 32
OUTPUT_SIZE = 1
NUM_WORKERS = 24

WANDB_ENTITY = "24FS_I4DS27"
WANDB_PROJECT = "baselines"

medium_models = [
    "resnet18",
    "resnet50",
    "resnet152",
    "densenet121",
    "densenet169",
    "densenet201",
    "efficientnet_v2_s",
    "efficientnet_v2_m",
    "efficientnet_v2_l",
]

small_models = [
    "alexnet",
    "vgg11",
    "resnet18"
]

large_models = [
    "vit_b_16"
]

datasets = [
    "covidx_data",
    "mri_data",
]

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((224, 224), antialias=True),
    ]
)

In [None]:
def train():
    wandb.init()
    config = wandb.config

    wandb_logger = WandbLogger(log_model=True)

    model = ImageClassifier(
        modelname=config.model,
        output_size=OUTPUT_SIZE,
        p_dropout_classifier=config.p_dropout_classifier,
        weight_decay=config.weight_decay,
    )
    wandb_logger.watch(model, log_graph=False)

    if config.dataset == "covidx_data":
        datamodule = COVIDXDataModule(
            path="data/raw/COVIDX-CXR4",
            transform=transform,
            num_workers=NUM_WORKERS,
            batch_size=config.batch_size,
            train_sample_size=0.05,
            train_shuffle=True,
        ).setup()
    elif config.dataset == "mri_data":
        datamodule = MRIDataModule(
            path="data/raw/Brain-Tumor-MRI",
            path_processed="data/processed/Brain-Tumor-MRI",
            transform=transform,
            num_workers=NUM_WORKERS,
            batch_size=config.batch_size,
            train_shuffle=True,
        ).setup()

    trainer = Trainer(
        max_epochs=config.epochs,
        log_every_n_steps=1,
        gradient_clip_val=0.5,
        accelerator="auto",
        logger=wandb_logger,
        fast_dev_run=False,  # set to True to test run
        enable_progress_bar=True,
        enable_model_summary=True,
        callbacks=[
            # EarlyStopping(monitor="val_loss", mode="min", patience=1),
            ModelCheckpoint(
                monitor="val_loss",
                mode="min",
                save_top_k=1,  # save the best model
                save_last=True,  # save the last model
                dirpath=f"models/{model.modelname}",
                filename=f"{model.modelname}-lr{model.lr}-pdrop{model.p_dropout_classifier}-wd{model.weight_decay}",
            )
        ],
    )

    # Train the model
    trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())
    wandb_logger.experiment.unwatch(model)

    del config, wandb_logger, model, trainer
    wandb.finish()

In [4]:
for dataset_name in datasets: 
    for model_name in medium_models: # small_models, large_models, medium_models
        sweep_config = {
            "method": "grid",
            "metric": {"name": "val_loss", "goal": "minimize"},
            "parameters": {
                "model": {"values": [model_name]},
                "dataset": {"values": [dataset_name]},
                "lr": {"values": [1e-5]},
                "p_dropout_classifier": {"values": [0.0, 0.2, 0.5]},
                "weight_decay": {"values": [0.0, 1e-4, 1e-5]},
                "batch_size": {"values": [BATCH_SIZE]},
                "epochs": {"values": [50]},
            },
        }

        sweep_id = wandb.sweep(sweep_config, project=WANDB_PROJECT, entity=WANDB_ENTITY)
        wandb.agent(sweep_id, function=train)