#### Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
sys.path.append("..")

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

In [3]:
import vnn
import vec_models
import nonvec_models
import init_methods
import dfa_util

#### Test for CUDA

In [4]:
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('No GPU, training on CPU')
    dev = torch.device('cpu')
else:
    print('GPU found, training on GPU')
    dev = torch.device('cuda')

No GPU, training on CPU


#### Load MNIST

In [5]:
def load_mnist(batch_size=128, shuffle_train=True):
    transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,))])
    train_set = torchvision.datasets.MNIST("../data", train=True, download=True, transform=transform)
    test_set = torchvision.datasets.MNIST("../data", train=False, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=shuffle_train)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [6]:
mnist_train_loader, mnist_test_loader = load_mnist()

In [7]:
for image_batch, label_batch in mnist_train_loader:
    pass

#### Define architecture

In [8]:
n_latent_1 = 100
n_latent_2 = 50

In [9]:
n_classes = 10

In [10]:
def make_mnist_recurrent(n_steps=n_classes, input_size=784, mono=True):
    model = nn.Sequential(
        vnn.Recurrent(category_dim=n_steps, input_size=input_size, rnn_dim=n_latent_1,
                      first_layer=False, first_rec_layer=True, last_layer=False, mono=mono),
        vnn.Recurrent(category_dim=n_steps, input_size=n_latent_1, rnn_dim=n_latent_2,
                      first_layer=False, first_rec_layer=False, last_layer=False, mono=mono),
        vnn.Recurrent(category_dim=n_steps, input_size=n_latent_2, rnn_dim=1,
                      first_layer=False, first_rec_layer=False, last_layer = True, mono=mono))
    return model

In [11]:
recc_mnist = make_mnist_recurrent()
recc_mnist

Sequential(
  (0): Recurrent()
  (1): Recurrent()
  (2): Recurrent()
)

In [12]:
# recc_mnist._modules['1']._parameters['weight_hh'].shape

In [13]:
image_batch = image_batch.view(image_batch.shape[0], -1)
output = recc_mnist(image_batch.to(dev))

Projected shape torch.Size([96, 100])
Inter shape torch.Size([96, 100])
Inter done
Mask shape torch.Size([1, 100])
Op shape torch.Size([10, 100])
Timestep op calculated 0
Inter shape torch.Size([96, 100])
Inter done
Mask shape torch.Size([1, 100])
Op shape torch.Size([10, 100])
Timestep op calculated 1
Inter shape torch.Size([96, 100])
Inter done
Mask shape torch.Size([1, 100])
Op shape torch.Size([10, 100])
Timestep op calculated 2
Inter shape torch.Size([96, 100])
Inter done
Mask shape torch.Size([1, 100])
Op shape torch.Size([10, 100])
Timestep op calculated 3
Inter shape torch.Size([96, 100])
Inter done
Mask shape torch.Size([1, 100])
Op shape torch.Size([10, 100])
Timestep op calculated 4
Inter shape torch.Size([96, 100])
Inter done
Mask shape torch.Size([1, 100])
Op shape torch.Size([10, 100])
Timestep op calculated 5
Inter shape torch.Size([96, 100])
Inter done
Mask shape torch.Size([1, 100])
Op shape torch.Size([10, 100])
Timestep op calculated 6
Inter shape torch.Size([96, 100

In [15]:
output.shape

torch.Size([96, 10, 1])

In [None]:
# recc_mnist._modules['0']._parameters['weight_ih'].shape

In [None]:
# recc_mnist._modules['1']._parameters['weight_ih'].shape

In [None]:
recc_mnist._modules['0']._parameters['weight_ih']

In [None]:
### weight ih in layer 1 too seems to be non-neg
### weights ih, hh in layer 2 seem to be 0.

#### Train recurrent model using GEB

In [None]:
def train_epoch(model, opt, train_loader, flatten, vectorized, learning_rule, device):
    avg_loss_sum = 0. #sum of batch-avg loss vals
    num_correct = 0 #sum of correct counts
    num_examples = 0 #total # examples
    loss_fn = nn.CrossEntropyLoss(reduction="mean")
    for batch_idx, (data, labels) in enumerate(train_loader):
        input = format_input(data, flatten, vectorized)
        opt.zero_grad()
        if vectorized:
            #vectorized BP or DF
            with torch.no_grad(): #makes no difference...but this proves to ourselves that there's no gradient here!
                output = model(input.to(device))[..., 0]
            vnn.set_model_grads(model, output, labels, learning_rule=learning_rule, reduction="mean")
            loss = loss_fn(output, labels.to(device))
        else:
            #unvectorized BP or DF
            output = model(input.to(device), learning_rule=learning_rule)
            loss = loss_fn(output, labels.to(device))
            loss.backward()
        opt.step()
        if vectorized:
            vnn.post_step_callback(model)
        else:
            dfa_util.post_step_callback(model)
        avg_loss_sum += loss.item()
        num_correct += (output.detach().argmax(dim=1).cpu() == labels).int().sum().item()
        num_examples += len(data)
    epoch_loss = avg_loss_sum / (batch_idx + 1)
    epoch_accuracy = num_correct / num_examples
    print("loss: {}, accuracy: {}".format(epoch_loss, epoch_accuracy))

In [None]:
def train_model(snapshot_dir, model, train_loader, test_loader, eval_iter, lr, num_epochs,
    flatten, vectorized, learning_rule, device):
    model = model.to(device)
    opt = optim.Adam(model.parameters(), lr=lr)
    snapshot_epoch, just_restarted, done_training = restart_from_snapshot(snapshot_dir, model, opt)
    if done_training or snapshot_epoch >= num_epochs:
        print("Loaded model already done training")
        return
    for epoch in range(snapshot_epoch, num_epochs):
        if epoch % eval_iter == 0 and not just_restarted:
            train_accuracy, train_loss = eval_accuracy(model, train_loader, flatten, vectorized, device)
            test_accuracy, test_loss = eval_accuracy(model, test_loader, flatten, vectorized, device)
            save_snapshot(snapshot_dir, model, opt, epoch, train_loss, train_accuracy, test_loss, test_accuracy,
                flatten, vectorized, learning_rule)
            if train_accuracy == 1.0:
                print("Perfect train accuracy achieved, ending training at epoch {}".format(epoch))
                break
        just_restarted = False
        train_epoch(model, opt, train_loader, flatten, vectorized, learning_rule, device)

In [None]:
snap_dir = '../mnist_recc_trials_v1/'
recc_mnist = recc_mnist()
mnist_train_loader, mnist_test_loader = load_mnist()
eval_iter = 1
eta = 0.001
n_epochs = 10
flatten_bool = True
vec_bool = True
rule="df"

In [None]:
train_model(snap_dir, vnn_fc_mnist, mnist_train_loader, mnist_test_loader, eval_iter, eta, n_epochs, flatten_bool, vec_bool, rule, dev)

In [None]:
weight = torch.rand(1,6)
out_features, in_features = weight.shape

In [None]:
weight[:] = 0.
W = torch.randn(max(out_features//2, 1), in_features//2, device=weight.device) / np.sqrt(0.25 * in_features)

# weight[::2, ::2] = F.relu(W)
# weight[::2, 1::2] = F.relu(-W)

# if out_features > 1:
#     weight[1::2, ::2] = F.relu(-W)
#     weight[1::2, 1::2] = F.relu(W)

In [None]:
weight

In [None]:
W