In [1]:
import numpy as np
%matplotlib inline

import torch
import torch.optim as optim

from dataset import get_set_loaders, imshow
from capsulenet import CapsNet

In [2]:
torch.set_num_threads(12)

In [3]:
train_set_loader, test_set_loader = get_set_loaders()

# Train

### Init model

In [4]:
USE_CUDA = True
EPOCHS = 100

if USE_CUDA and torch.cuda.is_available():
    print('USING CUDA')
    device = torch.device('cuda')
else:
    print('USING CPU')
    device = torch.device('cpu')

capsnet = CapsNet(cuda=USE_CUDA)
try:
    capsnet.load_state_dict(
        torch.load(open('./models/capsnet_state.pth', 'rb'), map_location=device)
    )
    print('State dict loaded from \'./models/capsnet_state.pth\'')
except:
    pass

capsnet = capsnet.to(device)


USING CPU
State dict loaded from './models/capsnet_state.pth'


### Optimizer and LR Scheduler

In [5]:
optimizer = optim.Adam(capsnet.parameters(), lr=0.01*(0.90**35))
exponential_lr = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.90)

### Train loop

In [6]:
for epoch in range(1, EPOCHS+1):
    
    print()
    print('='*10, 'TRAIN', '='*10)
    capsnet.train()
    running_loss = 0.0
    running_accuracy = 0.0
    
    for i, batch in enumerate(train_set_loader, 1):        
        # Load the batch
        images, targets = batch
        images = images.to(device)
        targets = targets.to(device)
        
        # Zero grad
        optimizer.zero_grad()
        # Forward
        output, norm, reconstruction = capsnet(images, targets)
        # Compute loss
        loss = CapsNet.loss(norm, targets, reconstruction, images)
        # Compute accuracy
        most_active_idx = norm.argmax(dim=1)
        accuracy = torch.sum((most_active_idx == targets)).item() / targets.size(0)
        # Backward
        loss.backward()
        # Optim step
        optimizer.step()
        
        # Add loss
        running_loss += loss.item()
        # Add accuracy
        running_accuracy += accuracy
        
        if i % 100 == 0:
            print(f'==> EPOCH[{epoch}]({i}/{len(train_set_loader)}): LOSS: {loss.item()} ACCURACY: {accuracy}')
            
    print(f'=====> EPOCH[{epoch}] Completed: Avg. LOSS: {running_loss/len(train_set_loader)} Avg. ACCURACY {running_accuracy/len(train_set_loader)}')
    
    print()
    print('='*10, 'EVAL', '='*10)
    capsnet.eval()
    running_accuracy = 0.0
    
    for i, batch in enumerate(test_set_loader, 1):

        # Load the batch
        images, targets = batch
        images = images.to(device)
        targets = targets.to(device)
        
        with torch.no_grad():
            # Forward
            output, norm, reconstruction = capsnet(images)
            most_active_idx = norm.argmax(dim=1)
            running_accuracy += torch.sum((most_active_idx == targets)).item() / targets.size(0)

    print(f'=====> EPOCH[{epoch}]: AVG. ACCURACY: {running_accuracy/len(test_set_loader)}')
    
    # Reduce lr
    exponential_lr.step()


==> EPOCH[1](100/469): LOSS: 0.00034670671448111534 ACCURACY: 1.0
==> EPOCH[1](200/469): LOSS: 2.797026354528498e-05 ACCURACY: 1.0
==> EPOCH[1](300/469): LOSS: 5.710797995561734e-05 ACCURACY: 1.0
==> EPOCH[1](400/469): LOSS: 3.8094229239504784e-05 ACCURACY: 1.0
=====> EPOCH[1] Completed: Avg. LOSS: 4.545670733676661e-05 Avg. ACCURACY 1.0

=====> EPOCH[1]: AVG. ACCURACY: 0.99521484375

==> EPOCH[2](100/469): LOSS: 4.553144390229136e-05 ACCURACY: 1.0
==> EPOCH[2](200/469): LOSS: 9.55496943788603e-05 ACCURACY: 1.0
==> EPOCH[2](300/469): LOSS: 4.4865431846119463e-05 ACCURACY: 1.0
==> EPOCH[2](400/469): LOSS: 3.21540683216881e-05 ACCURACY: 1.0
=====> EPOCH[2] Completed: Avg. LOSS: 4.857826933924626e-05 Avg. ACCURACY 1.0

=====> EPOCH[2]: AVG. ACCURACY: 0.99521484375

==> EPOCH[3](100/469): LOSS: 2.4436070816591382e-05 ACCURACY: 1.0
==> EPOCH[3](200/469): LOSS: 5.281090852804482e-05 ACCURACY: 1.0
==> EPOCH[3](300/469): LOSS: 2.6276162316207774e-05 ACCURACY: 1.0
==> EPOCH[3](400/469): LOSS: 

Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/Cellar/python/3.7.4_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/usr/local/Cellar/python/3.7.4_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/local/Cellar/python/3.7.4_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/local/Cellar/python/3.7.4_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
  File "/usr/local/Cellar/python/3.7.4_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  Fi

KeyboardInterrupt: 