# 24FS\_I4DS27: Adversarial Attacks \\ Wie kann KI überlistet werden? <br> 02-Models

In [None]:
%load_ext autoreload
%autoreload 2

import os
os.chdir("..")

In [1]:
import torch
import torchvision
import wandb
import pytorch_lightning as pl 

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

from src.data.covidx import COVIDXDataModule
from src.data.mri import MRIDataModule
from src.models.imageclassifier import ImageClassifier

In [None]:
BATCH_SIZE = 8
OUTPUT_SIZE = 1
NUM_WORKERS = 0

WANDB_ENTITY = "7ben18" # "gabrieltorresgamez" 
WANDB_PROJECT = "24FS_I4DS27"

models = ["alexnet", "vgg11", "resnet18", "densenet121", "efficientnet_v2_m", "vit_l_32"]
datasets = ["covidx_data", "mri_data"]

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((448, 448), antialias=True),
    ]
)


In [None]:
def train():
    # Initialize a new wandb run
    with wandb.init(project=WANDB_PROJECT, entity=WANDB_ENTITY) as run:
        # Config is a variable that holds and saves hyperparameters and inputs
        config = wandb.config

        # Initialize a new model
        model = ImageClassifier(
            modelname=config.model, output_size=OUTPUT_SIZE, p_dropout_classifier=0.2
        )

        # Initialize a new datamodule
        if config.dataset == "covidx_data":
            datamodule = COVIDXDataModule(
                path="data/raw/COVIDX-CXR4", transform=transform, batch_size=config.batch_size
            ).setup()
        elif config.dataset == "mri_data":
            datamodule = MRIDataModule(
                path="data/processed/Brain-Tumor-MRI", transform=transform, batch_size=config.batch_size
            ).setup()

        # Initialize a new trainer
        trainer = Trainer(
            max_epochs=config.epochs,
            gpus=1 if torch.cuda.is_available() else None,
            logger=WandbLogger(),
            callbacks=[EarlyStopping(monitor="val_loss"), ModelCheckpoint(monitor="val_loss")],
        )

        # Train the model
        trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())

        # Evaluate the model
        trainer.test(model, datamodule.test_dataloader())

        # Log the final test accuracy
        run.log({"test_loss": trainer.callback_metrics["test_loss"]})

In [None]:
sweep_config = {
    "method": "grid",
    "metric": {"name": "val_loss", "goal": "minimize"},
    "parameters": {
        "model": {"values": models}, # replace with alexnet, vgg11, resnet18, densenet121, efficientnet_v2_m, vit_l_32
        "dataset": {"values": datasets}, # replace with one dataset 
        "lr": {"values": [1e-3, 1e-4]},
        "batch_size": {"values": [BATCH_SIZE]},
        "epochs": {"values": [10]},
    },
}

In [None]:
sweep_id = wandb.sweep(sweep_config, project=WANDB_PROJECT, entity=WANDB_ENTITY)
wandb.agent(sweep_id, function=train) 

---

In [None]:
# Get AlexNet
alexnet = ImageClassifier(modelname="alexnet", output_size=OUTPUT_SIZE, p_dropout_classifier=0.2)

# Get MRI Data
mri_datamodule = MRIDataModule(path="data/processed/Brain-Tumor-MRI", transform=transform, batch_size=BATCH_SIZE).setup()

# Get all loaders
mri_train_loader = mri_datamodule.train_dataloader()
mri_val_loader = mri_datamodule.val_dataloader()
mri_test_loader = mri_datamodule.test_dataloader()

image, label = next(iter(mri_datamodule.train_dataloader()))
print(f"Input shape: {image.shape}")
print(f"Label shape: {label.shape}")

----

In [2]:
model = ImageClassifier(modelname="vit_l_32", output_size=1, p_dropout_classifier=0.2)

In [3]:
transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((448, 448), antialias=True),
    ]
)

covidx_datamodule = COVIDXDataModule(path="data/raw/COVIDX-CXR4", transform=transform, batch_size=8).setup()

for batch in covidx_datamodule.train_dataloader():
    image, label = batch
    print(model(image))
    break

tensor([[ 0.3510],
        [ 0.0796],
        [ 0.1362],
        [-0.1810],
        [ 0.0261],
        [ 0.4587],
        [ 0.4103],
        [ 0.7684]], grad_fn=<AddmmBackward0>)


---