Skip to content

Commit

Permalink
Added hard weight sharing.
Browse files Browse the repository at this point in the history
  • Loading branch information
pbloem committed Nov 19, 2017
1 parent ddb1697 commit 52f0087
Show file tree
Hide file tree
Showing 7 changed files with 488 additions and 75 deletions.
180 changes: 150 additions & 30 deletions gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,46 @@ def densities(points, means, sigmas):
"""
Compute the unnormalized PDFs of the points under the given MVNs
(with sigma a diagonal matrix per MVN)
:param means:
:param sigmas:
:param points:
:return:
"""

# n: number of MVNs
# d: number of points per MVN
# rank: dim of points

batchsize, n, d, rank = points.size()

means = means.unsqueeze(2).expand_as(points)

sigmas = sigmas.unsqueeze(2).expand_as(points)
sigmas_squared = torch.sqrt(sigmas)

points = points - means
points = points * (1.0/(EPSILON+sigmas_squared))

# Compute dot products for all points
# -- unroll the batch/n dimensions
points = points.view(-1, 1, rank, 1).squeeze(3)
# -- dot prod
products = torch.bmm(points, points.transpose(1,2))
# -- reconstruct shape
products = products.view(batchsize, n, d)

num = torch.exp(- 0.5 * products)

return num

def densities_single(points, means, sigmas):
"""
Compute the unnormalized PDFs of the points under the given MVNs
(with sigma a single number per MVN)
:param means:
:param sigmas:
:param points:
Expand Down Expand Up @@ -332,7 +372,7 @@ def __init__(self, in_rank, out_shape, additional=0, bias_type=Bias.DENSE):

self.bias_type = bias_type

def split_out(self, res, input_size, output_size, gain=5.0):
def split_out(self, res, input_size, output_size):
"""
Utility function. res is a B x K x Wrank+2 tensor with range from
-inf to inf, this function splits out the means, sigmas and values, and
Expand All @@ -348,26 +388,83 @@ def split_out(self, res, input_size, output_size, gain=5.0):
b, k, width = res.size()
w_rank = width - 2

means = nn.functional.sigmoid(res[:, :, 0:w_rank] * gain)
means = nn.functional.sigmoid(res[:, :, 0:w_rank])
means = means.unsqueeze(2).contiguous().view(-1, k, w_rank)

## expand the indices to the range [0, max]

# Limits for each of the w_rank indices
# and scales for the sigmas
ws = list(output_size) + list(input_size)
s = torch.cuda.FloatTensor(ws) if self.use_cuda else FloatTensor(ws)
s = Variable(s.contiguous())
s = s - 1
s = s.unsqueeze(0).unsqueeze(0)
s = s.expand_as(means)

means = means * s
ss = s.unsqueeze(0).unsqueeze(0)
sm = s - 1
sm = sm.unsqueeze(0).unsqueeze(0)

sigmas = nn.functional.softplus(res[:, :, w_rank:w_rank + 1]).squeeze(2) + 0.0
means = means * sm.expand_as(means)

sigmas = nn.functional.softplus(res[:, :, w_rank:w_rank + 1]).squeeze(2)
values = nn.functional.softplus(res[:, :, w_rank + 1:].squeeze(2))

sigmas = sigmas.unsqueeze(2).expand_as(means)

sigmas = sigmas * ss.expand_as(sigmas)

return means, sigmas, values


def split_shared(self, res, input_size, output_size, values):
"""
Splits res into means and sigmas, samples values according to multinomial parameters
in res
:param res:
:param input_size:
:param output_size:
:param gain:
:return:
"""

b, k, width = res.size()
w_rank = len(input_size) + len(output_size)

means = nn.functional.sigmoid(res[:, :, 0:w_rank])
means = means.unsqueeze(2).contiguous().view(-1, k, w_rank)

## expand the indices to the range [0, max]

# Limits for each of the w_rank indices
# and scales for the sigmas
ws = list(output_size) + list(input_size)
s = torch.cuda.FloatTensor(ws) if self.use_cuda else FloatTensor(ws)
s = Variable(s.contiguous())

ss = s.unsqueeze(0).unsqueeze(0)
sm = s - 1
sm = sm.unsqueeze(0).unsqueeze(0)

means = means * sm.expand_as(means)

sigmas = nn.functional.softplus(res[:, :, w_rank:w_rank+1])

sigmas = sigmas.expand_as(means)
sigmas = sigmas * ss.expand_as(sigmas)

# extract the values
vweights = res[:, :, w_rank+1:].contiguous()

assert vweights.size()[2] == values.size()[0]

vweights = util.bsoftmax(vweights)

samples, snode = util.bmultinomial(vweights, num_samples=1)

weights = values[samples.data.view(-1)].view(b, k)

return means, sigmas, weights, snode

def forward(self, input):

t0total = time.time()
Expand Down Expand Up @@ -553,62 +650,73 @@ def reset(t):
# plt.savefig('./init/means.{:06}.png'.format(i))


class DenseASHLayer(HyperLayer):
class ParamASHLayer(HyperLayer):
"""
Hyperlayer with arbitrary (fixed) in/out shape. Uses simple dense hypernetwork
Hyperlayer with free sparse parameters, no hypernetwork.
"""

def __init__(self, in_shape, out_shape, k, additional=0, hidden=256):
super().__init__(in_rank=len(in_shape), out_shape=out_shape, additional=additional, bias_type=Bias.NONE)
def __init__(self, in_shape, out_shape, k, additional=0, sigma_scale=0.1, fix_values=False):
super().__init__(in_rank=len(in_shape), additional=additional, out_shape=out_shape, bias_type=Bias.NONE)

self.k = k
self.in_shape = in_shape
self.out_shape = out_shape
self.sigma_scale = sigma_scale
self.fix_values = fix_values

self.w_rank = len(in_shape) + len(out_shape)

# hypernetwork
self.hyp = nn.Sequential(
Flatten(),
nn.Linear(prod(in_shape), hidden),
nn.ReLU(),
nn.Linear(hidden, (self.w_rank + 2) * k),
)

# self.bias = Parameter(torch.zeros(out_shape))
p = torch.randn(k, self.w_rank + 2)
p[:, self.w_rank:self.w_rank + 1] = p[:, self.w_rank:self.w_rank + 1]
self.params = Parameter(p)

def hyper(self, input):
"""
Evaluates hypernetwork.
"""
res = self.hyp.forward(input)

res = res.unsqueeze(1).view(-1, self.k, self.w_rank + 2)
batch_size = input.size()[0]

# Replicate the parameters along the batch dimension
res = self.params.unsqueeze(0).expand(batch_size, self.k, self.w_rank+2)

means, sigmas, values = self.split_out(res, input.size()[1:], self.out_shape)

sigmas = sigmas * self.sigma_scale
if self.fix_values:
values = values * 0.0 + 1.0

return means, sigmas, values

def clone(self):
result = ParamASHLayer(self.in_shape, self.out_shape, self.k, self.additional, self.gain)

class ParamASHLayer(HyperLayer):
result.params = Parameter(self.params.data.clone())

return result

class WeightSharingASHLayer(HyperLayer):
"""
Hyperlayer with arbitrary (fixed) in/out shape. Uses simple dense hypernetwork
Hyperlayer with free sparse parameters, no hypernetwork, and a limited number of weights with hard sharing
"""

def __init__(self, in_shape, out_shape, k, additional=0, gain=5.0):
def __init__(self, in_shape, out_shape, k, additional=0, sigma_scale=0.1, num_values=2):
super().__init__(in_rank=len(in_shape), additional=additional, out_shape=out_shape, bias_type=Bias.NONE)

self.k = k
self.in_shape = in_shape
self.out_shape = out_shape
self.gain = gain
self.sigma_scale = sigma_scale

self.w_rank = len(in_shape) + len(out_shape)

p = torch.randn(k, self.w_rank + 2)
p[:, self.w_rank:self.w_rank + 1] = p[:, self.w_rank:self.w_rank + 1] * 0.0 + 0.6
p = torch.randn(k, self.w_rank + 1 + num_values)
p[:, self.w_rank:self.w_rank + 1] = p[:, self.w_rank:self.w_rank + 1]
self.params = Parameter(p)

self.sources = Parameter(torch.randn(num_values))
# self.sources = Variable(FloatTensor([-1.0, 1.0]))

def hyper(self, input):
"""
Evaluates hypernetwork.
Expand All @@ -617,13 +725,25 @@ def hyper(self, input):
batch_size = input.size()[0]

# Replicate the parameters along the batch dimension
res = self.params.unsqueeze(0).expand(batch_size, self.k, self.w_rank+2)
rows, columns = self.params.size()
res = self.params.unsqueeze(0).expand(batch_size, rows, columns)

means, sigmas, values = self.split_out(res, input.size()[1:], self.out_shape, self.gain)
means, sigmas, values, self.samples = self.split_shared(res, input.size()[1:], self.out_shape, self.sources)
sigmas = sigmas * self.sigma_scale

return means, sigmas, values

def call_reinforce(self, downstream_reward):
b, = downstream_reward.size()

rew = downstream_reward.unsqueeze(1).expand(b, self.k)
rew = rew.contiguous().view(-1, 1)

self.samples.reinforce(rew)
self.samples.backward()

def clone(self):

result = ParamASHLayer(self.in_shape, self.out_shape, self.k, self.additional, self.gain)

result.params = Parameter(self.params.data.clone())
Expand Down
19 changes: 10 additions & 9 deletions identity.experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
"""
w = SummaryWriter()

BATCH = 256
SHAPE = (32, )
CUDA = True
BATCH = 64
SHAPE = (16, )
CUDA = False
MARGIN = 0.1

torch.manual_seed(2)

Expand All @@ -34,14 +35,14 @@

params = None

model = gaussian.ParamASHLayer(SHAPE, SHAPE, additional=256, k=nzs, gain=1.0)
model = gaussian.WeightSharingASHLayer(SHAPE, SHAPE, additional=64, k=nzs, sigma_scale=0.2, num_values=2)
# model.initialize(SHAPE, batch_size=64, iterations=100, lr=0.05)

if CUDA:
model.cuda()

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
optimizer = optim.Adam(model.parameters(), lr=0.01)

for i in trange(N):

Expand All @@ -63,13 +64,13 @@

w.add_scalar('identity32/loss', loss.data[0], i*BATCH)

if i % (N//50) == 0:
if i % (N//2500) == 0:
means, sigmas, values = model.hyper(x)

plt.clf()
util.plot(means, sigmas, values)
plt.xlim((-1, SHAPE[0]))
plt.ylim((-1, SHAPE[0]))
util.plot(means, sigmas, values, shape=(SHAPE[0], SHAPE[0]))
plt.xlim((-MARGIN*(SHAPE[0]-1), (SHAPE[0]-1) * (1.0+MARGIN)))
plt.ylim((-MARGIN*(SHAPE[0]-1), (SHAPE[0]-1) * (1.0+MARGIN)))
plt.savefig('./spread/means{:04}.png'.format(i))

print('LOSS', torch.sqrt(loss))
70 changes: 70 additions & 0 deletions orthogonal.experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import hyper, gaussian
import torch, random, sys
from torch.autograd import Variable
from torch.nn import Parameter
from torch import nn, optim
from tqdm import trange
from tensorboardX import SummaryWriter

import matplotlib.pyplot as plt
import util, logging, time, gc

import psutil, os

logging.basicConfig(filename='run.log',level=logging.INFO)
LOG = logging.getLogger()

"""
Learn any orthogonal mapping
"""
w = SummaryWriter()

BATCH = 256
INSHAPE = (4, )
OUTSHAPE = (4, )

CUDA = False

torch.manual_seed(2)

nzs = 4

N = 300000 // BATCH

scale = 2
# plt.figure(figsize=(INSHAPE[0]*scale,OUTSHAPE[0]*scale))
plt.figure(figsize=(5,5))

MARGIN = 0.1
util.makedirs('./spread/')

params = None

model = gaussian.ParamASHLayer(INSHAPE, OUTSHAPE, additional=6, k=nzs, sigma_scale=0.4, fix_values=True)

if CUDA:
model.cuda()

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

for i in trange(N):

loss, x1, _ = util.orth_loss(BATCH, INSHAPE, model, CUDA)

loss.backward() # compute the gradients

optimizer.step()

w.add_scalar('orthogonal/loss', loss.data[0], i*BATCH)

if i % (N//250) == 0:
means, sigmas, values = model.hyper(x1)

plt.clf()
util.plot(means, sigmas, values, shape=(INSHAPE[0], OUTSHAPE[0]))
plt.xlim((-MARGIN*(INSHAPE[0]-1), (INSHAPE[0]-1) * (1.0+MARGIN)))
plt.ylim((-MARGIN*(OUTSHAPE[0]-1), (OUTSHAPE[0]-1) * (1.0+MARGIN)))
plt.savefig('./spread/means{:04}.png'.format(i))

print('LOSS', torch.sqrt(loss))
Loading

0 comments on commit 52f0087

Please sign in to comment.