In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class FactorizedLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(FactorizedLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        # Factorized weights: W = S * V
        self.W = nn.Parameter(torch.Tensor(out_features, in_features))
        self.S = nn.Parameter(torch.Tensor(out_features))
        self.V = nn.Parameter(torch.Tensor(out_features, in_features))

        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        # Glorot init for W
        nn.init.xavier_uniform_(self.W)
        # Normal sampling for weighting
        nn.init.normal_(self.S, mean=0.5, std=0.01)
        # nn.init.normal_(self.V, mean=,a=0.01)
        # Guarantee positive weighting factors with exp
        # self.S = nn.Parameter((torch.exp(self.S)))
        # Scale V accordingly
        # print("S: ", self.S.shape)
        # print("W: ", self.W.shape)
        # print("V: ", self.V.shape)
        s_exp = torch.exp(self.S)
        # print("s_exp: ", s_exp.shape)
        self.V = nn.Parameter(torch.div(self.W, s_exp.view(-1, 1)))


        assert torch.isclose(torch.mean(torch.matmul(torch.diag(s_exp), self.V)), torch.mean(self.W))

        if self.bias is not None:
            nn.init.zeros_(self.bias)

    def forward(self, input):
        # Compute the full weight matrix W = S * V
        s_exp_diag = torch.diag(torch.exp(self.S))
        weight = torch.matmul(s_exp_diag, self.V)
        # Perform the linear transformation
        return F.linear(input, weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

# Example usage:
# layer = FactorizedLinear(100, 50)
# input_tensor = torch.randn(64, 100) # Batch size 64, input features 100
# output_tensor = layer(input_tensor)
# print(output_tensor.shape)
# print(f"Number of parameters in full layer: {100 * 50}")

In [None]:
layer = FactorizedLinear(in_features=50, out_features=150)
input_tensor = torch.randn(64, 50) # Batch size 64, input features 100
output_tensor = layer(input_tensor)
print(output_tensor.shape)
print(layer)

torch.Size([64, 150])
FactorizedLinear(in_features=50, out_features=150, bias=True)


In [None]:
a = torch.Tensor([[1,1,1],[2,2,2]])
b = torch.Tensor([[2],[2]])
print(b.shape)
c = torch.div(a,b)
d = torch.matmul(torch.diag(b.flatten()),c)
print(c.shape)
print(d.shape)

torch.Size([2, 1])
torch.Size([2, 3])
torch.Size([2, 3])
