# still in construction

In [1]:
# General
import torch
import torch.nn as nn

from torch.utils.data import DataLoader

# Configuration
from omegaconf import OmegaConf

# General
import torch

# Import src folder
import sys
import os
from pathlib import Path

home = Path(os.path.abspath('')).parent
sys.path.append(os.path.join(home, "src"))

# Now import alva
from alva import generate_samples_with_iterative_epsilons

# Import to display result
from utils import plot_prediction_switch, set_random_seed

# Import to load config
from omegaconf import OmegaConf
#from classifiers.mnist import LeNet5
from data import mnist, PerturbatedMnist
from utils import set_random_seed, split_tensor_random, unnormalize_tensor
from training import training_loop

### Load config and hyperparameters

In [4]:
CFG_PATH_HYPERPARAMS=r"config\training_config.yaml"
CFG_PATH_DATA = r"config\mnist_data.yaml"

cfg_hyperparams = OmegaConf.load(CFG_PATH_HYPERPARAMS)
cfg_data = OmegaConf.load(CFG_PATH_DATA)

# Set random seed
set_random_seed(0)
# Set random seed for reproducibility
# set_random_seed(cfg.RANDOM_SEED)

# Get the device for cuda optimization
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set Hyperparameters
BATCH_SIZE = cfg_hyperparams.BATCH_SIZE
LEARNING_RATE = cfg_hyperparams.LEARNING_RATE
N_EPOCHS = cfg_hyperparams.N_EPOCHS
N_EPSILON_EPOCHS = cfg_hyperparams.N_CONV_EPOCHS

# General information about mnist
#DATA_DIM = tuple(cfg.dataset.shape.values())
#CLASSES = cfg.dataset.classes

# Load Data
per_training_data = PerturbatedMnist(cfg_data.ROOT, 'training', transform = mnist.get_standard_transformation())
per_test_data = PerturbatedMnist(cfg_data.ROOT, 'test', transform = mnist.get_standard_transformation())

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST\raw\train-images-idx3-ubyte.gz


100.0%


Extracting data/MNIST\raw\train-images-idx3-ubyte.gz to data/MNIST\raw


100.0%


Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST\raw\train-labels-idx1-ubyte.gz
Extracting data/MNIST\raw\train-labels-idx1-ubyte.gz to data/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz



17.9%

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST\raw\t10k-images-idx3-ubyte.gz


100.0%
100.0%

Extracting data/MNIST\raw\t10k-images-idx3-ubyte.gz to data/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting data/MNIST\raw\t10k-labels-idx1-ubyte.gz to data/MNIST\raw

torch.Size([0]) torch.Size([0])
torch.Size([0]) torch.Size([0])





In [3]:
def generate_samples(cfg):
    # Generate samples
    _, generator, (z, y, per_z, per_y, epsilons) = generate_samples_with_iterative_epsilons_by_config()
    print(f"Generated {len(z)} adversarial samples with generator")
    # Save pictures for later.
    x, per_x = generator(z).detach().cpu(), generator(per_z).detach().cpu()
    return x, y, per_x, per_y, epsilons

In [4]:
perturbated_percentage = []
for epsilon_epoch in range(N_EPSILON_EPOCHS):
    # Declare experiment specific variables
    run_name = f"{epsilon_epoch:03}"
    losses_path = f'{cfg.save_paths.losses}/loss_{run_name}.pt'
    epsilons_path = f'{cfg.save_paths.epsilons}/epsilons_{run_name}.pt'
    model_path = f'{cfg.save_paths.models}/lenet_{run_name}.pth'

    # Reload data and create Dataloader
    per_training_data.load_data()
    per_test_data.load_data()
    train_loader = DataLoader(per_training_data, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(per_test_data, batch_size=BATCH_SIZE, shuffle=True)

    perturbated_percentage.append((per_training_data.get_perturbated_percentage(), per_test_data.get_perturbated_percentage()))

    # Start run
    print(f'\n--------------------------')
    print(f'Executing Experiment: #{run_name}')
    print(f'\nPerturbated images: {round(per_training_data.get_perturbated_percentage(), 3)}%')
    print(f'--------------------------\n')

    # Reinitialize model
    model = LeNet5().to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    # Train the model on the dataset
    print("\nTRAINING\n")
    model, optimizer, metrics = training_loop(model, criterion, optimizer, train_loader, test_loader, N_EPOCHS, DEVICE)

    # Save the model and load it into the config
    torch.save(model, model_path)
    cfg.models.classifier = model_path

    # Exceut the pipelin to generate samples for every class
    per_xs = []
    targets = []
    all_target_figures = []
    epsilons  = []

    print("\nGENERATING\n")
    for target in CLASSES:
        add_config_entry(cfg, 'target', target)
        print("\nGoing for " + str(target) + "\n")

        # Generating images and storing figures 
        x, y, per_x, per_y, epsilon = generate_samples(cfg)
        figures = generate_plots(x, y, per_x, per_y, cfg.target, original_training_data, DATA_DIM, cfg.dataset.name)

        per_xs.append(unnormalize_tensor(per_x))
        targets.append(torch.full((1, per_x.shape[0]), target, dtype=int))
        all_target_figures.append((target, figures))
        epsilons.append((target, epsilon))

    # Process data
    x_test, y_test, x_train, y_train = split_tensor_random(torch.cat(per_xs).view(-1, 28,28), torch.cat(targets, dim=1).view(-1))

    # Save torchs
    torch.save(torch.Tensor(metrics), losses_path)
    torch.save(epsilons, epsilons_path)
    torch.save((x_train, y_train), f'{cfg.save_paths.runs}/{run_name}_training.pt')
    torch.save((x_test, y_test), f'{cfg.save_paths.runs}/{run_name}_test.pt')
    torch.save(torch.Tensor(perturbated_percentage), cfg.save_paths.general + "/perturbated_percentage.pt")



--------------------------
Executing Experiment: #000

Perturbated images: 0.0%
--------------------------


TRAINING



KeyboardInterrupt: 