In [1]:
import torch
import torch.optim as optim
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
%matplotlib inline

from neural import VQVAE


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

Device: cpu


## Data Loaders

In [3]:
BATCH_SIZE = 128
SHUFFLE = True
NUM_WORKERS = 12

In [4]:
train_set_loader = data.DataLoader(
    datasets.MNIST('./data', train=True, transform=transforms.ToTensor(), download=True),
    batch_size=BATCH_SIZE,
    shuffle=SHUFFLE,
    num_workers=NUM_WORKERS,
)

test_set_loader = data.DataLoader(
    datasets.MNIST('./data', train=False, transform=transforms.ToTensor(), download=True),
    batch_size=BATCH_SIZE,
    shuffle=SHUFFLE,
    num_workers=NUM_WORKERS,
)

## Net and Optimizer

In [5]:
net = VQVAE(
    in_channels=1,
    num_hiddens=32,
    num_res_hiddens=0,
    num_res_layers=0,
    decay=0.99,
    rgb_out=False,
).to(device)
optimizer = optim.Adam(net.parameters(), lr=1e-4)

In [6]:
try:
    net.load_state_dict(torch.load(open('state_dict.pth', 'rb')))
    print('State Dict loaded from \'state_dict.pth\'')
except:
    pass

## Train Loop

In [7]:
def train(epochs=10):
    print('='*10, end='')
    print(' TRAIN', end=' ') 
    print('='*10, end='\n\n')
    net.train()

    for epoch in range(1, epochs+1):
        running_loss = 0

        for i, batch in enumerate(train_set_loader, 1):
            images, _ = batch
            images = images.to(device)
            
            # Zero grad
            optimizer.zero_grad()

            # Forward
            encoded, quantized, recon_x = net(images)
            # Compute Loss
            loss_value = net.loss_function(images, recon_x, encoded, quantized)
            running_loss += loss_value.item()
            # Backward
            loss_value.backward()
            # Update
            optimizer.step()

            if i % 100 == 0:
                print(f'==> EPOCH[{epoch}]({i}/{len(train_set_loader)}): LOSS: {loss_value.item()}')
            
        print(f'=====> EPOCH[{epoch}] Completed: Avg. LOSS: {running_loss/len(train_set_loader)}')
        print()


In [8]:
train()


==> EPOCH[1](100/469): LOSS: 0.6023479700088501
==> EPOCH[1](200/469): LOSS: 0.49819499254226685
==> EPOCH[1](300/469): LOSS: 0.5997454524040222
==> EPOCH[1](400/469): LOSS: 0.5790199041366577
=====> EPOCH[1] Completed: Avg. LOSS: 0.5773290154903428

==> EPOCH[2](100/469): LOSS: 0.2756499946117401
==> EPOCH[2](200/469): LOSS: 0.21680065989494324
==> EPOCH[2](300/469): LOSS: 0.17918626964092255
==> EPOCH[2](400/469): LOSS: 0.16688381135463715
=====> EPOCH[2] Completed: Avg. LOSS: 0.24800729020826343

==> EPOCH[3](100/469): LOSS: 0.13846659660339355
==> EPOCH[3](200/469): LOSS: 0.13659034669399261
==> EPOCH[3](300/469): LOSS: 0.1302890032529831
==> EPOCH[3](400/469): LOSS: 0.12322765588760376
=====> EPOCH[3] Completed: Avg. LOSS: 0.13492631440414293

==> EPOCH[4](100/469): LOSS: 0.12125618755817413
==> EPOCH[4](200/469): LOSS: 0.11437968909740448
==> EPOCH[4](300/469): LOSS: 0.11170611530542374
==> EPOCH[4](400/469): LOSS: 0.1121169775724411
=====> EPOCH[4] Completed: Avg. LOSS: 0.11518

KeyboardInterrupt: 

In [None]:
net.test()

In [None]:
image, label = test_set_loader.dataset[1]
encoded, quantized, recon = net(image.unsqueeze(0))

print(label)
plt.imshow(image[0], cmap='Greys');

In [None]:
recon = recon[0].squeeze()
plt.imshow(recon.detach().numpy(), cmap='Greys');

In [None]:
net.quantizer.num_embeddings

In [None]:
encodings = torch.zeros(28*28, 16)
encodings[0, 15] = 1
quantized = encodings @ net.gm_embedding.embeddings_mu

In [None]:
quantized = quantized.reshape((1, 28, 28, 8)).permute(0, 3, 1, 2)

In [None]:
recon = net.decoder(quantized)[0].squeeze()
plt.imshow(recon.detach().numpy(), cmap='Greys');

In [None]:
image, label = test_set_loader.dataset[20]
image = image.view(-1, 28*28).unsqueeze(0)

mu, sigma = net.encode(image)
z_2 = mu + sigma * torch.rand_like(sigma)

recon = net.decode(z_2)
print(label)
plt.imshow(recon.view(28, 28).detach().numpy(), cmap='Greys');

In [None]:
z_3 = z_2 - z_1
recon = net.decode(z_3)
plt.imshow(recon.view(28, 28).detach().numpy(), cmap='Greys');

In [None]:
torch.save(net.state_dict(), open('vqa_state_dict.pth', 'wb'))