In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm
# %run ConV_VAE.ipynb
%run bigger_latent_space.ipynb
%run loader.ipynb
%run test.ipynb
%run train.ipynb

In [2]:
plt.rcParams['figure.dpi'] = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'
####### IMPORTANT ######
# set this flag to true if you want to load the model beforehand
load = False
# specify which model you want to load
model_name = None

latent_space = 2048
batch_size=128

In [3]:
normalize = transforms.Lambda(lambda x : x / 255.0)
transform = transforms.Compose([
    transforms.ToTensor(),# replace with to pil image?
    normalize
])
train_data = CustomImageDataset('../data/sign_mnist_train.csv', transform=transform)
train_data_raw = CustomImageDataset('../data/sign_mnist_train.csv', transform=transforms.ToTensor())
test_data = CustomImageDataset('../data/sign_mnist_test.csv', transform=transform)

In [4]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)

In [5]:
model = ConvVarAutoencoder(latent_space).to(device)
if (load):
    model.load_state_dict(torch.load(model_name, map_location=device))

criterion = F.mse_loss
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.1)

In [6]:
epochs = 5
train_losses = []
test_losses = []
for epoch in tqdm(range(epochs)):  # loop over the dataset multiple times
    # Train on data
    train_loss = train_vae(train_loader, model, optimizer, device)
    test_loss = test(epoch, model, test_loader)
    print('====> Average Training loss per image: {:.4f}'.format(train_loss.item()))
    print('====>Average  Test set loss per image : {:.4f}'.format(test_loss))
    if epoch % 5 == 0:
        torch.save(model.state_dict(), "../model/model" + str(epoch) + ".pt")
        with torch.no_grad():
            sample = torch.randn(batch_size, latent_space).to(device)
            sample = model.decoder(sample).cpu()
            save_image(sample.view(batch_size, 1, 28, 28),
                       '../results/' + str(epoch) + '.png')
    train_losses.append(train_loss.cpu().detach().numpy())
    test_losses.append(test_loss)

  0%|          | 0/5 [00:02<?, ?it/s]


RuntimeError: shape '[-1, 1, 28, 28]' is invalid for input of size 460800

In [None]:
plt.plot(train_losses, label="training")
plt.plot(test_losses, label="validation")
plt.legend()
plt.show()