In [15]:
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm
%run ConV_VAE.ipynb
%run loader.ipynb
%run test.ipynb
%run train.ipynb

In [16]:
plt.rcParams['figure.dpi'] = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'
####### IMPORTANT ######
# set this flag to true if you want to load the model beforehand
load = False
# specify which model you want to load
model_name = None

In [17]:
normalize = transforms.Normalize(160, 50)
transform = transforms.Compose([
    transforms.ToTensor(),
    normalize
])
train_data = CustomImageDataset('../data/sign_mnist_train.csv', transform=transform)
test_data = CustomImageDataset('../data/sign_mnist_test.csv', transform=transform)

In [18]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=True)

In [19]:
model = ConvVarAutoencoder().to(device)
if (load):
    model.load_state_dict(torch.load(model_name, map_location=device))

criterion = F.mse_loss
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.1)

In [None]:
epochs = 10
train_losses = []
test_losses = []
for epoch in tqdm(range(epochs)):  # loop over the dataset multiple times
    # Train on data
    train_loss = train_vae(train_loader, model, optimizer, device)
    test_loss = test(epoch, model, test_loader)
    print('====> Training loss: {train} KL Divergence {kl}'.format(train=train_loss.item(),kl=model.encoder.kl.item()))
    print('====> Test set loss: {:.4f}'.format(test_loss))
    if epoch % 10 == 0:
        torch.save(model.state_dict(), "../model/model" + str(epoch) + ".pt")
        with torch.no_grad():
            sample = torch.randn(64, 2048).to(device)
            sample = model.decoder(sample).cpu()
            save_image(sample.view(64, 1, 28, 28),
                       '../results/' + str(epoch) + '.png')
    train_losses.append(train_loss.cpu().detach().numpy())
    test_losses.append(test_loss)

 10%|█         | 1/10 [00:21<03:11, 21.27s/it]

====> Training loss: 159672.71875 KL Divergence 3041.65087890625
====> Test set loss: 131668.4698


 20%|██        | 2/10 [00:39<02:35, 19.45s/it]

====> Training loss: 126780.8984375 KL Divergence 3025.66064453125
====> Test set loss: 121882.3455


 30%|███       | 3/10 [00:56<02:09, 18.44s/it]

====> Training loss: 121715.4453125 KL Divergence 3015.67626953125
====> Test set loss: 120033.4553


 40%|████      | 4/10 [01:15<01:51, 18.61s/it]

====> Training loss: 119795.125 KL Divergence 2989.90576171875
====> Test set loss: 119037.1834


In [None]:
plt.plot(train_losses, label="training")
plt.plot(test_losses, label="validation")
plt.legend()
plt.show()