In [1]:
import sys
import os

main_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(main_dir)

import model_classes
from model_classes import *
import torch
from transformers import PretrainedConfig
import torch.nn as nn
import math
import copy

device = 'cuda' if torch.cuda.is_available() else 'cpu'


# device = 'cpu'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# input shape: [batch size, no segments, num_heads, head_dim]
# for every head, the router schould return weights for each expert, so:
# output shape: [bs, no seq, num heads, num experts]
class Router_mh_lori(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = int(config.hidden_size / config.num_MH_MOE_heads)
        self.num_experts = config.num_experts
        self.device = config.device
        self.expert_embeddings = nn.Parameter(torch.randn(self.hidden_size, self.num_experts))
        torch.nn.init.kaiming_uniform_(self.expert_embeddings, nonlinearity='linear')

    def forward(self, x):
        dot = torch.einsum("bshd,de->bshe", x, self.expert_embeddings)
        res = torch.nn.functional.softmax(dot, dim=-1)
        return res

In [3]:
class MH_Lori(nn.Module):
    def __init__(self, config):
        super(MH_Lori, self).__init__()
        self.config = config

        self.batch_size = config.batch_size
        self.hidden_dim = config.hidden_size
        self.seq_len = config.seq_len
        self.num_heads = config.num_MH_MOE_heads
        self.head_dim = int(config.hidden_size / config.num_MH_MOE_heads)
        self.no_segments = config.no_lori_segments
        self.segment_len = int(self.seq_len / self.no_segments)
        self.device = config.device
        

        self.router = Router_mh_lori(config).to(self.device)

        self.multi_head_layer = nn.Linear(self.hidden_dim, self.hidden_dim).to(self.device)
        self.merge_layer = nn.Linear(self.hidden_dim, self.hidden_dim).to(self.device)
        # Initialization
        nn.init.xavier_uniform_(self.multi_head_layer.weight, gain=1 / math.sqrt(2))
        nn.init.xavier_uniform_(self.merge_layer.weight)
        nn.init.constant_(self.merge_layer.bias, 0.0)

        self.num_experts = config.num_experts
        self.intermediate_size = config.intermediate_size

        self.first_linear = nn.Parameter(torch.randn((self.num_experts, self.intermediate_size, self.head_dim)))
        torch.nn.init.kaiming_uniform_(self.first_linear, nonlinearity='linear')
        self.second_linear = nn.Parameter(torch.randn((self.num_experts, self.head_dim, self.intermediate_size)))
        torch.nn.init.kaiming_uniform_(self.second_linear, nonlinearity='linear')

        self.to(self.device)


    def forward(self, x):
        #x.shape = [batch size, seq len, hidden dim]
        x = self.multi_head_layer(x) 
        #x.shape = [batch size, seq len, hidden dim]
        x = x.reshape(self.batch_size, self.seq_len, self.num_heads, self.head_dim).contiguous()
        #Dividing into lori segments
        x = x.reshape(self.batch_size, self.no_segments, self.segment_len, self.num_heads, self.head_dim).contiguous()
        average_segment_embedding = torch.mean(x, dim = 2).to(self.device)
        # average_segment_embedding.size = [batch size, no segments, num_heads, head_dim]
        expert_weights = self.router(average_segment_embedding).to(self.device)
        # expert_weights shape = [bs, no seq, num heads, num experts]
        # calculating merged experts
        expert_weights = expert_weights.reshape(self.num_experts, 1, 1, self.no_segments, self.num_heads, self.batch_size) 
        merged_experts_1 = self.first_linear.reshape(self.num_experts, self.intermediate_size, self.head_dim, 1, 1, 1)
        merged_experts_1 = (merged_experts_1 * expert_weights).sum(dim = 0)
        merged_experts_1 = merged_experts_1.reshape(self.batch_size, self.no_segments, self.num_heads, self.intermediate_size, self.head_dim)
        print(f'merged expert 1 parameter count: {torch.numel(merged_experts_1):,}')
        merged_experts_1 = merged_experts_1[:, :-1, :, :, :] #we discard the last segment as expert created for it is never used

        merged_experts_2 = self.second_linear.reshape(self.num_experts, self.head_dim, self.intermediate_size, 1, 1, 1)
        merged_experts_2 = (merged_experts_2 * expert_weights).sum(dim = 0)
        merged_experts_2 = merged_experts_2.reshape(self.batch_size, self.no_segments, self.num_heads, self.head_dim, self.intermediate_size)
        merged_experts_2 = merged_experts_2[:, :-1, :, :, :]
        
        # process x by experts
        x = x.reshape(self.batch_size, self.no_segments, self.num_heads, self.segment_len, self.head_dim).contiguous()
        x_causal = x[:, 1:, :, :, :]
        # process segments s>1 throuth which gradient flows
        result = torch.einsum("bnhld,bnhid->bnhli", x_causal, merged_experts_1)
        result = nn.functional.relu(result, inplace=False)
        result = torch.einsum("bnhli,bnhdi->bnhld", result, merged_experts_2)
        # process segment s=1 without gradient

        with torch.no_grad():
            segment_1 = x[:, 0, :, :, :]
            expert_segment_1 = merged_experts_1[:, 0, :, :, :]
            expert_segment_2 = merged_experts_2[:, 0, :, :, :]

            result_segment_1 = torch.einsum("bhld,bhid->bhli", segment_1, expert_segment_1)
            result_segment_1 = nn.functional.relu(result_segment_1, inplace=False)
            result_segment_1 = torch.einsum("bhli,bhdi->bhld", result_segment_1, expert_segment_2)

            result_segment_1 = result_segment_1.unsqueeze(1)

        # concatenate processed segments
        result = torch.cat((result_segment_1, result), dim = 1)

        # reshape back into orginal shape
        result = result.reshape(self.batch_size, self.no_segments * self.segment_len, self.num_heads, self.head_dim)
        result = result.reshape(self.batch_size, self.no_segments * self.segment_len, self.hidden_dim)
        result = self.merge_layer(result)

        return result
    
    def test_if_reshaping_works(self, x):
        input = x
        x = x.reshape(self.batch_size, self.seq_len, self.num_heads, self.head_dim).contiguous()
        x = x.reshape(self.batch_size, self.no_segments, self.segment_len, self.num_heads, self.head_dim).contiguous()
        x = x.reshape(self.batch_size, self.no_segments, self.num_heads, self.segment_len, self.head_dim).contiguous()
        result = x
        result = result.reshape(self.batch_size, self.no_segments * self.segment_len, self.num_heads, self.head_dim)
        result = result.reshape(self.batch_size, self.no_segments * self.segment_len, self.hidden_dim)
        print(torch.equal(result, input))


In [4]:
#test how it works when I conect few of these
class test_model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.loris = nn.ModuleList([MH_Lori(config) for _ in range(config.n_layers)])
    def forward(self, x):
        for i, l in enumerate(self.loris):
            x = l(x)
        return x
    
# mh_lori = test_model(config)

In [5]:
# with torch.no_grad():
#     test_input = torch.rand((config.batch_size, config.seq_len, config.hidden_size)).to(device)
#     print('input shape: ', test_input.shape)
#     output = mh_lori(test_input)
#     print(output.shape)

# test_input = torch.rand((config.batch_size, config.seq_len, config.hidden_size)).to(device)
# print('input shape: ', test_input.shape)
# output = mh_lori(test_input)
# print(output.shape)

In [6]:
import torch

def get_gpu_memory():
    if torch.cuda.is_available():
        total_memory = torch.cuda.get_device_properties(0).total_memory
        reserved_memory = torch.cuda.memory_reserved(0)
        allocated_memory = torch.cuda.memory_allocated(0)
        free_memory = reserved_memory - allocated_memory

        print(f"Total GPU memory: {total_memory / 1e9} GB")
        print(f"Reserved GPU memory: {reserved_memory / 1e9} GB")
        print(f"Allocated GPU memory: {allocated_memory / 1e9} GB")
        print(f"Free GPU memory: {free_memory / 1e9} GB")
    else:
        print("No GPU available.")

get_gpu_memory()
# output = output.detach()
# torch.cuda.empty_cache()


Total GPU memory: 12.8843776 GB
Reserved GPU memory: 0.0 GB
Allocated GPU memory: 0.0 GB
Free GPU memory: 0.0 GB


In [7]:
def estimate_model_size(model):
    total_params = sum(p.numel() for p in model.parameters())
    total_buffers = sum(b.numel() for b in model.buffers())
    
    # Assuming float32 (4 bytes) for parameters and buffers
    model_size = (total_params + total_buffers) * 4
    print(f"Estimated Model Size: {model_size / (1024 ** 2):.2f} MB, total number of parameters: {model_size:,}")
    return model_size


# model = mh_lori


# model_size = estimate_model_size(model)
# print(f"Estimated Model Size: {model_size / (1024 ** 2):.2f} MB")


In [8]:
config_big= PretrainedConfig(
    num_experts_per_token=2,
    hidden_size=512,
    num_attention_heads = 8,
    num_MH_MOE_heads = 4,
    num_experts=8,
    batch_size = 1,
    seq_len = 1024,
    capacity_factor = 8,
    device = device,
    intermediate_size = 1024,
    forward_layer_class = MH_Lori,
    vocab_size = 30000,
    n_layers = 8,
    no_lori_segments = 64,
)
config_small= PretrainedConfig(
    num_experts_per_token=2,
    hidden_size=1024,
    num_attention_heads = 8,
    num_MH_MOE_heads = 4,
    num_experts=8,
    batch_size = 1,
    seq_len = 1024,
    capacity_factor = 8,
    device = device,
    intermediate_size = 2048,
    forward_layer_class = MH_Lori,
    vocab_size = 30000,
    n_layers = 8,
    no_lori_segments = 64,
)

In [9]:
config = config_small
mh_lori = Transformer(config_small).to(config_small.device)

In [10]:
estimate_model_size(mh_lori)
estimate_model_size(mh_lori.layers)
estimate_model_size(mh_lori.layers[0].forward_layer)
# estimate_model_size(mh_lori_singular)

Estimated Model Size: 682.77 MB, total number of parameters: 715,936,960
Estimated Model Size: 448.28 MB, total number of parameters: 470,056,960
Estimated Model Size: 40.02 MB, total number of parameters: 41,959,424


41959424

In [12]:
test_input = torch.randint(0, config.vocab_size, (config.batch_size, config.seq_len)).to(config.device)

# with torch.no_grad():
    # output = mh_lori(test_input)

output = mh_lori(test_input)
# output = output.detach()
get_gpu_memory()

merged expert 1 parameter count: 134,217,728
merged expert 1 parameter count: 134,217,728
merged expert 1 parameter count: 134,217,728
merged expert 1 parameter count: 134,217,728
merged expert 1 parameter count: 134,217,728
merged expert 1 parameter count: 134,217,728
merged expert 1 parameter count: 134,217,728
merged expert 1 parameter count: 134,217,728
Total GPU memory: 12.8843776 GB
Reserved GPU memory: 15.384707072 GB
Allocated GPU memory: 10.456372736 GB
Free GPU memory: 4.928334336 GB


In [None]:
for m in mh_lori.parameters():
    print(type(m), m.shape, m.is_cuda)

<class 'torch.nn.parameter.Parameter'> torch.Size([100, 128]) True
<class 'torch.nn.parameter.Parameter'> torch.Size([128]) True
<class 'torch.nn.parameter.Parameter'> torch.Size([128]) True
<class 'torch.nn.parameter.Parameter'> torch.Size([128]) True
<class 'torch.nn.parameter.Parameter'> torch.Size([128]) True
<class 'torch.nn.parameter.Parameter'> torch.Size([8, 256, 32]) True
<class 'torch.nn.parameter.Parameter'> torch.Size([8, 32, 256]) True
<class 'torch.nn.parameter.Parameter'> torch.Size([32, 8]) True
<class 'torch.nn.parameter.Parameter'> torch.Size([128, 128]) True
<class 'torch.nn.parameter.Parameter'> torch.Size([128]) True
<class 'torch.nn.parameter.Parameter'> torch.Size([128, 128]) True
<class 'torch.nn.parameter.Parameter'> torch.Size([128]) True
<class 'torch.nn.parameter.Parameter'> torch.Size([128, 128]) True
<class 'torch.nn.parameter.Parameter'> torch.Size([128, 128]) True
<class 'torch.nn.parameter.Parameter'> torch.Size([128, 128]) True
<class 'torch.nn.paramet