Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
lezcano committed Jun 13, 2019
1 parent b8c69b7 commit 19a3ca1
Show file tree
Hide file tree
Showing 7 changed files with 673 additions and 226 deletions.
62 changes: 21 additions & 41 deletions 1_copying.py
Expand Up @@ -3,8 +3,7 @@
import numpy as np
import argparse

from exprnn import ExpRNN
from approximants import exp_pade, taylor, scale_square, cayley
from exprnn import ExpRNN, get_parameters, orthogonal_step
from initialization import henaff_init, cayley_init

parser = argparse.ArgumentParser(description='Exponential Layer Copy Task')
Expand All @@ -14,12 +13,10 @@
parser.add_argument('--L', type=int, default=1000)
parser.add_argument('--lr', type=float, default=2e-4)
parser.add_argument('--lr_orth', type=float, default=2e-5)
parser.add_argument("--rescale", action="store_true", help='Apply scale-squaring trick')
parser.add_argument("-m", "--mode",
choices=["exact", "cayley", "pade", "taylor20", "lstm"],
default="exact",
type=str,
help="LSTM or Approximant to approximate the exponential of matrices")
choices=["exprnn", "lstm"],
default="exprnn",
type=str)
parser.add_argument("--init",
choices=["cayley", "henaff"],
default="henaff",
Expand All @@ -46,24 +43,6 @@
elif args.init == "henaff":
init = henaff_init

if args.mode == "cayley":
exp_func = cayley
elif args.mode == "pade":
exp_func = exp_pade
elif args.mode == "taylor20":
exp_func = lambda X: taylor(X, 20)

if args.mode != "lstm":
if args.mode == "exact":
# The exact implementation already implements a more advanced form of scale-squaring trick
exp = "exact"
else:
if args.rescale:
exp = lambda X: scale_square(X, exp_func)
else:
exp = exp_func


def copying_data(L, K, batch_size):
seq = np.random.randint(1, high=9, size=(batch_size, K))
zeros1 = np.zeros((batch_size, L))
Expand All @@ -77,13 +56,13 @@ def copying_data(L, K, batch_size):
return x, y


class Model(nn.Module):
class Model(torch.jit.ScriptModule):
def __init__(self, n_classes, hidden_size):
super(Model, self).__init__()
if args.mode == "lstm":
self.rnn = nn.LSTMCell(n_classes + 1, hidden_size)
else:
self.rnn = ExpRNN(n_classes + 1, hidden_size, exponential=exp, skew_initializer=init)
self.rnn = ExpRNN(n_classes + 1, hidden_size, skew_initializer=init)
self.lin = nn.Linear(hidden_size, n_classes)
self.loss_func = nn.CrossEntropyLoss()
self.reset_parameters()
Expand All @@ -92,17 +71,14 @@ def reset_parameters(self):
nn.init.kaiming_normal_(self.lin.weight.data, nonlinearity="relu")
nn.init.constant_(self.lin.bias.data, 0)

@torch.jit.script_method
def forward(self, inputs):
h = None
out = []
state = self.rnn.default_hidden(inputs[:, 0, ...])
outputs = torch.jit.annotate(List[Tensor], [])
for input in torch.unbind(inputs, dim=1):
h = self.rnn(input, h)
if isinstance(h, tuple):
out_rnn = h[0]
else:
out_rnn = h
out.append(self.lin(out_rnn))
return torch.stack(out, dim=1)
out_rnn, state = self.rnn(input, state)
outputs += [self.lin(out_rnn)]
return torch.stack(outputs, dim=1)

def loss(self, logits, y):
return self.loss_func(logits.view(-1, 9), y.view(-1))
Expand Down Expand Up @@ -134,10 +110,9 @@ def main():
optim = torch.optim.RMSprop(model.parameters(), lr=args.lr)
optim_orth = None
else:
optim = torch.optim.RMSprop((param for param in model.parameters()
if param is not model.rnn.log_recurrent_kernel and
param is not model.rnn.recurrent_kernel), lr=args.lr)
optim_orth = torch.optim.RMSprop([model.rnn.log_recurrent_kernel], lr=args.lr_orth)
non_orth_params, log_orth_params = get_parameters(model)
optim = torch.optim.RMSprop(non_orth_params, args.lr)
optim_orth = torch.optim.RMSprop(log_orth_params, lr=args.lr_orth)

x_onehot = torch.FloatTensor(batch_size, n_len, n_characters).to(device)

Expand All @@ -149,10 +124,15 @@ def main():
logits = model(x_onehot)
loss = model.loss(logits, batch_y)

# Zeroing out the optim_orth is not really necessary, but we do it for consistency
if optim_orth:
optim_orth.zero_grad()
optim.zero_grad()

loss.backward()

if optim_orth:
model.rnn.orthogonal_step(optim_orth)
model.apply(orthogonal_step(optim_orth))
optim.step()

with torch.no_grad():
Expand Down
87 changes: 39 additions & 48 deletions 2_mnist.py
Expand Up @@ -2,10 +2,10 @@
import torch.nn as nn
import numpy as np
import argparse
import sys
from torchvision import datasets, transforms

from exprnn import ExpRNN
from approximants import exp_pade, taylor, scale_square, cayley
from exprnn import ExpRNN, get_parameters, orthogonal_step
from initialization import henaff_init, cayley_init

parser = argparse.ArgumentParser(description='Exponential Layer MNIST Task')
Expand All @@ -15,12 +15,10 @@
parser.add_argument('--lr', type=float, default=7e-4)
parser.add_argument('--lr_orth', type=float, default=7e-5)
parser.add_argument("--permute", action="store_true")
parser.add_argument("--rescale", action="store_true")
parser.add_argument("-m", "--mode",
choices=["exact", "cayley", "pade", "taylor20", "lstm"],
default="exact",
type=str,
help="LSTM or Approximant to approximate the exponential of matrices")
choices=["exprnn", "lstm"],
default="exprnn",
type=str)
parser.add_argument("--init",
choices=["cayley", "henaff"],
default="cayley",
Expand Down Expand Up @@ -49,45 +47,30 @@
init = henaff_init


if args.mode == "cayley":
exp_func = cayley
elif args.mode == "pade":
exp_func = exp_pade
elif args.mode == "taylor20":
exp_func = lambda X: taylor(X, 20)

if args.mode != "lstm":
if args.mode == "exact":
# The exact implementation already implements a more advanced form of scale-squaring trick
exp = "exact"
else:
if args.rescale:
exp = lambda X: scale_square(X, exp_func)
else:
exp = exp_func

if args.permute:
permute = np.random.RandomState(92916)
permutation = torch.LongTensor(permute.permutation(784))


class Model(nn.Module):
def __init__(self, hidden_size):
class Model(torch.jit.ScriptModule):
__constants__ = ["permute", "permutation"]
def __init__(self, hidden_size, permute):
super(Model, self).__init__()
self.permute = permute
permute = np.random.RandomState(92916)
self.register_buffer("permutation", torch.LongTensor(permute.permutation(784)))
if args.mode == "lstm":
self.rnn = nn.LSTMCell(1, hidden_size)
else:
self.rnn = ExpRNN(1, hidden_size, exponential=exp, skew_initializer=init)
self.rnn = ExpRNN(1, hidden_size, skew_initializer=init)

self.lin = nn.Linear(hidden_size, n_classes)
self.loss_func = nn.CrossEntropyLoss()

@torch.jit.script_method
def forward(self, inputs):
if args.permute:
inputs = inputs[:, permutation]
h = None
if self.permute:
inputs = inputs[:, self.permutation]

state = self.rnn.default_hidden(inputs[:,0,...])
for input in torch.unbind(inputs, dim=1):
h = self.rnn(input.unsqueeze(dim=1), h)
return self.lin(h)
_, state = self.rnn(input.unsqueeze(dim=1), state)
return self.lin(state)

def loss(self, logits, y):
return self.loss_func(logits, y)
Expand All @@ -107,17 +90,16 @@ def main():
batch_size=batch_size, shuffle=True, **kwargs)

# Model and optimizers
model = Model(hidden_size).to(device)
model = Model(hidden_size, args.permute).to(device)
model.train()

if args.mode == "lstm":
optim = torch.optim.RMSprop(model.parameters(), lr=args.lr)
optim_orth = None
else:
optim = torch.optim.RMSprop((param for param in model.parameters()
if param is not model.rnn.log_recurrent_kernel and
param is not model.rnn.recurrent_kernel), lr=args.lr)
optim_orth = torch.optim.RMSprop([model.rnn.log_recurrent_kernel], lr=args.lr_orth)
non_orth_params, log_orth_params = get_parameters(model)
optim = torch.optim.RMSprop(non_orth_params, args.lr)
optim_orth = torch.optim.RMSprop(log_orth_params, lr=args.lr_orth)

best_test_acc = 0.
for epoch in range(epochs):
Expand All @@ -128,20 +110,27 @@ def main():
logits = model(batch_x)
loss = model.loss(logits, batch_y)

# Zeroing out the optim_orth is not really necessary, but we do it for consistency
if optim_orth:
optim_orth.zero_grad()
optim.zero_grad()

loss.backward()

if optim_orth:
model.rnn.orthogonal_step(optim_orth)
model.apply(orthogonal_step(optim_orth))
optim.step()

with torch.no_grad():
correct = model.correct(logits, batch_y)

processed += len(batch_x)
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.6f}%'.format(
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.2f}%\tBest: {:.2f}%'.format(
epoch, processed, len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item(), correct/len(batch_x)))
100. * batch_idx / len(train_loader), loss.item(), 100 * correct/len(batch_x), best_test_acc))


model.rnn.recurrent_kernel.orthogonalise()
model.eval()
with torch.no_grad():
test_loss = 0.
Expand All @@ -153,10 +142,12 @@ def main():
correct += model.correct(logits, batch_y).float()

test_loss /= len(test_loader)
test_acc = correct / len(test_loader.dataset)
test_acc = 100 * correct / len(test_loader.dataset)
best_test_acc = max(test_acc, best_test_acc)
print("\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%), Best Accuracy: {:.3f}\n"
.format(test_loss, correct, len(test_loader.dataset), 100 * test_acc, best_test_acc))
print("\n")
print(args)
print("Test set: Average loss: {:.4f}, Accuracy: {:.2f}%, Best Accuracy: {:.2f}%\n"
.format(test_loss, test_acc, best_test_acc))

model.train()

Expand Down

0 comments on commit 19a3ca1

Please sign in to comment.