Skip to content

Commit

Permalink
Fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
pbloem committed Sep 15, 2018
1 parent e1f85c0 commit 5bf24c5
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 19 deletions.
41 changes: 28 additions & 13 deletions gaussian.py
@@ -1,19 +1,12 @@
import torch
from numpy.core.multiarray import dtype
from torch.autograd import Variable
from torch.nn import Parameter
from torch import FloatTensor, LongTensor

import abc, itertools, math, types
from numpy import prod

import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

from util import *
import util
Expand Down Expand Up @@ -55,7 +48,7 @@ def fi_matrix(indices, shape):

batchsize, rows, rank = indices.size()

prod = LongTensor(rank).fill_(1)
prod = torch.LongTensor(rank).fill_(1)

if indices.is_cuda:
prod = prod.cuda()
Expand Down Expand Up @@ -106,7 +99,7 @@ def tup(index, shape, use_cuda=False):
result = torch.cuda.LongTensor(num, len(shape)) if use_cuda else LongTensor(num, len(shape))

for dim in range(len(shape) - 1):
per_inc = hyper.prod(shape[dim+1:])
per_inc = util.prod(shape[dim+1:])
result[:, dim] = index / per_inc
index = index % per_inc
result[:, -1] = index
Expand Down Expand Up @@ -510,7 +503,7 @@ def discretize(self, means, sigmas, values, rng=None, additional=16, use_cuda=Fa
# Sample additional points
if rng is not None:
t0 = time.time()
total = hyper.prod(rng)
total = util.prod(rng)

if PROPER_SAMPLING:

Expand Down Expand Up @@ -717,6 +710,28 @@ def forward_inner(self, input, means, sigmas, values, bias):

return y

def forward_sample(self, input):
"""
Samples a single sparse matrix, and computes a transformation with that in a non-differentiable manner.
:param input:
:return:
"""

# Sample k indices



def backward_sample(self, batch_loss, q_prob, p_prob):
"""
Computes the gradient by REINFORCE, using the given batch loss, and the probabilities of the sample (as returned by forward_sample)
:param bacth_loss:
:param q_prob:
:param p_prob:
:return:
"""

class ParamASHLayer(HyperLayer):
"""
Hyperlayer with free sparse parameters, no hypernetwork.
Expand Down Expand Up @@ -860,7 +875,7 @@ def __init__(self, in_shape, out_shape, k, additional=0, poolsize=4, subsample=N
self.conv2d = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=(2, self.w_rank+2), stride=2)

self.bias = nn.Sequential(
nn.Linear(int(k/rep), hyper.prod(out_shape)),
nn.Linear(int(k/rep), util.prod(out_shape)),
)

def hyper(self, input):
Expand Down Expand Up @@ -975,7 +990,7 @@ def forward(self, input):

if self.adaptive_bias:
self.bias = nn.Sequential(
nn.Linear(self.ha * self.hb, hyper.prod(out_shape)),
nn.Linear(self.ha * self.hb, util.prod(out_shape)),
self.activation
)
else:
Expand Down Expand Up @@ -1086,7 +1101,7 @@ def forward(self, input):
nn.ConvTranspose1d(in_channels=width, out_channels=width, kernel_size=2, stride=2))

self.bias = nn.Sequential(
nn.Linear(self.ha * self.hb, hyper.prod(out_shape)),
nn.Linear(self.ha * self.hb, util.prod(out_shape)),
self.activation
)

Expand Down
2 changes: 0 additions & 2 deletions identity_experiment.py
Expand Up @@ -13,8 +13,6 @@

from argparse import ArgumentParser

import psutil, os

logging.basicConfig(filename='run.log',level=logging.INFO)
LOG = logging.getLogger()

Expand Down
151 changes: 151 additions & 0 deletions identity_reinforce.py
@@ -0,0 +1,151 @@
import hyper, gaussian
import torch, random, sys
from torch.autograd import Variable
from torch.nn import Parameter
from torch.nn.functional import sigmoid
from torch import nn, optim
from tqdm import trange
from tensorboardX import SummaryWriter

import matplotlib.pyplot as plt
import util, logging, time, gc
import numpy as np

from argparse import ArgumentParser

logging.basicConfig(filename='run.log',level=logging.INFO)
LOG = logging.getLogger()

"""
Simple experiment: learn the identity function from one tensor to another
"""
w = SummaryWriter()

def go(iterations=30000, additional=64, batch=4, size=32, cuda=False, plot_every=50,
lr=0.01, fv=False, sigma_scale=0.1, min_sigma=0.0, seed=0):

SHAPE = (size,)
MARGIN = 0.1

torch.manual_seed(seed)

nzs = hyper.prod(SHAPE)

util.makedirs('./identity/')

params = None

gaussian.PROPER_SAMPLING = False
model = gaussian.ParamASHLayer(SHAPE, SHAPE, k=size, additional=additional, sigma_scale=sigma_scale, has_bias=False, fix_values=fv, min_sigma=min_sigma)

if cuda:
model.cuda()

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)


for i in trange(iterations):

x = torch.zeros((batch,) + SHAPE) + (1.0/16.0)
x = torch.bernoulli(x)
if cuda:
x = x.cuda()
x = Variable(x)

optimizer.zero_grad()

y = model(x)

loss = criterion(y, x)

t0 = time.time()
loss.backward() # compute the gradients

optimizer.step()

w.add_scalar('identity32/loss', loss.data[0], i*batch)

if plot_every > 0 and i % plot_every == 0:
plt.figure(figsize=(7, 7))

print(loss)

means, sigmas, values = model.hyper(x)

plt.cla()
util.plot(means, sigmas, values, shape=(SHAPE[0], SHAPE[0]))
plt.xlim((-MARGIN*(SHAPE[0]-1), (SHAPE[0]-1) * (1.0+MARGIN)))
plt.ylim((-MARGIN*(SHAPE[0]-1), (SHAPE[0]-1) * (1.0+MARGIN)))

plt.savefig('./identity/means{:04}.png'.format(i))

return float(loss.data[0])

if __name__ == "__main__":

## Parse the command line options
parser = ArgumentParser()

parser.add_argument("-s", "--size",
dest="size",
help="Size (nr of dimensions) of the input.",
default=32, type=int)

parser.add_argument("-b", "--batch-size",
dest="batch_size",
help="The batch size.",
default=64, type=int)

parser.add_argument("-i", "--iterations",
dest="iterations",
help="The number of iterations (ie. the nr of batches).",
default=3000, type=int)

parser.add_argument("-a", "--additional",
dest="additional",
help="Number of additional points sampled",
default=512, type=int)

parser.add_argument("-c", "--cuda", dest="cuda",
help="Whether to use cuda.",
action="store_true")

parser.add_argument("-F", "--fix_values", dest="fix_values",
help="Whether to fix the values to 1.",
action="store_true")

parser.add_argument("-l", "--learn-rate",
dest="lr",
help="Learning rate",
default=0.005, type=float)

parser.add_argument("-S", "--sigma-scale",
dest="sigma_scale",
help="Sigma scale",
default=0.1, type=float)

parser.add_argument("-M", "--min_sigma",
dest="min_sigma",
help="Minimum variance for the components.",
default=0.0, type=float)

parser.add_argument("-p", "--plot-every",
dest="plot_every",
help="Plot every x iterations",
default=50, type=int)

parser.add_argument("-r", "--random-seed",
dest="seed",
help="Random seed.",
default=32, type=int)

options = parser.parse_args()

print('OPTIONS ', options)
LOG.info('OPTIONS ' + str(options))

go(batch=options.batch_size, size=options.size,
additional=options.additional, iterations=options.iterations, cuda=options.cuda,
lr=options.lr, plot_every=options.plot_every, fv=options.fix_values,
sigma_scale=options.sigma_scale, min_sigma=options.min_sigma, seed=options.seed)
18 changes: 14 additions & 4 deletions images.experiment.py
Expand Up @@ -364,6 +364,7 @@ def hyper(self, input, prep=None):
"""
Evaluates hypernetwork.
"""
# print('!!!', prep.size())

b, c, h, w = input.size()
l = self.pixel_indices.size(0)
Expand Down Expand Up @@ -423,7 +424,7 @@ def hyper(self, input, prep=None):

def forward(self, input, prep=None):

self.last_out = super().forward(input)
self.last_out = super().forward(input, prep=prep)

return self.last_out

Expand Down Expand Up @@ -519,16 +520,21 @@ def forward(self, image):

glimpses = []
for i, hyper in enumerate(self.hyperlayers):
glimpses.append(hyper(image, prep[i*4 : (i+1)*4]))
glimpses.append(hyper(image, prep=prep[:, i*4 : (i+1)*4]))

x = torch.cat(glimpses, dim=1).view(b, -1)
x = F.relu(self.lin1(x))
x = F.softmax(self.lin2(x), dim=1)

return x

def debug(self):
print(list(self.preprocess.parameters())[0].grad)

def plot(self, images):

prep = self.preprocess(images)

perrow = 5

num, c, w, h = images.size()
Expand All @@ -545,15 +551,17 @@ def plot(self, images):

ax.imshow(im, interpolation='nearest', extent=(-0.5, w - 0.5, -0.5, h - 0.5), cmap='gray_r')

for hyper in self.hyperlayers:
means, sigmas, values, _ = hyper.hyper(images)
for i, hyper in enumerate(self.hyperlayers):
means, sigmas, values, _ = hyper.hyper( images, prep=prep[:, i*4 : (i+1)*4] )

util.plot(means[i, :].unsqueeze(0), sigmas[i, :].unsqueeze(0), values[i, :].unsqueeze(0),
axes=ax, flip_y=h, alpha_global=0.3)

ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)

ax.set_xlim(-0.5, w - 0.5)
ax.set_ylim(-0.5, h - 0.5)

plt.gcf()

Expand Down Expand Up @@ -987,6 +995,8 @@ def go(args, batch=64, epochs=350, k=3, additional=64, modelname='baseline', cud
loss.backward() # compute the gradients
logging.info('backward: {} seconds'.format(time.time() - t0))

# model.debug()

# print(hyperlayer.values, hyperlayer.values.grad)

optimizer.step()
Expand Down

0 comments on commit 5bf24c5

Please sign in to comment.