In [1]:
# Basic Imports
import pathlib
from pathlib import Path
import os
import sys
from typing import Any, Sequence, Optional, Tuple, Iterator, Dict, Callable, Union
import json
import time
from tqdm.auto import tqdm
import numpy as np
from copy import copy
from glob import glob
from collections import defaultdict
import matplotlib.pyplot as plt

# Changing fonts to be latex typesetting
from matplotlib import rcParams
rcParams['mathtext.fontset'] = 'dejavuserif'
rcParams['font.family'] = 'serif'

# JAX/Flax
import jax
from jax import jit
import jax.numpy as jnp
from jax import random

# PyTorch for Dataloaders
import torch
import torch.utils.data as data
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torchvision.datasets import MNIST, CIFAR10
from torchvision import transforms

# Import created functions
import make_dataset as mkds
import visualization as vis

from numpy.random import default_rng
key = random.PRNGKey(42)
rng = default_rng(np.asarray(key))

# **Goal of this Notebook**: Perturb MNIST dataset and put it into a dataloader

TODO:
- Create custom dataset that perturbs the pytorch dataset and outputs the perturbed MNIST images with the Empirical field

**NOTE:** The perturbation hyperparameters have to larger than usual when running the `process_perturbed_data` function.

- Further testing needs to be done to figure out why, but as of right now, I'll continue on and come back to this once the model is working.

In [2]:
turb_training, turb_test = mkds.create_perturbed_dataset(download=False)

In [3]:
# perturbed_training = mkds.load_data(data_dir='/pscratch/sd/m/mdowicz/PFGM_MNIST/saved_data/MNIST/perturbed/',
#                                      data_file='perturbed_training.pkl')

# perturbed_test = mkds.load_data(data_dir='/pscratch/sd/m/mdowicz/PFGM_MNIST/saved_data/MNIST/perturbed/',
#                                      data_file='perturbed_test.pkl')

In [4]:
train, val, test = mkds.partition_MNIST(download=False)

In [5]:
# perturbed_training = mkds.load_data(data_dir='/pscratch/sd/m/mdowicz/PFGM_MNIST/saved_data/MNIST/perturbed/partitioned',
#                                      data_file='partitioned_training_set.pkl')

# perturbed_test = mkds.load_data(data_dir='/pscratch/sd/m/mdowicz/PFGM_MNIST/saved_data/MNIST/perturbed/partitioned',
#                                      data_file='partitioned_val_set.pkl')

# perturbed_test = mkds.load_data(data_dir='/pscratch/sd/m/mdowicz/PFGM_MNIST/saved_data/MNIST/perturbed/partitioned',
#                                      data_file='partitioned_test_set.pkl')

In [6]:
train_dl, val_dl, test_dl = mkds.load_dataloaders(batch_size=128)
train_batch = next(iter(train_dl))
train_batch

(array([[ 2.9594457e+04,  3.0179191e+04,  8.7360898e+04, ...,
         -3.7994027e+04,  3.9512625e+04,  3.6402961e+04],
        [-1.4059526e+01, -5.6157249e+01, -1.5860073e+01, ...,
         -3.3219299e+01,  9.5528245e-01,  2.9137549e+00],
        [-3.0235529e-02,  2.0162916e-01, -1.2602958e-01, ...,
         -4.6595004e-01, -1.3841748e-01,  4.7488004e-02],
        ...,
        [-1.5022976e+01,  3.8047273e+00,  1.3766521e+01, ...,
          4.1629248e+00,  2.8523464e+00,  1.0300952e+01],
        [-3.5767760e+00, -3.0446305e+00, -3.5145929e+00, ...,
         -1.2841077e-01,  2.2135780e+00,  1.9064258e+00],
        [ 2.3253174e+07, -1.1343905e+08,  1.3110527e+08, ...,
          2.3191926e+08, -2.4705858e+07,  4.0102307e+08]], dtype=float32),
 array([[-0.29568702, -0.3015293 , -0.8728488 , ...,  0.37960964,
         -0.3947824 , -0.36371282],
        [ 0.8431644 ,  3.3678083 ,  0.95114505, ...,  1.992196  ,
         -0.05728928, -0.17474091],
        [ 0.14380585, -0.9589861 ,  0.5994203 