/
layers.py
808 lines (591 loc) · 29.3 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import Parameter
from torch import FloatTensor, LongTensor
import abc, itertools, math, types
from numpy import prod
import torch.nn.functional as F
import tensors
from sparse import util
from sparse.util import Bias, sparsemult, contains_nan, bmult, nduplicates, d
import sys
import random
import numpy as np
from enum import Enum
# added to the sigmas to prevent NaN
EPSILON = 10e-7
SIGMA_BOOST = 2.0
"""
Core implementation of the sparse (hyper)layer as an abstract class (SparseLayer).
"""
def densities(points, means, sigmas):
"""
Compute the unnormalized probability densities of a given set of points for a
given set of multivariate normal distrbutions (MVNs)
:param means: (b, n, c, r) tensor of n vectors of dimension r (in a batch of size b)
representing the means of n MVNs
:param sigmas: (b, k, l, r) tensor of n vectors of dimension r (in a batch of size b)
representing the diagonal covariance matrix of n MVNs
:param points: The points for which to compute the probabilioty densities
:return: (b, k, n) tensor containing the density of every point under every MVN
"""
# n: number of MVNs
# rank: dim of points
# i: number of integer index tuples sampled per chunk
# k: number of continuous index tuples per chunk
# c: number of chunks
c, i, rank = points.size()[-3:]
c, k, rank = means.size()[-3:]
pref = points.size()[:-3]
assert pref == means.size()[:-3]
points = points.unsqueeze(-2).expand( *(pref + (c, i, k, rank)) )
means = means.unsqueeze(-3).expand_as(points)
sigmas = sigmas.unsqueeze(-3).expand_as(points)
sigmas_squared = torch.sqrt(1.0/(EPSILON+sigmas))
points = points - means
points = points * sigmas_squared
# Compute dot products for all points
# -- unroll the pref/c/k/l dimensions
points = points.view(-1, 1, rank)
# -- dot prod
# print(points)
products = torch.bmm(points, points.transpose(1, 2))
# -- reconstruct shape
products = products.view( *(pref + (c, i, k)) )
num = torch.exp(- 0.5 * products) # the numerator of the Gaussian density
return num
def transform_means(means, size, method='sigmoid'):
"""
Transforms raw parameters for the index tuples (with values in (-inf, inf)) into parameters within the bound of the
dimensions of the tensor.
In the case of a templated sparse layer, these parameters and the corresponding size tuple deascribe only the learned
subtensor.
:param means: (..., rank) tensor of raw parameter values
:param size: Tuple describing the tensor dimensions.
:return: (..., rank)
"""
# Compute upper bounds
s = torch.tensor(list(size), dtype=torch.float, device=d(means)) - 1
s = util.unsqueezen(s, len(means.size()) - 1)
s = s.expand_as(means)
# Scale to [0, 1]
if method == 'modulo':
means = means.remainder(s)
return means
if method == 'clamp':
means = torch.max(means, torch.zeros(means.size(), device=d(means)))
means = torch.min(means, s)
return means
means = torch.sigmoid(means)
return means * s
def transform_sigmas(sigmas, size, min_sigma=EPSILON):
"""
Transforms raw parameters for the conv matrices (with values in (-inf, inf)) into positive values, scaled proportional
to the dimensions of the tensor. Note: each sigma is parametrized by a single value, which is expanded to a vector to
fit the diagonal of the covariance matrix.
In the case of a templated sparse layer, these parameters and the corresponing size tuple deascribe only the learned
subtensor.
:param sigmas: (..., ) matrix of raw sigma values
:param size: Tuple describing the tensor dimensions.
:param min_sigma: Minimal sigma value.
:return:(..., rank) sigma values
"""
ssize = sigmas.size()
r = len(size)
# Scale to [0, 1]
sigmas = F.softplus(sigmas + SIGMA_BOOST) + min_sigma
# sigmas = sigmas[:, :, None].expand(b, k, r)
sigmas = sigmas.unsqueeze(-1).expand(*(ssize + (r, )))
# Compute upper bounds
s = torch.tensor(list(size), dtype=torch.float, device='cuda' if sigmas.is_cuda else 'cpu')
s = util.unsqueezen(s, len(sigmas.size()) - 1)
s = s.expand_as(sigmas)
return sigmas * s
class SparseLayer(nn.Module):
"""
Abstract class for the (templated) hyperlayer. Implement by defining a hypernetwork, and returning it from the
hyper() method. See NASLayer for an implementation without hypernetwork.
The templated hyperlayer takes certain columns of its index-tuple matrix as fixed (the template), and others as
learnable. Imagine a neural network layer where the connections to the output nodes are fixed, but the connections to
the input nodes can be learned.
For a non-templated hypernetwork (all columns learnable), just leave the template parameters None.
"""
@abc.abstractmethod
def hyper(self, input):
"""
Applies the hypernetwork, and returns the continuous index tuples, with their associated sigmas and values.
:param input: The input to the hyperlayer.
:return: A triple: (means, sigmas, values)
"""
raise NotImplementedError
def __init__(self, in_rank, out_size,
temp_indices=None,
learn_cols=None,
chunk_size=None,
gadditional=0, radditional=0, region=None,
bias_type=Bias.DENSE):
"""
:param in_rank: Nr of dimensions in the input. The specific size may vary between inputs.
:param out_size: Tuple describing the size of the output.
:param temp_indices: The template describing the fixed part of the tuple index-tuple matrix. None for a
non-templated hyperlayer.
:param learn_cols: Which columns of the template are 'free' (to be learned). The rest are fixed. None for a
non-templated hyperlayer.
:param chunk_size: Size of the "context" of generating integer index tuples. Duplicates are removed withing the
same context. The list of continuous index tuples is chunked into contexts of this size. If none, the whole
list counts as a single context. This is mostly useful in combination with templating.
:param gadditional: Number of points to sample globally per index tuple
:param radditional: Number of points to sample locally per index tuple
:param region: Tuple describing the size of the region over which the local additional points are sampled (must
be smaller than the size of the tensor).
:param bias_type: The type of bias of the sparse layer (none, dense or sparse).
:param subsample:
"""
super().__init__()
rank = in_rank + len(out_size)
assert learn_cols is None or len(region) == len(learn_cols), "Region should span as many dimensions as there are learnable columns"
self.in_rank = in_rank
self.out_size = out_size # without batch dimension
self.gadditional = gadditional
self.radditional = radditional
self.region = region
self.chunk_size = chunk_size
self.bias_type = bias_type
self.learn_cols = learn_cols if learn_cols is not None else range(rank)
self.templated = temp_indices is not None
# create a tensor with all binary sequences of length 'out_rank' as rows
# (this will be used to compute the nearby integer-indices of a float-index).
self.register_buffer('floor_mask', floor_mask(len(self.learn_cols)))
if self.templated:
# template for the index matrix containing the hardwired connections
# The learned parts can be set to zero; they will be overriden.
assert temp_indices.size(1) == in_rank + len(out_size)
self.register_buffer('temp_indices', temp_indices)
def is_cuda(self):
return next(self.parameters()).is_cuda
def forward(self, input, mrange=None, seed=None, **kwargs):
"""
:param input:
:param mrange: Specifies a subrange of index tuples to compute the gradient over. This is helpful for gradient
accumulation methods. This doesn;t work together with templating.
:param seed:
:param kwargs:
:return:
"""
assert mrange is None or not self.templated, "Templating and gradient accumulation do not work together"
### Compute and unpack output of hypernetwork
bias = None
if self.bias_type == Bias.NONE:
means, sigmas, values = self.hyper(input, **kwargs)
elif self.bias_type == Bias.DENSE:
means, sigmas, values, bias = self.hyper(input, **kwargs)
elif self.bias_type == Bias.SPARSE:
raise Exception('Sparse bias not supported yet.')
else:
raise Exception('bias type {} not recognized.'.format(self.bias_type))
b, n, r = means.size()
dv = 'cuda' if self.is_cuda() else 'cpu'
# We divide the list of index tuples into 'chunks'. Each chunk represents a kind of context:
# - duplicate integer index tuples within the chunk are removed
# - proportions are normalized over all index tuples within the chunk
# This is useful in the templated setting. If no chunk size is requested, we just add a singleton dimension.
k = self.chunk_size if self.chunk_size is not None else n # chunk size
c = n // k # number of chunks
means, sigmas, values = means.view(b, c, k, r), sigmas.view(b, c, k, r), values.view(b, c, k)
assert b == input.size(0), 'input batch size ({}) should match parameter batch size ({}).'.format(input.size(0), b)
# max values allowed for each column in the index matrix
fullrange = self.out_size + input.size()[1:]
subrange = [fullrange[r] for r in self.learn_cols] # submatrix for the learnable dimensions
if not self.training:
indices = means.view(b, c*k, r).round().long()
else:
if mrange is not None: # only compute the gradient for a subset of index tuples
fr, to = mrange
# sample = random.sample(range(nm), self.subsample) # the means we will learn for
ids = torch.zeros((k,), dtype=torch.uint8, device=dv)
ids[fr:to] = 1
means, means_out = means[:, :, ids, :], means[:, :, ~ids, :]
sigmas, sigmas_out = sigmas[:, :, ids, :], sigmas[:, :, ~ids, :]
values, values_out = values[:, :, ids], values[:, :, ~ids]
# These should not get a gradient, since their means aren't being sampled for
# (their gradient will be computed in other passes)
means_out, sigmas_out, values_out = means_out.detach(), sigmas_out.detach(), values_out.detach()
indices = generate_integer_tuples(means, self.gadditional, self.radditional, rng=subrange, relative_range=self.region, seed=seed, cuda=self.is_cuda())
indfl = indices.float()
# Mask for duplicate indices
dups = nduplicates(indices)
# compute (unnormalized) densities under the given MVNs (proportions)
props = densities(indfl, means, sigmas).clone() # result has size (b, c, i, k), i = indices[2]
props[dups, :] = 0
props = props / props.sum(dim=2, keepdim=True) # normalize over all points of a given index tuple
# Weight the values by the proportions
values = values[:, :, None, :].expand_as(props)
values = props * values
values = values.sum(dim=3)
if mrange is not None:
indices_out = means_out.data.round().long()
#
# print(indices.size(), indices_out.size())
# print(values.size(), values_out.size())
# sys.exit()
indices = torch.cat([indices, indices_out], dim=2)
values = torch.cat([values, values_out], dim=2)
# remove the chunk dimensions
indices, values = indices.view(b, -1 , r), values.view(b, -1)
if self.templated:
# stitch the generated indices into the template
b, l, r = indices.size()
h, w = self.temp_indices.size()
template = self.temp_indices[None, :, None, :].expand(b, h, l//h, w)
template = template.contiguous().view(b, l, w)
template[:, :, self.learn_cols] = indices
indices = template
# if self.is_cuda():
# indices = indices.cuda()
size = self.out_size + input.size()[1:]
output = tensors.contract(indices, values, size, input)
if self.bias_type == Bias.DENSE:
return output + bias
return output
class NASLayer(SparseLayer):
"""
Sparse layer with free sparse parameters, no hypernetwork, no template.
"""
def __init__(self, in_size, out_size, k,
sigma_scale=0.2,
fix_values=False, has_bias=False,
min_sigma=0.0,
gadditional=0,
region=None,
radditional=None,
template=None,
learn_cols=None,
chunk_size=None):
"""
:param in_size:
:param out_size:
:param k:
:param sigma_scale:
:param fix_values:
:param has_bias:
:param min_sigma:
:param gadditional:
:param region:
:param radditional:
:param clamp:
:param template: LongTensor Template for the matrix of index tuples. Learnable columns are updated through backprop
other values are taken from the template.
:param learn_cols: tuple of integers. Learnable columns of the template.
"""
super().__init__(in_rank=len(in_size),
out_size=out_size,
bias_type=Bias.DENSE if has_bias else Bias.NONE,
gadditional=gadditional,
radditional=radditional,
region=region,
temp_indices=template,
learn_cols=learn_cols,
chunk_size=chunk_size)
self.k = k
self.in_size = in_size
self.out_size = out_size
self.sigma_scale = sigma_scale
self.fix_values = fix_values
self.has_bias = has_bias
self.min_sigma = min_sigma
self.rank = len(in_size) + len(out_size)
imeans = torch.randn(k, self.rank if template is None else len(learn_cols))
isigmas = torch.randn(k)
self.pmeans = Parameter(imeans)
self.psigmas = Parameter(isigmas)
if fix_values:
self.register_buffer('pvalues', torch.ones(k))
else:
self.pvalues = Parameter(torch.randn(k))
if self.has_bias:
self.bias = Parameter(torch.zeros(*out_size))
def hyper(self, input, **kwargs):
"""
Evaluates hypernetwork.
"""
b = input.size(0)
size = self.out_size + input.size()[1:] # total dimensions of the weight tensor
if self.learn_cols is not None:
size = [size[l] for l in self.learn_cols]
k, r = self.pmeans.size()
# Expand parameters along batch dimension
means = self.pmeans[None, :, :].expand(b, k, r)
sigmas = self.psigmas[None, :].expand(b, k)
values = self.pvalues[None, :].expand(b, k)
means, sigmas = transform_means(means, size), transform_sigmas(sigmas, size, min_sigma=self.min_sigma) * self.sigma_scale
if self.has_bias:
return means, sigmas, values, self.bias
return means, sigmas, values
class Convolution(nn.Module):
"""
A non-adaptive hyperlayer that mimics a convolution. That is, the basic structure of the layer is a convolution, but
instead of connecting every input in the patch to every output channel, we connect them sparsely, with parameters
learned by the hyperlayer.
The parameters are the same for each instance of the convolution kernel, but they are sampled separately for each.
The hyperlayer is _templated_ that is, each connection is fixed to one output node. There are k connections per
output node.
The stride is always 1, padding is always added to ensure that the output resolution is the same as the input
resolution.
"""
def __init__(self, in_size, out_channels, k, kernel_size=3,
gadditional=2, radditional=2, rprop=0.2,
min_sigma=0.0,
sigma_scale=0.1,
fix_values=False,
has_bias=True):
"""
:param in_size: Channels and resolution of the input
:param out_size: Tuple describing the size of the output.
:param k: Number of points sampled per instance of the kernel.
:param kernel_size: Size of the (square) kernel.,
:param gadditional: Number of points to sample globally per index tuple
:param radditional: Number of points to sample locally per index tuple
:param rprop: Describes the region over which the local samples are taken, as a proportion of the channels
:param bias_type: The type of bias of the sparse layer (none, dense or sparse).
:param subsample:
"""
super().__init__()
rank = 6
self.in_size = in_size
self.out_size = (out_channels,) + in_size[1:]
self.kernel_size = kernel_size
self.gadditional = gadditional
self.radditional = radditional
self.region = (max(1, math.floor(rprop * in_size[0])), kernel_size-1, kernel_size-1)
self.min_sigma = min_sigma
self.sigma_scale = sigma_scale
self.has_bias = has_bias
self.pad = nn.ZeroPad2d(kernel_size // 2)
self.means = nn.Parameter(torch.randn(out_channels, k, 3))
self.sigmas = nn.Parameter(torch.randn(out_channels, k))
self.values = None if fix_values else nn.Parameter(torch.randn(out_channels, k))
# out_indices = torch.LongTensor(list(np.ndindex( (in_size[1:]) )))
# self.register_buffer('out_indices', out_indices)
template = torch.LongTensor(list(np.ndindex( (out_channels, in_size[1], in_size[2]) )))
assert template.size() == (prod((out_channels, in_size[1], in_size[2])), 3)
template = F.pad(template, (0, 3))
self.register_buffer('template', template)
if self.has_bias:
self.bias = Parameter(torch.randn(*self.out_size))
def hyper(self, x):
"""
Returns the means, sigmas and values for a _single_ kernel. The same kernel is applied at every position (but
with fresh samples).
:param x:
:return:
"""
b = x.size(0)
size = (self.in_size[0], self.kernel_size, self.kernel_size)
o, k, r = self.means.size()
# Expand parameters along batch dimension
means = self.means[None, :, :].expand(b, o, k, r)
sigmas = self.sigmas[None, :].expand(b, o, k)
values = self.values[None, :].expand(b, o, k)
means, sigmas = transform_means(means, size), transform_sigmas(sigmas, size, min_sigma=self.min_sigma) * self.sigma_scale
return means, sigmas, values
def forward(self, x):
dv = 'cuda' if self.template.is_cuda else 'cpu'
# get continuous parameters
means, sigmas, values = self.hyper(x)
# zero pad
x = self.pad(x)
b, o, k, r = means.size()
assert sigmas.size() == (b, o, k, r)
assert values.size() == (b, o, k)
# number of instances of the convolution kernel
nk = self.in_size[1] * self.in_size[2]
# expand for all kernels
means = means [:, :, None, :, :].expand(b, o, nk, k, r)
sigmas = sigmas[:, :, None, :, :].expand(b, o, nk, k, r)
values = values[:, :, None, :] .expand(b, o, nk, k)
if not self.training:
indices = means.round().long()
l = k
else:
# sample integer index tuples
# print(means.size())
indices = ngenerate(means,
self.gadditional, self.radditional,
relative_range=self.region,
rng=(self.in_size[0], self.kernel_size, self.kernel_size),
cuda=means.is_cuda)
# for i in range(indices.contiguous().view(-1, 3).size(0)):
# print(indices.contiguous().view(-1, 3)[i, :])
# sys.exit()
# print('indices', indices.size())
indfl = indices.float()
b, o, nk, l, r = indices.size()
assert l == k * (2 ** r + self.gadditional + self.radditional)
assert nk == self.in_size[1] * self.in_size[2]
# mask for duplicate indices
dups = nduplicates(indices)
# compute unnormalized densities (proportions) under the given MVNs
props = densities(indfl, means, sigmas).clone() # result has size (..., c, i, k), i = indices[2]
# print('densities', props.size())
# print(util.contains_nan(props))
props[dups, :] = 0
# print('... ', props.size())
props = props / props.sum(dim=-2, keepdim=True) # normalize over all points of a given index tuple
# print(util.contains_nan(props))
# sys.exit()
# Weight the values by the proportions
values = values[:, :, :, None, :].expand_as(props)
values = props * values
values = values.sum(dim=4)
template = self.template[None, :, None, :].expand(b, self.out_size[0]*self.in_size[1]*self.in_size[2], l, 6)
template = template.view(b, self.out_size[0], nk, l, 6)
template[:, :, :, :, 3:] = indices
indices = template.contiguous().view(b, self.out_size[0] * nk * l, 6)
offsets = indices[:, :, 1:3]
# for i in range(indices.view(-1, 6).size(0)):
# print(indices.view(-1, 6)[i, :], values.view(-1)[i].data)
# sys.exit()
indices[:, :, 4:] = indices[:, :, 4:] + offsets
values = values.contiguous().view(b, self.out_size[0] * nk * l)
# apply tensor
size = self.out_size + x.size()[1:]
assert (indices.view(-1, 6).max(dim=0)[0] >= torch.tensor(size, device=dv)).sum() == 0, "Max values of indices ({}) out of bounds ({})".format(indices.view(-1, 6).max(dim=0)[0], size)
output = tensors.contract(indices, values, size, x)
if self.has_bias:
return output + self.bias
return output
FLOOR_MASKS = {}
def floor_mask(num_cols, cuda=False):
if num_cols not in FLOOR_MASKS:
lsts = [[int(b) for b in bools] for bools in itertools.product([True, False], repeat=num_cols)]
FLOOR_MASKS[num_cols] = torch.BoolTensor(lsts, device='cpu')
if cuda:
return FLOOR_MASKS[num_cols].cuda()
return FLOOR_MASKS[num_cols]
def generate_integer_tuples(means, gadditional, ladditional, rng=None, relative_range=None, seed=None, cuda=False, fm=None):
"""
Takes continuous-valued index tuples, and generates integer-valued index tuples.
The returned matrix of ints is not a Variable (just a plain LongTensor). Autograd of the real valued indices passes
through the values alone, not the integer indices used to instantiate the sparse matrix.
:param ind: A Variable containing a matrix of N by K, where K is the number of indices.
:param val: A Variable containing a vector of length N containing the values corresponding to the given indices
:return: a triple (ints, props, vals). ints is an N*2^K by K matrix representing the N*2^K integer index-tuples that can
be made by flooring or ceiling the indices in 'ind'. 'props' is a vector of length N*2^K, which indicates how
much of the original value each integer index-tuple receives (based on the distance to the real-valued
index-tuple). vals is vector of length N*2^K, containing the value of the corresponding real-valued index-tuple
(ie. vals just repeats each value in the input 'val' 2^K times).
"""
b, k, c, rank = means.size()
FT = torch.cuda.FloatTensor if cuda else torch.FloatTensor
if seed is not None:
torch.manual_seed(seed)
"""
Generate neighbor tuples
"""
if fm is None:
fm = floor_mask(rank, cuda)
fm = fm[None, None, None, :].expand(b, k, c, 2 ** rank, rank)
neighbor_ints = means.data[:, :, :, None, :].expand(b, k, c, 2 ** rank, rank).contiguous()
neighbor_ints[fm] = neighbor_ints[fm].floor()
neighbor_ints[~fm] = neighbor_ints[~fm].ceil()
neighbor_ints = neighbor_ints.long()
"""
Sample uniformly from all integer tuples
"""
global_ints = FT(b, k, c, gadditional, rank)
global_ints.uniform_()
global_ints *= (1.0 - EPSILON)
rng = FT(rng)
rngxp = rng.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand_as(global_ints)
global_ints = torch.floor(global_ints * rngxp).long()
"""
Sample uniformly from a small range around the given index tuple
"""
local_ints = FT(b, k, c, ladditional, rank)
local_ints.uniform_()
local_ints *= (1.0 - EPSILON)
rngxp = rng[None, None, None, :].expand_as(local_ints) # bounds of the tensor
rrng = FT(relative_range) # bounds of the range from which to sample
rrng = rrng[None, None, None, :].expand_as(local_ints)
# print(means.size())
mns_expand = means.round().unsqueeze(3).expand_as(local_ints)
# upper and lower bounds
lower = mns_expand - rrng * 0.5
upper = mns_expand + rrng * 0.5
# check for any ranges that are out of bounds
idxs = lower < 0.0
lower[idxs] = 0.0
idxs = upper > rngxp
lower[idxs] = rngxp[idxs] - rrng[idxs]
local_ints = (local_ints * rrng + lower).long()
all = torch.cat([neighbor_ints, global_ints, local_ints] , dim=3)
return all.view(b, k, -1, rank) # combine all indices sampled within a chunk
def ngenerate(means, gadditional, ladditional, rng=None, relative_range=None, seed=None, cuda=False, fm=None, epsilon=EPSILON):
"""
Generates random integer index tuples based on continuous parameters.
:param epsilon: The random bumbers are based on uniform samples in (0, 1-epsilon). Note that
in some cases epsilon needs to be relatively big (e.g. 10-5)
"""
b = means.size(0)
k, c, rank = means.size()[-3:]
pref = means.size()[:-1]
FT = torch.cuda.FloatTensor if cuda else torch.FloatTensor
rng = FT(tuple(rng))
# - the tuple() is there in case a torch.Size() object is passed (which causes torch to
# interpret the argument as the size of the tensor rather than its content).
bounds = util.unsqueezen(rng, len(pref) + 1).long() # index bound with unsqueezed dims for broadcasting
if seed is not None:
torch.manual_seed(seed)
"""
Generate neighbor tuples
"""
if fm is None:
fm = floor_mask(rank, cuda)
size = pref + (2**rank, rank)
fm = util.unsqueezen(fm, len(size) - 2).expand(size)
neighbor_ints = means.data.unsqueeze(-2).expand(*size).contiguous()
neighbor_ints[fm] = neighbor_ints[fm].floor()
neighbor_ints[~fm] = neighbor_ints[~fm].ceil()
neighbor_ints = neighbor_ints.long()
assert (neighbor_ints >= bounds).sum() == 0, 'One of the neighbor indices is outside the tensor bounds'
"""
Sample uniformly from all integer tuples
"""
gsize = pref + (gadditional, rank)
global_ints = FT(*gsize)
global_ints.uniform_()
global_ints *= (1.0 - epsilon)
rngxp = util.unsqueezen(rng, len(gsize) - 1).expand_as(global_ints)
global_ints = torch.floor(global_ints * rngxp).long()
assert (global_ints >= bounds).sum() == 0, 'One of the global sampled indices is outside the tensor bounds'
"""
Sample uniformly from a small range around the given index tuple
"""
lsize = pref + (ladditional, rank)
local_ints = FT(*lsize)
local_ints.uniform_()
local_ints *= (1.0 - epsilon)
rngxp = util.unsqueezen(rng, len(lsize) - 1).expand_as(local_ints) # bounds of the tensor
rrng = FT(relative_range) # bounds of the range from which to sample
rrng = util.unsqueezen(rrng, len(lsize) - 1).expand_as(local_ints)
# print(means.size())
mns_expand = means.round().unsqueeze(-2).expand_as(local_ints)
# upper and lower bounds
lower = mns_expand - rrng * 0.5
upper = mns_expand + rrng * 0.5
# check for any ranges that are out of bounds
idxs = lower < 0.0
lower[idxs] = 0.0
idxs = upper > rngxp
lower[idxs] = rngxp[idxs] - rrng[idxs]
cached = local_ints.clone()
local_ints = (local_ints * rrng + lower).long()
assert (local_ints >= bounds).sum() == 0, f'One of the local sampled indices is outside the tensor bounds (this may mean the epsilon is too small)' \
f'\n max sampled {(cached * rrng).max().item()}, rounded {(cached * rrng).max().long().item()} max lower limit {lower.max().item()}' \
f'\n sum {((cached * rrng).max() + lower.max()).item()}' \
f'\n rounds to {((cached * rrng).max() + lower.max()).long().item()}'
#f'\n {means}\n {local_ints}\n {cached * rrng}'
all = torch.cat([neighbor_ints, global_ints, local_ints] , dim=-2)
fsize = pref[:-1] + (-1, rank)
return all.view(*fsize) # combine all indices sampled within a chunk