In [None]:
import time
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader

from helper.datasets import get_dataset_loader
from helper.general import empty_gpu_cache
from helper.models import evaluate_model, get_model
from helper.config import TrainConfig
from helper.global_variables import DEVICE, TRAIN_YAML_PATH

samples_range = range(1000)

def fgsm_attack(model, loss_fn, dataset_loader, epsilon):
    modified_images = []
    labels_buffer = []
    model.eval()
    
    for images, labels in tqdm(dataset_loader, "FGSM Attack"):
        images = images.to(DEVICE)
        images_altered = images.clone()
        labels = labels.to(DEVICE)
        images.requires_grad_()
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        model.zero_grad()
        loss.backward()
        data_grad = images.grad.data
        images_altered = images + epsilon * data_grad.sign()
        images_altered = torch.clamp(images_altered, 0, 1)
        
        labels_buffer.append(labels)
        modified_images.append(images_altered)
        empty_gpu_cache()
    return torch.cat(modified_images), torch.cat(labels_buffer)

train_config = TrainConfig(TRAIN_YAML_PATH)
print(train_config.models)

loss_fn = nn.CrossEntropyLoss()

for model_name, model_path in train_config.models:
    model = get_model(model_name)
    model.load_state_dict(torch.load(model_path))
    
    # Get data and evaluate
    # 
    _, test_loader = get_dataset_loader(model_name.split("-")[0], 32)
    subset = torch.utils.data.Subset(
        test_loader.dataset,
        samples_range  # Take first 2500 samples
    )
    subset_loader = DataLoader(subset, batch_size=64, shuffle=False)
    
    empty_gpu_cache()
    accuracy_before = evaluate_model(model, subset_loader)
    
    model.eval()
    print(f"accuracy of {model_name} (before): {accuracy_before:.4f}")
    
    
    time_pre = time.perf_counter()
    epsilon = 0.1
    try: 
        perturbed_images, labels = fgsm_attack(model, 
                                    loss_fn, 
                                    subset_loader, 
                                    epsilon)
    except:
        print("Error in fgsm_attack")
    time_post = time.perf_counter() - time_pre
    print(f"FGSM time: {time_post:.2f} s")
    
    perturbed_dataset = torch.utils.data.TensorDataset(perturbed_images, labels)
    
    # Batch Size should be obtained from single place
    perturbed_loader = DataLoader(perturbed_dataset, batch_size=64, shuffle=False)
    accuracy_before = evaluate_model(model, perturbed_loader)

    empty_gpu_cache()

