In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

In [6]:
class ShuffleLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(ShuffleLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.W = nn.Parameter(torch.Tensor(out_features, in_features))
        self.b = nn.Parameter(torch.Tensor(out_features))
        self.init_parameters()

    def init_parameters(self):
        nn.init.normal_(self.W, std=0.1)
        nn.init.constant_(self.b, 0)

    def forward(self, x):
        # Shuffle the input tensor and the layer kernel randomly
        indices = torch.randperm(x.size()[-1])
        shuffled_x = x[indices]
        shuffled_weight = self.W[:, indices]
        print(f"Size of input tensor: {x.size()}\nSize of W tensor {self.W.size()}")
        print(f"Input tensor: {x}\nWeight tensor: {self.W}")
        print(f"Shuffled indices: {indices}")
        print(f"Shuffled input: {shuffled_x}\nShuffled weighs: {shuffled_weight}")

        # Perform the matrix multiplication
        output = torch.matmul(shuffled_x, shuffled_weight.t())
        output_not_shuffled = torch.matmul(x, self.W.t())
        
        print(f"Shuffled output: {output}")
        print(f"output: {output_not_shuffled}")
        output = output + self.b
        return output
    

In [10]:
inputX = torch.Tensor(np.arange(1, 4, dtype=float))
print(inputX)

tensor([1., 2., 3.])


In [12]:
shuffle = ShuffleLinear(3, 2)
print(shuffle.forward(inputX))

Size of input tensor: torch.Size([3])
Size of W tensor torch.Size([2, 3])
Input tensor: tensor([1., 2., 3.])
Weight tensor: Parameter containing:
tensor([[-0.0211,  0.0158,  0.0458],
        [ 0.1995, -0.0278,  0.1617]], requires_grad=True)
Shuffled indices: tensor([1, 0, 2])
Shuffled input: tensor([2., 1., 3.])
Shuffled weighs: tensor([[ 0.0158, -0.0211,  0.0458],
        [-0.0278,  0.1995,  0.1617]], grad_fn=<IndexBackward0>)
tensor([0.1479, 0.6290], grad_fn=<SqueezeBackward3>)
tensor([0.1479, 0.6290], grad_fn=<SqueezeBackward3>)
tensor([0.1479, 0.6290], grad_fn=<AddBackward0>)
