In [1]:
import numpy as np
import tensorflow as tf
import torch
from scnn.bipolar_functions.operations import *
from scnn.convertors import F2S, S2F

In [2]:
seq_len = 16384
f2s = F2S(seq_len=seq_len)
s2f = S2F()

In [3]:
class MUXScaled_Linear():
    def __init__(self, 
                 in_features,
                 out_features,
                 seq_len,
                 bias=False) -> None:
        super().__init__()

        self.down_scale = in_features
        self.in_features = in_features
        self.out_features = out_features
        self.seq_len = seq_len
        self.is_bias = bias
        self.weight = None
        self.bias = None

    def load_weight(self, data):
        assert data['weight'].shape == torch.Size([self.in_features, self.out_features])

        self.weight = F2S(seq_len=self.seq_len)(data['weight'])

        if self.is_bias:
            assert data['bias'].shape == torch.Size([self.out_features])
            self.bias = F2S(seq_len=self.seq_len)(data['bias'])
            if len(self.bias.shape) <=2 : 
                self.bias = self.bias.unsqueeze(dim=0)
            self.weight = torch.cat([self.weight, self.bias], dim=0)

    def forward(self, inputs):
        assert self.weight is not None
        if self.is_bias:
            bias_input = torch.ones(1,self.seq_len, dtype=torch.bool).expand(inputs.shape[:-2]+(1,-1,))
            inputs = torch.concat([inputs, bias_input],dim=-2)

        out = scaled_matmul(inputs, self.weight)

        return out

In [4]:
model = MUXScaled_Linear(8,2,seq_len)

In [5]:
weight = (torch.rand(8,2)*2-1)
inputs = torch.rand(3,8)*2-1

model.load_weight({'weight':weight})

In [6]:
s2f(model.forward(f2s(inputs)))*8

tensor([[-1.6377,  2.1631],
        [ 0.8271, -0.1143],
        [ 0.9004,  1.4941]])

In [7]:
inputs @ weight

tensor([[-1.7563,  2.1296],
        [ 0.7746, -0.1193],
        [ 0.9253,  1.6423]])

In [17]:
class MUXScaled_LinearAct():
    def __init__(self, 
                 scalar,
                 in_features,
                 out_features,
                 seq_len,
                 bias=False) -> None:
        super().__init__()

        self.scalar = scalar
        self.down_scale = in_features
        self.in_features = in_features
        self.out_features = out_features
        self.seq_len = seq_len
        self.is_bias = bias
        self.weight = None
        self.bias = None

        if self.scalar < 1/in_features: 
            raise NotImplementedError
        
    def load_weight(self, data):
        print( torch.Size([self.in_features, self.out_features]))
        assert data['weight'].shape == torch.Size([self.in_features, self.out_features])

        self.weight = F2S(seq_len=self.seq_len)(data['weight'])

        if self.is_bias:
            assert data['bias'].shape == torch.Size([self.out_features])
            self.bias = F2S(seq_len=self.seq_len)(data['bias'])
            if len(self.bias.shape) <=2 : 
                self.bias = self.bias.unsqueeze(dim=0)
            self.weight = torch.cat([self.weight, self.bias], dim=0)

    def forward(self, inputs):
        assert self.weight is not None
        if self.is_bias:
            bias_input = torch.ones(1,self.seq_len, dtype=torch.bool).expand(inputs.shape[:-2]+(1,-1,))
            inputs = torch.concat([inputs, bias_input],dim=-2)

        out = scaled_matmul(inputs, self.weight)
        
        out = stanh(out, r= round(self.scalar*2*self.in_features))
        return out

In [18]:
model = MUXScaled_LinearAct(1,8,2,seq_len)

In [19]:
weight = (torch.rand(8,2)*2-1)
inputs = torch.rand(3,8)*2-1

model.load_weight({'weight':weight})

torch.Size([8, 2])


In [20]:
s2f(model.forward(f2s(inputs)))

tensor([[-0.7422,  0.8115],
        [-0.8237,  0.0460],
        [ 0.9768, -0.7980]])

In [21]:
np.tanh(inputs @ weight)

tensor([[-0.7024,  0.8127],
        [-0.7792,  0.1834],
        [ 0.9648, -0.8041]])