In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
loraA = nn.Linear(512, 2, bias=False)
loraB = nn.Linear(2, 1024, bias=False)

print(loraA.weight.shape, loraB.weight.shape)

torch.Size([2, 512]) torch.Size([1024, 2])


In [5]:
moe_weight = nn.Linear(512, 2, bias=False)

print(moe_weight.weight.shape)

torch.Size([2, 512])


In [11]:
seq_in = torch.randn(4, 16, 512) # b n d

router_weight = moe_weight(seq_in) # b n r
router_weight = torch.softmax(router_weight, dim=-1) # b n r
router_weight = torch.diag_embed(router_weight) # b n r r

lora_weight = loraB.weight @ router_weight @ loraA.weight # b n d k
print(lora_weight.shape)

torch.Size([4, 16, 1024, 512])


In [13]:
seq_out = torch.einsum('bnkd,bnd->bnk', lora_weight, seq_in)
print(seq_out.shape)

torch.Size([4, 16, 1024])


In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class loraLinear(nn.Linear):
    def __init__(self, in_features, out_features, k):
        super(loraLinear, self).__init__(in_features, out_features)

        self.weight.requires_grad = False

        self.loraA = nn.Linear(in_features, k, bias=False)
        self.loraB = nn.Linear(k, out_features, bias=False)
        
    def forward(self, x):

        org_result = F.linear(x, self.weight)
        lora_result = x @ self.loraA.weight.T @ self.loraB.weight.T
        return org_result + lora_result
    
loralinear_layer = loraLinear(4096, 4096, 16).to(device)

In [3]:
class loraLinear(nn.Linear):
    def __init__(self, in_features, out_features, rank):
        nn.Linear.__init__(self, in_features, out_features)
        
        self.weight.requires_grad = False

        self.loraA = nn.Linear(in_features, rank, bias=False)
        self.loraB = nn.Linear(rank, out_features, bias=False)
        self.router = nn.Linear(in_features, rank, bias=False)


    def forward(self, x):
        org_result = F.linear(x, self.weight) # b n d / b d

        router_weight = self.router(x) # b n r / b r
        router_weight = torch.softmax(router_weight, dim=-1) # b n r / b r
        router_weight = torch.diag_embed(router_weight) # b n r r / b r r

        left_result = x @ self.loraA.weight.T # b n r / b r

        if len(router_weight.size()) == 3:
            moe_result = torch.einsum('br, bkr->bk', left_result, router_weight)
        else:
            moe_result = torch.einsum('bnr, bnkr->bnk', left_result, router_weight)
        right_result = moe_result @ self.loraB.weight.T # b n k / b k

        # lora_weight = self.loraB.weight @ router_weight @ self.loraA.weight # b n k d / b k d
        
        # if len(lora_weight.size()) == 3:
        #     seq_out = torch.einsum('bkd,bd->bk', lora_weight, x)
        # else:
        #     seq_out = torch.einsum('bnkd,bnd->bnk', lora_weight, x)

        result = org_result + right_result

        return result
    

loralinear_layer = loraLinear(4096, 4096, 16).to(device)

In [4]:
input_tensor = torch.randn(1, 512, 4096).to(device)
output_tensor = loralinear_layer(input_tensor)

print(output_tensor.shape)

torch.Size([1, 512, 4096])


In [5]:
!nvidia-smi

Sat Sep  7 15:29:10 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        Off |   00000000:65:00.0 Off |                  N/A |
|  0%   29C    P2            105W /  370W |     472MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                