In diesem Notebook wird die Pipeline aufgebaut.

In [1]:
%load_ext autoreload
%autoreload 2

import os
if os.getcwd() == '/home/jovyan/work': # jhub
    os.chdir("24FS_I4DS27/main/") 
    os.system("make reqs")
else: # local
    os.chdir("../")

In [2]:
import torch
import torchvision
import matplotlib.pyplot as plt

from src.data.mri import MRIDataModule
from src.data.covidx import COVIDXDataModule
from src.utils.download import download_models
from src.utils.transform_perturbation import AddImagePerturbation
from src.utils.uap_helper import generate_adversarial_images_from_model_dataset, get_model
from src.utils.adv_training import pipeline, get_transform
from src.models.imageclassifier import ImageClassifier

from lightning.pytorch import Trainer
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

plt.rcParams["figure.dpi"] = 200
plt.rcParams["figure.figsize"] = (16, 8)

torch.set_float32_matmul_precision('high')

In [3]:
ENTITY = "24FS_I4DS27"
PROJECT = "baselines"
NUM_WORKERS = 8

def get_datamodule(dataset, transform=get_transform(), num_workers=0, batch_size=1, seed=42):
    if dataset == "covidx_data":
        return COVIDXDataModule(
            path="data/raw/COVIDX-CXR4",
            transform=transform,
            num_workers=num_workers,
            batch_size=batch_size,
            train_sample_size=0.05,
            train_shuffle=True,
            seed=seed,
        ).setup()

    elif dataset == "mri_data":
        return MRIDataModule(
            path="data/raw/Brain-Tumor-MRI",
            path_processed="data/processed/Brain-Tumor-MRI",
            transform=transform,
            num_workers=num_workers,
            batch_size=batch_size,
            train_shuffle=True,
            seed=seed,
        ).setup()

    else:
        raise ValueError("Invalid dataset")

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

In [4]:
# Download models if not present
models = download_models(ENTITY, PROJECT)

In [None]:
modelname, dataset = "resnet50", "covidx_data"
print(f"\n---\nModel: {modelname} - Dataset: {dataset}")
print(f"Device: {device}")


lr = 1e-1
max_retries = 10
for attempt in range(max_retries):
    try:
        pipeline(
            modelname=modelname,
            dataset=dataset,
            n_robustifications=10,
            i=5,
            n=1000,
            t=30,
            p=2,
            lambda_norm=0.001,
            r=0.5,
            eps=1e-6,
            lr_uap=lr,
            seed=42,
            num_workers=NUM_WORKERS,
            device=device,
            verbose=False,
        )
    # CUDA out of memory error
    except RuntimeError as e:
        print(f"Attempt {attempt + 1} failed: {e}")
        if attempt >= max_retries - 1:
            print("All retries failed, lowering learning rate")
            lr /= 10**0.5


---
Model: resnet50 - Dataset: covidx_data
Device: cuda


Universal Pertubation:   0%|          | 0/5 [00:00<?, ?it/s]

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return F.conv2d(input, weight, bias, self.stride,


Attempt 1 failed: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 


/opt/conda/lib/python3.11/site-packages/lightning/fabric/loggers/csv_logs.py:268: Experiment logs directory robustified_models/resnet50-covidx_data-n_1000-robustification_0/01_UAPs_pre_robustification exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!


Universal Pertubation:   0%|          | 0/5 [00:00<?, ?it/s]

Attempt 2 failed: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 


/opt/conda/lib/python3.11/site-packages/lightning/fabric/loggers/csv_logs.py:268: Experiment logs directory robustified_models/resnet50-covidx_data-n_1000-robustification_0/01_UAPs_pre_robustification exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!


Universal Pertubation:   0%|          | 0/5 [00:00<?, ?it/s]