In [1]:
import sys
sys.path.insert(0, '../')


In [2]:
import numpy as np
%matplotlib inline

import torch
import torch.optim as optim

from dataset import get_set_loaders, imshow
from capsule.net import CapsNet

In [3]:
torch.set_num_threads(12)

In [4]:
train_set_loader, test_set_loader = get_set_loaders()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../datasets/MNIST/raw/train-images-idx3-ubyte.gz


100.1%

Extracting ../datasets/MNIST/raw/train-images-idx3-ubyte.gz to ../datasets/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../datasets/MNIST/raw/train-labels-idx1-ubyte.gz


113.5%

Extracting ../datasets/MNIST/raw/train-labels-idx1-ubyte.gz to ../datasets/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../datasets/MNIST/raw/t10k-images-idx3-ubyte.gz


100.4%

Extracting ../datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to ../datasets/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz


180.4%

Extracting ../datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../datasets/MNIST/raw
Processing...
Done!


# Train

### Init model

In [5]:
USE_CUDA = True
LOAD_PRE_TRAINED = False
EPOCHS = 100

if USE_CUDA and torch.cuda.is_available():
    print('USING CUDA')
    device = torch.device('cuda')
else:
    print('USING CPU')
    device = torch.device('cpu')


capsnet = CapsNet(cuda=USE_CUDA)

if LOAD_PRE_TRAINED:
    try:
        capsnet.load_state_dict(
            torch.load(open('./models/capsnet_state.pth', 'rb'), map_location=device)
        )
        print('State dict loaded from \'./models/capsnet_state.pth\'')
    except:
        pass

capsnet = capsnet.to(device)


USING CPU


### Optimizer and LR Scheduler

In [6]:
optimizer = optim.Adam(capsnet.parameters(), lr=0.01)
exponential_lr = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.90)

### Train loop

In [7]:
for epoch in range(1, EPOCHS+1):
    
    print()
    print('='*10, 'TRAIN', '='*10)
    capsnet.train()
    running_loss = 0.0
    running_accuracy = 0.0
    
    for i, batch in enumerate(train_set_loader, 1):        
        # Load the batch
        images, targets = batch
        images = images.to(device)
        targets = targets.to(device)
        
        # Zero grad
        optimizer.zero_grad()
        # Forward
        output, norm, reconstruction = capsnet(images, targets)
        # Compute loss
        loss = CapsNet.loss(norm, targets, reconstruction, images)
        # Compute accuracy
        most_active_idx = norm.argmax(dim=1)
        accuracy = torch.sum((most_active_idx == targets)).item() / targets.size(0)
        # Backward
        loss.backward()
        # Optim step
        optimizer.step()
        
        # Add loss
        running_loss += loss.item()
        # Add accuracy
        running_accuracy += accuracy
        
        if i % 100 == 0:
            print(f'==> EPOCH[{epoch}]({i}/{len(train_set_loader)}): LOSS: {loss.item()} ACCURACY: {accuracy}')
            
    print(f'=====> EPOCH[{epoch}] Completed: Avg. LOSS: {running_loss/len(train_set_loader)} Avg. ACCURACY {running_accuracy/len(train_set_loader)}')
    
    print()
    print('='*10, 'EVAL', '='*10)
    capsnet.eval()
    running_accuracy = 0.0
    
    for i, batch in enumerate(test_set_loader, 1):

        # Load the batch
        images, targets = batch
        images = images.to(device)
        targets = targets.to(device)
        
        with torch.no_grad():
            # Forward
            output, norm, reconstruction = capsnet(images)
            most_active_idx = norm.argmax(dim=1)
            running_accuracy += torch.sum((most_active_idx == targets)).item() / targets.size(0)

    print(f'=====> EPOCH[{epoch}]: AVG. ACCURACY: {running_accuracy/len(test_set_loader)}')
    
    # Reduce lr
    exponential_lr.step()




Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/Cellar/python/3.7.5/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/usr/local/Cellar/python/3.7.5/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/local/Cellar/python/3.7.5/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/local/Cellar/python/3.7.5/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
  File "/usr/local/Cellar/python/3.7.5/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/usr/local/Cellar/python/3.7.5/Frameworks/Pyth

KeyboardInterrupt: 