In [7]:
import sys
import os

main_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(main_dir)

import model_classes
from model_classes import *
import torch
from transformers import PretrainedConfig
import torch.nn as nn
import math
import copy

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [67]:
config = PretrainedConfig(
    num_experts_per_token=2,
    hidden_size=128,
    num_attention_heads = 16,
    num_MH_MOE_heads = 4,
    num_experts=32,
    batch_size = 20,
    seq_len = 32,
    capacity_factor = 8,
    device = device,
    intermediate_size = 64,
    forward_layer_class = VectorizedMoE,
    vocab_size = 1000,
    n_layers = 8,
    no_lori_segments = 16,
)

In [42]:
# input shape: [batch size, no segments, num_heads, head_dim]
# for every head, the router schould return weights for each expert, so:
# output shape: [bs, no seq, num heads, num experts]
class Router_mh_lori(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = int(config.hidden_size / config.num_MH_MOE_heads)
        self.num_experts = config.num_experts
        self.expert_embeddings = nn.Parameter(torch.randn(self.hidden_size, self.num_experts)).to(config.device)
        torch.nn.init.kaiming_uniform_(self.expert_embeddings, nonlinearity='linear')

    def forward(self, x):
        dot = torch.einsum("bshd,de->bshe", x, self.expert_embeddings)
        res = torch.nn.functional.softmax(dot, dim=-1)
        return res

In [76]:
class MH_Lori(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.batch_size = config.batch_size
        self.hidden_dim = config.hidden_size
        self.seq_len = config.seq_len
        self.num_heads = config.num_MH_MOE_heads
        self.head_dim = int(config.hidden_size / config.num_MH_MOE_heads)
        self.no_segments = config.no_lori_segments
        self.segment_len = int(self.seq_len / self.no_segments)

        self.router = Router_mh_lori(config)

        self.multi_head_layer = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.merge_layer = nn.Linear(self.hidden_dim, self.hidden_dim)
        # Initialization
        nn.init.xavier_uniform_(self.multi_head_layer.weight, gain=1 / math.sqrt(2))
        nn.init.xavier_uniform_(self.merge_layer.weight)
        nn.init.constant_(self.merge_layer.bias, 0.0)

        self.num_experts = config.num_experts
        self.intermediate_size = config.intermediate_size

        self.first_linear = nn.Parameter(torch.randn(self.num_experts, self.intermediate_size, self.head_dim)).to(config.device)
        torch.nn.init.kaiming_uniform_(self.first_linear, nonlinearity='linear')
        self.second_linear = nn.Parameter(torch.randn(self.num_experts, self.head_dim, self.intermediate_size)).to(config.device)
        torch.nn.init.kaiming_uniform_(self.second_linear, nonlinearity='linear')

    def forward(self, x):
        #x.shape = [batch size, seq len, hidden dim]
        x = self.multi_head_layer(x) 
        #x.shape = [batch size, seq len, hidden dim]
        x = x.reshape(self.batch_size, self.seq_len, self.num_heads, self.head_dim).contiguous()
        #Dividing into lori segments
        x = x.reshape(self.batch_size, self.no_segments, self.segment_len, self.num_heads, self.head_dim).contiguous()
        #calculating routing weights
        average_segment_embedding = torch.mean(x, dim = 2)
        # average_segment_embedding.size = [batch size, no segments, num_heads, head_dim]
        expert_weights = self.router(average_segment_embedding)
        # expert_weights shape = [bs, no seq, num heads, num experts]
        # calculating merged experts
        expert_weights = expert_weights.reshape(self.num_experts, 1, 1, self.no_segments, self.num_heads, self.batch_size) 
        merged_experts_1 = self.first_linear.reshape(self.num_experts, self.intermediate_size, self.head_dim, 1, 1, 1)
        merged_experts_1 = (merged_experts_1 * expert_weights).sum(dim = 0)
        merged_experts_1 = merged_experts_1.reshape(self.batch_size, self.no_segments, self.num_heads, self.intermediate_size, self.head_dim)
        merged_experts_1 = merged_experts_1[:, :-1, :, :, :] #we discard the last segment as expert created for it is never used

        merged_experts_2 = self.second_linear.reshape(self.num_experts, self.head_dim, self.intermediate_size, 1, 1, 1)
        merged_experts_2 = (merged_experts_2 * expert_weights).sum(dim = 0)
        merged_experts_2 = merged_experts_2.reshape(self.batch_size, self.no_segments, self.num_heads, self.head_dim, self.intermediate_size)
        merged_experts_2 = merged_experts_2[:, :-1, :, :, :]
        
        # process x by experts
        x = x.reshape(self.batch_size, self.no_segments, self.num_heads, self.segment_len, self.head_dim).contiguous()
        x_causal = x[:, 1:, :, :, :]

        result = torch.einsum("bnhld,bnhid->bnhli", x_causal, merged_experts_1)
        result = nn.functional.relu(result, inplace=False)
        result = torch.einsum("bnhli,bnhdi->bnhld", result, merged_experts_2)

        # reshape back into orginal shape
        result = result.reshape(self.batch_size, (self.no_segments - 1) * self.segment_len, self.num_heads, self.head_dim)
        result = result.reshape(self.batch_size, (self.no_segments - 1) * self.segment_len, self.hidden_dim)
        result = self.merge_layer(result)

        return result




        # x = x.reshape(config.batch_size, config.seq_len, config.num_MH_MOE_heads,
        #                self.head_dim).reshape(config.batch_size, config.seq_len, config.hidden_size).contiguous()
        # x = self.merge_layer(x)
        # return x
        

test_input = torch.rand((config.batch_size, config.seq_len, config.hidden_size))
mh_lori = MH_Lori(config)
output = mh_lori(test_input)
print(output.shape)

torch.Size([20, 30, 128])


In [12]:
test_input = torch.rand((config.batch_size, config.seq_len, config.hidden_size))
mh_lori = MH_Lori(config)
output = mh_lori(test_input)
print(output.shape)

RuntimeError: einsum(): subscript h has size 64 for operand 1 which does not broadcast with previously seen size 8