# 24FS\_I4DS27: Adversarial Attacks \\ Wie kann KI überlistet werden? <br> 03-Training

In [1]:
%load_ext autoreload
%autoreload 2

import os
import torch
import wandb
import torchvision

from lightning import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint

os.chdir("..")
from src.data.mri import MRIDataModule
from src.data.covidx import COVIDXDataModule
from src.models.imageclassifier import ImageClassifier

In [2]:
BATCH_SIZE = 8
OUTPUT_SIZE = 1
NUM_WORKERS = 0

WANDB_ENTITY = "24FS_I4DS27"
WANDB_PROJECT = "baselines"

models = ["alexnet", "vgg11", "resnet18", "densenet121", "efficientnet_v2_m", "vit_l_32"]
datasets = [
                #"covidx_data", 
                "mri_data",
           ]

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((448, 448), antialias=True),
    ]
)

In [None]:
def train():
    # Initialize a new wandb run
    with wandb.init(project=WANDB_PROJECT, entity=WANDB_ENTITY) as run:
        # Config is a variable that holds and saves hyperparameters and inputs
        config = wandb.config

        # Initialize a new model
        model = ImageClassifier(modelname=config.model, output_size=OUTPUT_SIZE, p_dropout_classifier=config.p_dropout_classifier)

        # Initialize a new datamodule
        if config.dataset == "covidx_data":
            datamodule = COVIDXDataModule(
                path="data/raw/COVIDX-CXR4", transform=transform, batch_size=config.batch_size
            ).setup()
        elif config.dataset == "mri_data":
            datamodule = MRIDataModule(
                path="data/raw/Brain-Tumor-MRI", path_processed="data/processed/Brain-Tumor-MRI", transform=transform, batch_size=config.batch_size
            ).setup()

        # Initialize a new trainer
        trainer = Trainer(
            max_epochs=config.epochs,
            accelerator="auto",
            logger=WandbLogger(),
            callbacks=[
                #EarlyStopping(monitor="val_loss"),
                ModelCheckpoint(monitor="val_loss"),
            ],
            log_every_n_steps=1,
        )

        # Train the model
        trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())

In [None]:
sweep_config = {
    "method": "grid",
    "metric": {"name": "val_loss", "goal": "minimize"},
    "parameters": {
        "model": {"values": models},  # replace with alexnet, vgg11, resnet18, densenet121, efficientnet_v2_m, vit_l_32
        "dataset": {"values": datasets},  # replace with one dataset
        "lr": {"values": [1e-3, 1e-4]},
        "p_dropout_classifier": {"values": [0.0, 0.2]},
        "batch_size": {"values": [BATCH_SIZE]},
        "epochs": {"values": [10]},
    },
}

In [None]:
sweep_id = wandb.sweep(sweep_config, project=WANDB_PROJECT, entity=WANDB_ENTITY)
wandb.agent(sweep_id, function=train)