In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm
%run ConV_VAE.ipynb
%run loader.ipynb
%run test.ipynb
%run train.ipynb

In [2]:
plt.rcParams['figure.dpi'] = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'
####### IMPORTANT ######
# set this flag to true if you want to load the model beforehand
load = False
# specify which model you want to load
model_name = None

In [3]:
normalize = transforms.Normalize(160, 50)
transform = transforms.Compose([
    transforms.ToTensor(),
    normalize
])
train_data = CustomImageDataset('../data/sign_mnist_train.csv', transform=transform)
test_data = CustomImageDataset('../data/sign_mnist_test.csv')

In [4]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=True)

In [5]:
model = ConvVarAutoencoder().to(device)
if (load):
    model.load_state_dict(torch.load(model_name, map_location=device))

criterion = F.mse_loss
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.1)

In [None]:
epochs = 15
train_losses = []
test_losses = []
for epoch in tqdm(range(epochs)):  # loop over the dataset multiple times
    # Train on data
    train_loss = train_vae(train_loader, model, optimizer, device)
    test_loss = test(epoch, model, test_loader)
    print('====> Training loss: {train} KL Divergence {kl}'.format(train=train_loss.item(),kl=model.encoder.kl.item()))
    print('====> Test set loss: {:.4f}'.format(test_loss))
    if epoch % 10 == 0:
        torch.save(model.state_dict(), "../model/model" + str(epoch) + ".pt")
        with torch.no_grad():
            sample = torch.randn(64, 2048).to(device)
            sample = model.decoder(sample).cpu()
            save_image(sample.view(64, 1, 28, 28),
                       '../results/' + str(epoch) + '.png')
    train_losses.append(train_loss.cpu().detach().numpy())
    test_losses.append(test_loss)

  7%|▋         | 1/15 [00:18<04:19, 18.53s/it]

====> Training loss: 146958.421875 KL Divergence 17318624.0
====> Test set loss: 3345436771.5088


 13%|█▎        | 2/15 [00:35<03:47, 17.53s/it]

====> Training loss: 139178.65625 KL Divergence 2030633.75
====> Test set loss: 2988848097.5439


 20%|██        | 3/15 [00:52<03:30, 17.55s/it]

====> Training loss: 124668.0390625 KL Divergence 1910223.875
====> Test set loss: 2937796903.0175


 27%|██▋       | 4/15 [01:10<03:14, 17.64s/it]

====> Training loss: 122220.6796875 KL Divergence 2227604.5
====> Test set loss: 2874950818.6667


 33%|███▎      | 5/15 [01:28<02:56, 17.66s/it]

====> Training loss: 120757.2578125 KL Divergence 2456443.0
====> Test set loss: 2864227690.9474


 40%|████      | 6/15 [01:47<02:43, 18.12s/it]

====> Training loss: 119842.0703125 KL Divergence 2469235.5
====> Test set loss: 2831117308.0702


 47%|████▋     | 7/15 [02:06<02:26, 18.37s/it]

====> Training loss: 119006.2109375 KL Divergence 2743234.5
====> Test set loss: 2792013974.8772


 53%|█████▎    | 8/15 [02:26<02:12, 18.94s/it]

====> Training loss: 118435.8359375 KL Divergence 2734568.5
====> Test set loss: 2734051389.6140


 60%|██████    | 9/15 [02:44<01:52, 18.78s/it]

====> Training loss: 117781.5078125 KL Divergence 2680895.5
====> Test set loss: 2586329314.6667


 67%|██████▋   | 10/15 [03:03<01:33, 18.75s/it]

====> Training loss: 117267.375 KL Divergence 2258320.25
====> Test set loss: 2635696689.8246


 73%|███████▎  | 11/15 [03:21<01:14, 18.60s/it]

====> Training loss: 117002.8125 KL Divergence 2646716.25
====> Test set loss: 2518869624.7018


 80%|████████  | 12/15 [03:40<00:56, 18.73s/it]

====> Training loss: 116591.4765625 KL Divergence 1966671.625
====> Test set loss: 2462845769.2632


In [None]:
plt.plot(train_losses)
plt.plot(test_losses)
plt.show()