In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [22]:
class ShuffleLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(ShuffleLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.W = nn.Parameter(torch.Tensor(out_features, in_features))
        self.b = nn.Parameter(torch.Tensor(out_features))
        self.init_parameters()

    def init_parameters(self):
        nn.init.normal_(self.W, std=0.1)
        nn.init.constant_(self.b, 0)

    def forward(self, x):
        print(f"Size of input tensor: {x.size()}\nSize of W tensor {self.W.size()}")
        
        shuffle_arr_T = self.W.T
        indices = torch.argsort(torch.rand_like(shuffle_arr_T), dim=-1)
        result = torch.gather(shuffle_arr_T, dim=-1, index=indices)

        print(f"Input: {x}\nWeights:\n{self.W}\nTransposed Weight:\n{shuffle_arr_T}\nIndices:\n{indices}\nShuffled T Weights:\n{result}\nShuffled Weights:\n{result.T}")

        # Perform the matrix multiplication
        output = torch.matmul(x, self.W.t())
        output = output + self.b
        return output
    

In [6]:
inputX = torch.Tensor(np.arange(1, 4, dtype=float))
print(inputX)

tensor([1., 2., 3.])


In [23]:
shuffle = ShuffleLinear(3, 4)
print(shuffle.forward(inputX))

Size of input tensor: torch.Size([3])
Size of W tensor torch.Size([4, 3])
Input: tensor([1., 2., 3.])
Weights:
Parameter containing:
tensor([[-0.0688, -0.1161, -0.0816],
        [ 0.0527,  0.0680,  0.0577],
        [ 0.1203,  0.0084,  0.1173],
        [-0.0609,  0.1054,  0.2306]], requires_grad=True)
Transposed Weight:
tensor([[-0.0688,  0.0527,  0.1203, -0.0609],
        [-0.1161,  0.0680,  0.0084,  0.1054],
        [-0.0816,  0.0577,  0.1173,  0.2306]], grad_fn=<PermuteBackward0>)
Indices:
tensor([[2, 1, 3, 0],
        [0, 1, 2, 3],
        [2, 3, 1, 0]])
Shuffled T Weights:
tensor([[ 0.1203,  0.0527, -0.0609, -0.0688],
        [-0.1161,  0.0680,  0.0084,  0.1054],
        [ 0.1173,  0.2306,  0.0577, -0.0816]], grad_fn=<GatherBackward0>)
Shuffled Weights:
tensor([[ 0.1203, -0.1161,  0.1173],
        [ 0.0527,  0.0680,  0.2306],
        [-0.0609,  0.0084,  0.0577],
        [-0.0688,  0.1054, -0.0816]], grad_fn=<PermuteBackward0>)
tensor([-0.5459,  0.3616,  0.4890,  0.8418], grad_fn=<A