MOE

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Linear(nn.Module):
    def __init__(self, in_features, out_features):
        super(Linear, self).__init__()
        self.fc = nn.Linear(in_features, out_features)
        
    def forward(self, x):
        return self.fc(x)
    
class MoELayer(nn.Module):
    def __init__(self, num_experts, in_features, out_features):
        super(MoELayer, self).__init__()
        self.num_experts = num_experts
        self.experts = nn.ModuleList([Linear(in_features, out_features) for _ in range(num_experts)])
        self.gate = nn.Linear(in_features, num_experts)
        
    def forward(self, x):
        gate_score = F.softmax(self.gate(x), dim=-1)
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
        output = torch.bmm(gate_score.unsqueeze(1), expert_outputs).squeeze(1)
        return output

# Example usage 
# num_experts = 4
# in_features = 5
# out_features = 3
# batch_size = 10
# model = MoELayer(num_experts, in_features, out_features)

# demo = torch.randn(batch_size, in_features)
# output = model(demo)
# print(output.shape)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class LoRALinear(nn.Module):
    def __init__(self, in_features=32, out_features=64, merge=True, rank=16, lora_alpha=16,dropout=0.5):
        super(LoRALinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.merge = merge
        self.rank = rank
        self.lora_alpha = lora_alpha
        self.dropout_rate = dropout

        self.linear = nn.Linear(in_features, out_features)
        
        if rank > 0:
            self.lora_b = nn.Parameter(torch.zeros(out_features, rank))
            self.lora_a = nn.Parameter(torch.zeros(rank, in_features)) 
            self.scale = self.lora_alpha / self.rank
            self.linear.weight.requires_frad = False

        if self.dropout_rate > 0:
            self.dropout = nn.Dropout(self.dropout_rate)
        else:
            self.dropout = nn.Identity()
        
        self.initial_weights()

    def initial_weights(self):
        nn.init.kaiming_uniform_(self.lora_a, a=math.sqrt(5))
        nn.init.zeros_(self.lora_b)

    def forward(self, x):
        if self.rank > 0 and self.merge:
            output = F.linear(x, self.linear.weight + self.lora_b @ self.lora_a * self.scale, self.linear.bias)
            output = self.dropout(output)
            return output
        else:
            return self.dropout(self.linear(x))

class LoRAMoELayer(nn.Module):
    def __init__(self, num_experts=4, in_features=16, out_features=32, merge=True, rank=16, lora_alpha=16,dropout=0.5):
        super(LoRAMoELayer, self).__init__()
        self.num_experts = num_experts
        self.experts = nn.ModuleList([LoRALinear(in_features, out_features, merge, rank, lora_alpha,dropout) for _ in range(num_experts)])
        self.gate = LoRALinear(in_features, num_experts, False, rank, lora_alpha,dropout)
        
    def forward(self, x):
        gate_score = F.softmax(self.gate(x), dim=-1)
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
        output = torch.bmm(gate_score.unsqueeze(1), expert_outputs).squeeze(1)
        return output
    
# Example usage 
num_experts = 4
in_features = 32
out_features = 32
rank = 16
batch_size = 10
model = LoRAMoELayer(num_experts, in_features, out_features,merge=True)

demo = torch.randn(batch_size, in_features)
output = model(demo)
print(output.shape)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
t = LoRAMoELayer().to(device)
print(t)


torch.Size([10, 32])
LoRAMoELayer(
  (experts): ModuleList(
    (0-3): 4 x LoRALinear(
      (linear): Linear(in_features=16, out_features=32, bias=True)
      (dropout): Dropout(p=0.5, inplace=False)
    )
  )
  (gate): LoRALinear(
    (linear): Linear(in_features=16, out_features=4, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
)
