In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%run imports.py

In [159]:
class DimLinear(nn.Conv1d):
    """
    This module is simply a linear layer that is applied to an arbritrary dimension rather than the last dimension.
    """
    def __init__(self, shape, dim_to_mix, out_features=None, bias=True):
        self.in_features = shape[dim_to_mix]
        if out_features is None:
            out_features = self.in_features
        self.out_features = out_features
        super().__init__(self.in_features, self.out_features, kernel_size=1, bias=bias)
        
        self.input_shape = list(shape)
        self.output_shape = copy.copy(self.input_shape)
        self.output_shape[dim_to_mix] = out_features
        self.dim_to_mix = dim_to_mix 
        
    def forward(self, x, verbose=False):
        bs, is_ = util.bs_is_split(x.shape, np.prod(self.input_shape, dtype=int))
        if list(is_) != self.input_shape:
            raise Exception(f'Input shape {is_} does not match expected input shape {self.input_shape}')
            
        bsis = [*bs, *self.input_shape]
        bsos = [*bs, *self.output_shape]
        left_collapse = np.prod([*bs, *is_[:self.dim_to_mix]], dtype=int)
        right_collapse = np.prod([*is_[self.dim_to_mix+1:]], dtype=int)
        proc_shape = [left_collapse, self.in_features, right_collapse]
        
        if verbose:
            print(f'collapsing shape {bsis}->{proc_shape}->{bsos}')
        
        x = x.reshape(*proc_shape)
        x = super().forward(x)
        x = x.reshape(*bsos)
        return x



collapsing shape [10, 1, 2, 3, 4, 5]->[240, 5, 1]->[10, 1, 2, 3, 4, 100]


torch.Size([10, 1, 2, 3, 4, 100])

In [189]:
shape = [10, 10, 10, 300]
X = torch.randn(10, *shape)
lin1 = DimLinear(shape, 3, out_features=301)
lin2 = dim_models.DimLinear(300, 301, shape=shape, dim_to_mix=3)
lin3 = nn.Linear(300, 301)

In [204]:
%%time
a = lin1(X)

collapsing shape [10, 10, 10, 10, 300]->[10000, 300, 1]->[10, 10, 10, 10, 301]
CPU times: user 79.5 ms, sys: 27.1 ms, total: 107 ms
Wall time: 15 ms


In [205]:
%%time
a = lin2(X)

CPU times: user 1.32 s, sys: 1.58 s, total: 2.9 s
Wall time: 352 ms


In [206]:
%%time
a = lin3(X)

CPU times: user 39.6 ms, sys: 0 ns, total: 39.6 ms
Wall time: 2.95 ms


In [5]:
import dim_models

In [126]:
[p.shape for p in d3.parameters()]

[torch.Size([100, 99, 1, 1, 1]), torch.Size([100])]

In [92]:
M = torch.randn(99, 100)
d0 = nn.Linear(99, 100)
d1 = nn.Conv1d(99, 100, 1)
d2 = nn.Conv2d(99, 100, 1)
d3 = nn.Conv3d(99, 100, 1)

In [113]:
X1 = torch.randn(50, 60, 70, 99, 1)
X2 = torch.randn(50, 60, 99, 1, 70)
X3 = torch.randn(50, 99, 1, 60, 70)
X4 = torch.randn(99, 1, 50, 60, 70)

In [114]:
%%time
(M.T@X1).shape

CPU times: user 106 ms, sys: 71.6 ms, total: 177 ms
Wall time: 19.5 ms


torch.Size([50, 60, 70, 100, 1])

In [115]:
%%time
d0(X1[..., 0]).shape

CPU times: user 173 ms, sys: 66.1 ms, total: 239 ms
Wall time: 25.9 ms


torch.Size([50, 60, 70, 100])

In [116]:
%%time
(M*X1).sum(dim=-2, keepdim=True).shape

CPU times: user 2.98 s, sys: 3.45 s, total: 6.43 s
Wall time: 789 ms


torch.Size([50, 60, 70, 1, 100])

In [117]:
%%time
(M[:, :, None]*X2).sum(dim=-3, keepdim=True).shape

CPU times: user 3.69 s, sys: 2.86 s, total: 6.56 s
Wall time: 804 ms


torch.Size([50, 60, 1, 100, 70])

In [118]:
%%time
d1(X2[..., 0, :].reshape(-1, 99, 70)).reshape(50, 60, 100, 70).shape

CPU times: user 204 ms, sys: 177 ms, total: 381 ms
Wall time: 39 ms


torch.Size([50, 60, 100, 70])

In [119]:
%%time
d2(X3[..., 0, :, :].reshape(-1, 99, 60, 70)).reshape(50, 100, 60, 70).shape

CPU times: user 218 ms, sys: 183 ms, total: 400 ms
Wall time: 42.7 ms


torch.Size([50, 100, 60, 70])

In [120]:
%%time
d3(X4[..., 0, :, :, :].reshape(-1, 99, 50, 60, 70)).reshape(100, 50, 60, 70).shape

CPU times: user 108 ms, sys: 102 ms, total: 210 ms
Wall time: 24 ms


torch.Size([100, 50, 60, 70])

In [77]:
x = torch.randn(10, 10, 10, 10, 50)

In [72]:
conv = nn.Conv1d(99, 100, 1)

In [75]:
conv(torch.randn(1,77, 99, 30)).shape

RuntimeError: Expected 3-dimensional input for 3-dimensional weight [100, 99, 1], but got 4-dimensional input of size [1, 77, 99, 30] instead