In [None]:
%load_ext autoreload
%autoreload 2

%cd -q ..

In [None]:
import os
from datetime import datetime

# DATA SCIENCE
import numpy as np
import pandas as pd

# PYTORCH
import torch
import torchvision
import torchvision.datasets as datasets

# PLOTLY
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# PROJECT
import base
from plots.images import plot_images
from tensorutils import describe
from experiment import Experiment
from models.linear_regression import LinearRegression
from models.triple_dense import TripleDense
from data_loaders.transformations import ReshapeTransform

In [None]:
data_dir = base.DATA_DIR_NAME
batch_size = 256
loss_fn = torch.nn.CrossEntropyLoss()
# MEAN = 0.1307
# STD  = 0.3081
MEAN = 0.5
STD  = 0.5

transformations = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            (MEAN,), (STD,)
        ),
        ReshapeTransform((-1,)),

    ]
)

train_dataset = datasets.MNIST(root=data_dir, train=True, transform=transformations)
test_dataset = datasets.MNIST(root=data_dir, train=False, transform=transformations)

data_shape = train_dataset.data.shape[1:]
x_dim = data_shape[0]*data_shape[0]
n_labels = len(train_dataset.classes)

# train_dataset, _ = torch.utils.data.random_split(train_dataset, [10000, 50000])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=3)

In [None]:
hidden1 = 500 # 300
hidden2 = 300 # 100
model = TripleDense(x_dim, hidden1, hidden2, n_labels)

In [None]:
# device = 'cpu'
# device = 'cuda'
device = None
optimizer_class = torch.optim.Adam
experiment = Experiment(model, loss_fn, device=device, optimizer_class=optimizer_class)

In [None]:
start = datetime.now()
experiment.train_model(train_loader=train_loader, num_epochs=10, valid_loader=test_loader, validate_every=2)

last = datetime.now() - start
print("Lasted = {}".format(last))

In [None]:
targets = torch.arange(10, device='cuda')
start = datetime.now()
x_adv = experiment.train_adversarial(x_shape=(x_dim,), targets=targets, num_epochs=20, info_every=5)
last = datetime.now() - start
print("Lasted = {}".format(last))

In [None]:
data = x_adv.cpu().numpy().reshape(-1, 28, 28)
vmax = 2
data = data.clip(-vmax,vmax)
labels = targets.cpu().numpy()

In [None]:
plot_images(images=data, labels=labels)

In [None]:
x_adv.shape

In [None]:
probs = experiment.predict(x_adv)
probs = probs.detach().cpu().numpy()
probs

In [None]:
probs.round(2)