In [1]:
import torch
import torch.optim as optim
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
%matplotlib inline

from neural import QVAE


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

Device: cpu


## Data Loaders

In [3]:
BATCH_SIZE = 128
SHUFFLE = True
NUM_WORKERS = 12

In [4]:
train_set_loader = data.DataLoader(
    datasets.FashionMNIST('./data', train=True, transform=transforms.ToTensor(), download=True),
    batch_size=BATCH_SIZE,
    shuffle=SHUFFLE,
    num_workers=NUM_WORKERS,
)

test_set_loader = data.DataLoader(
    datasets.FashionMNIST('./data', train=False, transform=transforms.ToTensor(), download=True),
    batch_size=BATCH_SIZE,
    shuffle=SHUFFLE,
    num_workers=NUM_WORKERS,
)

## Net

In [7]:
net = QVAE(
    in_channels=1,
    num_hiddens=64,
    num_res_hiddens=16,
    num_res_layers=1,
    rgb_out=False,
    embedding_dim=32,
    init_num_embeddings=128,
    min_cluster_size=10,
).to(device)

In [8]:
try:
    net.load_state_dict(torch.load(open('state_dict.pth', 'rb')))
    print('State Dict loaded from \'state_dict.pth\'')
except:
    pass

## Train Loop

In [12]:
optimizer = optim.Adam(net.parameters(), lr=1e-4)

In [13]:
def train(epochs=10):
    print('='*10, end='')
    print(' TRAIN', end=' ') 
    print('='*10, end='\n\n')
    net.train()

    for epoch in range(1, epochs+1):
        running_loss = 0

        for i, batch in enumerate(train_set_loader, 1):
            images, _ = batch
            images = images.to(device)
            
            # Zero grad
            optimizer.zero_grad()

            # Forward
            encoded, quantized, recon_x = net(images)
            # Compute Loss
            loss_value = net.loss_function(images, recon_x, encoded, quantized)
            running_loss += loss_value.item()
            # Backward
            loss_value.backward()
            # Update
            optimizer.step()

            if i % 100 == 0:
                print(f'==> EPOCH[{epoch}]({i}/{len(train_set_loader)}): LOSS: {loss_value.item()}')
            
        print(f'=====> EPOCH[{epoch}] Completed: Avg. LOSS: {running_loss/len(train_set_loader)}')
        print()


In [14]:
train()


==> EPOCH[1](100/469): LOSS: 0.32571345567703247
==> EPOCH[1](200/469): LOSS: 0.3316141366958618
==> EPOCH[1](300/469): LOSS: 0.6385176181793213
==> EPOCH[1](400/469): LOSS: 0.3401660919189453
=====> EPOCH[1] Completed: Avg. LOSS: 0.5027683224759376

==> EPOCH[2](100/469): LOSS: 0.693558931350708
==> EPOCH[2](200/469): LOSS: 0.3210856020450592
==> EPOCH[2](300/469): LOSS: 0.6305079460144043
==> EPOCH[2](400/469): LOSS: 0.7029139399528503
=====> EPOCH[2] Completed: Avg. LOSS: 0.5528024963732722

==> EPOCH[3](100/469): LOSS: 0.6502506136894226


Traceback (most recent call last):
  File "/usr/local/Cellar/python/3.7.6_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/usr/local/Cellar/python/3.7.6_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/local/Cellar/python/3.7.6_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/local/Cellar/python/3.7.6_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


KeyboardInterrupt: 

In [None]:
net.quantizer.num_embeddings

In [None]:
net.eval()

In [None]:
image, label = test_set_loader.dataset[420]
encoded, quantized, recon = net(image.unsqueeze(0))

print(label)
plt.imshow(image[0], cmap='Greys');

In [None]:
recon = recon[0].squeeze()
plt.imshow(recon.detach().numpy(), cmap='Greys');

In [None]:
torch.save(net.state_dict(), open('state_dict.pth', 'wb'))