#### Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
sys.path.append("..")

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

In [3]:
import vnn
import vec_models
import nonvec_models
import init_methods
import dfa_util

#### Test for CUDA

In [4]:
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('No GPU, training on CPU')
    dev = torch.device('cpu')
else:
    print('GPU found, training on GPU')
    dev = torch.device('cuda')

No GPU, training on CPU


#### Load MNIST 

In [5]:
def load_mnist(batch_size=128, shuffle_train=True):
    transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,))])
    train_set = torchvision.datasets.MNIST("../data", train=True, download=True, transform=transform)
    test_set = torchvision.datasets.MNIST("../data", train=False, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=shuffle_train)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

#### Define architecture

In [6]:
def make_mnist_vec_fc(mono=True):
    model = nn.Sequential(
        vnn.Linear(10, 28*28*10, 1024, first_layer=True, mono=mono), ## 1st layer doesn't need weight sharing
        vnn.tReLU(10, 1024),
        vnn.Linear(10, 1024, 512, mono=mono), ## 
        vnn.tReLU(10, 512),
        vnn.Linear(10, 512, 1, mono=mono))
    return model

#### Train model using GEVB

In [7]:
def format_input(data, flatten, vectorized):
    if vectorized:
        if flatten:
            input = vnn.expand_input(torch.flatten(data, 1), 10)
        else:
            input = vnn.expand_input_conv(data, 10)
    else:
        if flatten:
            input = torch.flatten(data, 1)
        else:
            input = data
    return input

In [8]:
def make_dir(dir_name):
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)

In [9]:
def files_in_dir(dir_name):
    #filenames = sorted([os.path.join(dir_name, f) for f in os.listdir(dir_name) if os.path.isfile(os.path.join(dir_name, f))])
    filenames = [f for f in os.listdir(dir_name) if os.path.isfile(os.path.join(dir_name, f))]
    epochs = [int(f.split("_")[1].split(".")[0]) for f in filenames]
    sorted_idx = np.argsort(epochs)
    sorted_filenames = [os.path.join(dir_name, filenames[sorted_idx[i]]) for i in range(len(filenames))]
    return sorted_filenames

In [10]:
def eval_accuracy(model, loader, flatten, vectorized, device):
    loss_sum = 0.
    num_correct = 0
    num_examples = 0
    loss_fn = nn.CrossEntropyLoss(reduction="sum") #note: sum, not mean here
    for batch_idx, (data, labels) in enumerate(loader):
        input = format_input(data, flatten, vectorized)
        with torch.no_grad():
            output = model(input.to(device))
        if vectorized:
            output = output[..., 0]
        loss = loss_fn(output, labels.to(device)).item()
        loss_sum += loss
        num_correct += (output.argmax(dim=1).cpu() == labels).int().sum().item()
        num_examples += len(data)
    accuracy = num_correct / num_examples
    loss = loss_sum / num_examples
    return accuracy, loss

In [11]:
def save_snapshot(snapshot_dir, model, opt, epoch, train_loss, train_accuracy, test_loss, test_accuracy,
    flatten, vectorized, learning_rule):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': opt.state_dict(),
        'train_loss': train_loss,
        'train_accuracy': train_accuracy,
        'test_loss': test_loss,
        'test_accuracy': test_accuracy,
        'flatten': flatten,
        'vectorized': vectorized,
        'learning_rule': learning_rule,
        'device': [p.device for p in model.parameters()][0]
        }, "{}/epoch_{}.pt".format(snapshot_dir, epoch))
    print("saved snapshot at epoch {}".format(epoch))
    print("train/test accuracy: {}/{}".format(train_accuracy, test_accuracy))

In [12]:
def restart_from_snapshot(snapshot_dir, model, opt):
    make_dir(snapshot_dir)
    filenames = files_in_dir(snapshot_dir)
    if len(filenames) > 0:
        snapshot = torch.load(filenames[-1])
        done_training = snapshot['train_accuracy'] == 1.
        if not done_training:
            model.load_state_dict(snapshot['model_state_dict'])
            opt.load_state_dict(snapshot['optimizer_state_dict'])
        epoch = snapshot['epoch']
        restarted = True
        print("loaded model from snapshot at epoch {}".format(epoch))
    else:
        epoch = 0
        restarted = False
        done_training = False
    return epoch, restarted, done_training

In [13]:
def train_epoch(model, opt, train_loader, flatten, vectorized, learning_rule, device):
    avg_loss_sum = 0. #sum of batch-avg loss vals
    num_correct = 0 #sum of correct counts
    num_examples = 0 #total # examples
    loss_fn = nn.CrossEntropyLoss(reduction="mean")
    for batch_idx, (data, labels) in enumerate(train_loader):
        input = format_input(data, flatten, vectorized)
        opt.zero_grad()
        if vectorized:
            #vectorized BP or DF
            with torch.no_grad(): #makes no difference...but this proves to ourselves that there's no gradient here!
                output = model(input.to(device))[..., 0]
            vnn.set_model_grads(model, output, labels, learning_rule=learning_rule, reduction="mean")
            loss = loss_fn(output, labels.to(device))
        else:
            #unvectorized BP or DF
            output = model(input.to(device), learning_rule=learning_rule)
            loss = loss_fn(output, labels.to(device))
            loss.backward()
        opt.step()
        if vectorized:
            vnn.post_step_callback(model)
        else:
            dfa_util.post_step_callback(model)
        avg_loss_sum += loss.item()
        num_correct += (output.detach().argmax(dim=1).cpu() == labels).int().sum().item()
        num_examples += len(data)
    epoch_loss = avg_loss_sum / (batch_idx + 1)
    epoch_accuracy = num_correct / num_examples
    print("loss: {}, accuracy: {}".format(epoch_loss, epoch_accuracy))

In [14]:
def train_model(snapshot_dir, model, train_loader, test_loader, eval_iter, lr, num_epochs,
    flatten, vectorized, learning_rule, device):
    model = model.to(device)
    opt = optim.Adam(model.parameters(), lr=lr)
#     _ = model(torch.randn(128, 3, 32, 32).to(device), learning_rule='df') #TODO: remove me!
    snapshot_epoch, just_restarted, done_training = restart_from_snapshot(snapshot_dir, model, opt)
    if done_training or snapshot_epoch >= num_epochs:
        print("Loaded model already done training")
        return
    for epoch in range(snapshot_epoch, num_epochs):
        if epoch % eval_iter == 0 and not just_restarted:
            train_accuracy, train_loss = eval_accuracy(model, train_loader, flatten, vectorized, device)
            test_accuracy, test_loss = eval_accuracy(model, test_loader, flatten, vectorized, device)
            save_snapshot(snapshot_dir, model, opt, epoch, train_loss, train_accuracy, test_loss, test_accuracy,
                flatten, vectorized, learning_rule)
            if train_accuracy == 1.0:
                print("Perfect train accuracy achieved, ending training at epoch {}".format(epoch))
                break
        just_restarted = False
        train_epoch(model, opt, train_loader, flatten, vectorized, learning_rule, device)

In [19]:
snap_dir = '../mnist_gevb_fc_trials_v2/'
vnn_fc_mnist = make_mnist_vec_fc()
mnist_train_loader, mnist_test_loader = load_mnist()
eval_iter = 1
eta = 0.001
n_epochs = 10
flatten_bool = True
vec_bool = True
rule="df"

In [21]:
train_model(snap_dir, vnn_fc_mnist, mnist_train_loader, mnist_test_loader, eval_iter, eta, n_epochs,
            flatten_bool, vec_bool, rule, dev)

saved snapshot at epoch 0
train/test accuracy: 0.1076/0.1085
loss: 1.2415843262061128, accuracy: 0.8846833333333334
saved snapshot at epoch 1
train/test accuracy: 0.9458/0.9404
loss: 0.17309663021392913, accuracy: 0.94855
saved snapshot at epoch 2
train/test accuracy: 0.9621333333333333/0.9529
loss: 0.13009815968906702, accuracy: 0.9594166666666667
saved snapshot at epoch 3
train/test accuracy: 0.9679/0.9606
loss: 0.1103716528793769, accuracy: 0.9667166666666667
saved snapshot at epoch 4
train/test accuracy: 0.9760166666666666/0.9674
loss: 0.09704019494656561, accuracy: 0.9696
saved snapshot at epoch 5
train/test accuracy: 0.97905/0.9685
loss: 0.09081464730647963, accuracy: 0.9711833333333333
saved snapshot at epoch 6
train/test accuracy: 0.9809833333333333/0.9708
loss: 0.08125746941297197, accuracy: 0.9741166666666666
saved snapshot at epoch 7
train/test accuracy: 0.9751333333333333/0.9632
loss: 0.07072647430423672, accuracy: 0.9771166666666666
saved snapshot at epoch 8
train/test acc

In [23]:
vnn_fc_mnist.__dict__

{'training': True,
 '_parameters': OrderedDict(),
 '_buffers': OrderedDict(),
 '_non_persistent_buffers_set': set(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_modules': OrderedDict([('0', Linear()),
              ('1', tReLU()),
              ('2', Linear()),
              ('3', tReLU()),
              ('4', Linear())])}

In [32]:
vnn_fc_mnist._modules['0']._parameters['weight']

Parameter containing:
tensor([[-0.0318, -0.0123,  0.0246,  ...,  0.0165,  0.0091, -0.0493],
        [ 0.0158, -0.0037, -0.0406,  ..., -0.0016,  0.0057,  0.0642],
        [-0.0079,  0.0088,  0.0490,  ..., -0.0590, -0.0293,  0.0261],
        ...,
        [-0.0137, -0.0179,  0.0606,  ..., -0.0118, -0.0365, -0.0008],
        [-0.0691, -0.0120, -0.0184,  ..., -0.0679,  0.0233,  0.0086],
        [ 0.0751,  0.0180,  0.0244,  ...,  0.0499, -0.0412, -0.0266]])

In [51]:
(vnn_fc_mnist._modules['1']._parameters['t'][2] - vnn_fc_mnist._modules['1']._parameters['t'][5])

tensor([-2.,  0.,  2.])

In [45]:
vnn_fc_mnist._modules['2']._parameters['weight'].shape

torch.Size([512, 1024])

In [58]:
vnn_fc_mnist._modules['0']._parameters['weight'].shape

torch.Size([1024, 7840])