In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [4]:
class ShuffleLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(ShuffleLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.W = nn.Parameter(torch.Tensor(out_features, in_features))
        self.b = nn.Parameter(torch.Tensor(out_features))
        self.init_parameters()

    def init_parameters(self):
        nn.init.normal_(self.W, std=0.1)
        nn.init.constant_(self.b, 0)

    def forward(self, x):
        print(f"Size of input tensor: {x.size()}\nSize of W tensor {self.W.size()}")
        
        indices = torch.argsort(torch.rand_like(self.W.T), dim=-1)
        shuffled_W_T = torch.gather(self.W.T, dim=-1, index=indices)

        print(f"Input: {x}\nWeights:\n{self.W}\nTransposed Weight:\n{self.W.T}\nIndices:\n{indices}\nShuffled T Weights:\n{shuffled_W_T}\nShuffled Weights:\n{shuffled_W_T.T}")

        # Perform the matrix multiplication
        output = torch.matmul(x, self.W.t())
        output += self.b
        shuffled_output = torch.matmul(x, shuffled_W_T)
        shuffled_output += self.b
        print(f"Normal Output:\n{output}\nShuffled output:\n{shuffled_output}")
        return output
    

In [5]:
inputX = torch.Tensor(np.arange(1, 4, dtype=float))
print(inputX)

tensor([1., 2., 3.])


In [6]:
shuffle = ShuffleLinear(3, 4)
shuffle.forward(inputX)

Size of input tensor: torch.Size([3])
Size of W tensor torch.Size([4, 3])
Input: tensor([1., 2., 3.])
Weights:
Parameter containing:
tensor([[ 0.0476, -0.0828, -0.1888],
        [-0.1206, -0.0245,  0.2130],
        [-0.0552,  0.0270, -0.0087],
        [-0.1820,  0.1950,  0.0184]], requires_grad=True)
Transposed Weight:
tensor([[ 0.0476, -0.1206, -0.0552, -0.1820],
        [-0.0828, -0.0245,  0.0270,  0.1950],
        [-0.1888,  0.2130, -0.0087,  0.0184]], grad_fn=<PermuteBackward0>)
Indices:
tensor([[1, 0, 3, 2],
        [1, 0, 3, 2],
        [1, 2, 0, 3]])
Shuffled T Weights:
tensor([[-0.1206,  0.0476, -0.1820, -0.0552],
        [-0.0245, -0.0828,  0.1950,  0.0270],
        [ 0.2130, -0.0087, -0.1888,  0.0184]], grad_fn=<GatherBackward0>)
Shuffled Weights:
tensor([[-0.1206, -0.0245,  0.2130],
        [ 0.0476, -0.0828, -0.0087],
        [-0.1820,  0.1950, -0.1888],
        [-0.0552,  0.0270,  0.0184]], grad_fn=<PermuteBackward0>)
Normal Output:
tensor([-0.6842,  0.4694, -0.0271,  0.26

tensor([-0.6842,  0.4694, -0.0271,  0.2633], grad_fn=<AddBackward0>)

In [35]:
sm1 = ShuffleLinear(3, 4)
sm2 = ShuffleLinear(4, 2)
sm2.forward(sm1.forward(inputX))

Size of input tensor: torch.Size([3])
Size of W tensor torch.Size([4, 3])
Input: tensor([1., 2., 3.])
Weights:
Parameter containing:
tensor([[ 0.0170,  0.1478,  0.1312],
        [-0.1592,  0.0984,  0.0042],
        [ 0.0364,  0.1263, -0.1015],
        [ 0.0037, -0.0235, -0.0238]], requires_grad=True)
Transposed Weight:
tensor([[ 0.0170, -0.1592,  0.0364,  0.0037],
        [ 0.1478,  0.0984,  0.1263, -0.0235],
        [ 0.1312,  0.0042, -0.1015, -0.0238]], grad_fn=<PermuteBackward0>)
Indices:
tensor([[2, 0, 3, 1],
        [3, 1, 2, 0],
        [1, 2, 3, 0]])
Shuffled T Weights:
tensor([[ 0.0364,  0.0170,  0.0037, -0.1592],
        [-0.0235,  0.0984,  0.1263,  0.1478],
        [ 0.0042, -0.1015, -0.0238,  0.1312]], grad_fn=<GatherBackward0>)
Shuffled Weights:
tensor([[ 0.0364, -0.0235,  0.0042],
        [ 0.0170,  0.0984, -0.1015],
        [ 0.0037,  0.1263, -0.0238],
        [-0.1592,  0.1478,  0.1312]], grad_fn=<PermuteBackward0>)
Normal Output:
tensor([ 0.7062,  0.0503, -0.0155, -0.11

tensor([0.0355, 0.0230], grad_fn=<AddBackward0>)