In [16]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F

In [3]:
loraA = nn.Linear(512, 2, bias=False)
loraB = nn.Linear(2, 1024, bias=False)

print(loraA.weight.shape, loraB.weight.shape)

torch.Size([2, 512]) torch.Size([1024, 2])


In [5]:
moe_weight = nn.Linear(512, 2, bias=False)

print(moe_weight.weight.shape)

torch.Size([2, 512])


In [11]:
seq_in = torch.randn(4, 16, 512) # b n d

router_weight = moe_weight(seq_in) # b n r
router_weight = torch.softmax(router_weight, dim=-1) # b n r
router_weight = torch.diag_embed(router_weight) # b n r r

lora_weight = loraB.weight @ router_weight @ loraA.weight # b n d k
print(lora_weight.shape)

torch.Size([4, 16, 1024, 512])


In [13]:
seq_out = torch.einsum('bnkd,bnd->bnk', lora_weight, seq_in)
print(seq_out.shape)

torch.Size([4, 16, 1024])


In [21]:
class loraLinear(nn.Linear):
    def __init__(self, in_features, out_features, rank):
        nn.Linear.__init__(self, in_features, out_features)
        
        self.weight.requires_grad = False

        self.loraA = nn.Linear(in_features, rank, bias=False)
        self.loraB = nn.Linear(rank, out_features, bias=False)
        self.router = nn.Linear(in_features, 2, bias=False)


    def forward(self, x):
        org_result = F.linear(x, self.weight) # b n d / b d

        router_weight = self.router(x) # b n r / b r
        router_weight = torch.softmax(router_weight, dim=-1) # b n r / b r
        router_weight = torch.diag_embed(router_weight) # b n r r / b r r

        lora_weight = self.loraB.weight @ router_weight @ self.loraA.weight # b n k d / b k d
        
        if len(lora_weight.size()) == 3:
            seq_out = torch.einsum('bkd,bd->bk', lora_weight, x)
        else:
            seq_out = torch.einsum('bnkd,bnd->bnk', lora_weight, x)

        result = seq_out + org_result

        return result
    

loralinear_layer = loraLinear(512, 1024, 2)

In [24]:
input_tensor = torch.randn(4, 16, 512)
output_tensor = loralinear_layer(input_tensor)

print(output_tensor.shape)

torch.Size([4, 16, 1024])
