In [1]:
import sys
import os

main_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(main_dir)

import model_classes
from model_classes import *
from MH_Lori_poprawianie import *
from dataloader import *
import dataloader
from helper_functions import *
import torch
from transformers import PretrainedConfig
import torch.nn as nn
import math
import copy
import lightning.pytorch as pl
# from pytorch_lightning.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import ModelCheckpoint

device = 'cuda' if torch.cuda.is_available() else 'cpu'

  from .autonotebook import tqdm as notebook_tqdm


Kod miriam

In [2]:
# Input: [batch_size, seq_len, hidden_size] - input embeddings
# Output: [batch_size, seq_len, num_experts] - expert routing weights
class Router_miriam(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts_per_token = config.num_experts_per_token
        self.hidden_size = config.hidden_size
        self.num_experts = config.num_experts

        self.expert_embeddings = nn.Parameter(torch.randn(self.num_experts, self.hidden_size)).to(config.device)
        torch.nn.init.kaiming_uniform_(self.expert_embeddings, nonlinearity='linear')

    def forward(self, x):
        dot = torch.einsum("bsh,eh->bse", x, self.expert_embeddings)
        # top_k_out = torch.topk(dot, k=self.num_experts_per_token)
        # top_k = (float("-inf") * torch.ones_like(dot)).scatter_(dim=-1, index=top_k_out.indices, src=top_k_out.values)
        # res = torch.nn.functional.softmax(top_k, dim=-1)
        res = torch.nn.functional.softmax(dot, dim=-1)
        return res
    
import math

# Input: [batch_size, seq_len, hidden_size] - input embeddings
# Output: [batch_size, seq_len, hidden_size] - output embeddings
class MoE_Lory(nn.Module):
    """version which takes first not random tokens up to expert_capacity"""
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts = config.num_experts
        self.hidden_size = config.hidden_size
        self.num_experts_per_token = config.num_experts_per_token
        self.capacity_factor = config.capacity_factor
        self.T_segments=config.T_segments
        # You can change experts representation if you want
        # self.experts = nn.ModuleList([MLP(config) for _ in range(self.num_experts)])
        #not as above but as below instead so as to compare more easily with the vectorized version
        self.intermediate_size = config.intermediate_size
        self.first_linear = nn.Parameter(torch.randn(self.num_experts, self.intermediate_size, self.hidden_size))
        torch.nn.init.kaiming_uniform_(self.first_linear, nonlinearity='linear')
        self.second_linear = nn.Parameter(torch.randn(self.num_experts, self.hidden_size, self.intermediate_size))
        torch.nn.init.kaiming_uniform_(self.second_linear, nonlinearity='linear')

        self.router = Router_miriam(config)

    def compute_out(self, data,linear1,linear2):
        return linear2 @ torch.nn.functional.relu(linear1 @ data)

    def merge_expert(self, weights):
        num_exp,_,_=weights.shape
        # expanded_weights1 = torch.ones((num_exp, self.intermediate_size, self.hidden_size)) * weights
        # expanded_weights2 = torch.ones((num_exp, self.hidden_size, self.intermediate_size)) * weights
        # linear1 = expanded_weights1 @ self.first_linear
        # linear2 = expanded_weights1 @ self.second_linear
        weighted_first_linear = torch.sum(weights * self.first_linear, dim=0)
        weighted_second_linear = torch.sum(weights * self.second_linear, dim=0)
        return weighted_first_linear,weighted_second_linear

    def forward(self, x):
        batch_size, seq_len, hidden_size = x.shape
        #assert hidden_size == self.hidden_size
        # expert_capacity = math.ceil(batch_size * seq_len / self.num_experts * self.capacity_factor)
        result = torch.zeros_like(x)
        segment_size=seq_len//self.T_segments
        for i in range(batch_size):
            for t in range(self.T_segments):
                segment=x[i, t*segment_size:(t+1)*segment_size]
                #print("segment shape check:",segment.shape,"seq_len/T, hidden_dim")

                if t==0:
                  with torch.no_grad():
                    h_x=segment.sum(axis=0)/segment_size
                    h_start = h_x.unsqueeze(dim  = 0)
                    h_x=h_x.unsqueeze(0)
                    h_x=h_x.unsqueeze(0)
                    #print("h_x shapecheck:",h_x.shape)
                    print(f'h_x shape: {h_x.shape}, input routera')
                    old_weights = self.router(h_x)
                    print(f'output routera: {old_weights.shape}')
                    log_w = old_weights
                    # print(old_weights.shape)
                    old_weights=old_weights.permute(2,0,1)
                    #print("old_weights shapecheck:",old_weights.shape)
                    merged_linear1,merged_linear2=self.merge_expert(old_weights)
                  for j in range(segment_size):
                    result[i, t*segment_size+j] = self.compute_out(x[i, t*segment_size+j],merged_linear1,merged_linear2)
                else:
                  h_x=segment.sum(axis=0)/segment_size
                  
                  h_x=h_x.unsqueeze(0)
                  h_start = torch.cat((h_start, h_x), dim = 0)

                  h_x=h_x.unsqueeze(0)
                  weights=self.router(h_x)
                  log_w = torch.cat((log_w, weights), dim = 0)
                  weights=weights.permute(2,0,1)
                  #print("weights shapecheck:",weights.shape)
                  merged_linear1,merged_linear2=self.merge_expert(old_weights)
                  old_weights=weights
                  for j in range(segment_size):
                    result[i, t*segment_size+j] = self.compute_out(x[i, t*segment_size+j],merged_linear1,merged_linear2)

        return result, h_start, log_w ################################### TEN KOD JEST ZLY!

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_config = dict(
    vocab_size=5000,
    max_position_embeddings=None,#256
    num_attention_heads=8,
    num_hidden_layers=None,
    hidden_dropout_prob=0.1,
    hidden_size=128,
    intermediate_size=512,
    num_labels=2,#2 co to robi?
    device = DEVICE #I added this one
)
moe_config = PretrainedConfig(
    **base_config,
    T_segments=5,
    num_experts=2,
    capacity_factor=None, #2.0
    num_experts_per_token=None,#1
    ff_cls=MoE_Lory
)

config_small= PretrainedConfig(
    num_experts_per_token=2,
    hidden_size=128,
    num_attention_heads = 8,
    num_MH_MOE_heads = 1,
    num_experts=2,
    batch_size = 1,
    seq_len = 20,
    capacity_factor = 3,
    device = device,
    intermediate_size = 512,
    forward_layer_class = MH_Lori,
    vocab_size = 5000,
    n_layers = 8,
    no_lori_segments = 5,
    py_lightning_loging = False,
    loss_fn = torch.nn.CrossEntropyLoss(),
    lr = 0.0006, #SET TO 0.0002
    betas = (0.9, 0.95),
    treat_mh_lori_as_regular_lori = True,
)

In [4]:
lory = MoE_Lory(moe_config).to(DEVICE)
lory.eval()
batch_size, seq_len, hidden_size=1,20,128

input = torch.randn((batch_size, seq_len, hidden_size)).to(DEVICE) * 10
v, h_start, w_m = lory(input)
print(v.shape)

h_x shape: torch.Size([1, 1, 128]), input routera
output routera: torch.Size([1, 1, 2])
torch.Size([1, 20, 128])


In [5]:
mh_lori = MH_Lori(config_small).to(config_small.device)
mh_lori.router.expert_embeddings.data = torch.transpose(lory.router.expert_embeddings.data, 0, 1)
mh_lori.first_linear.data = lory.first_linear.data
mh_lori.second_linear.data = lory.second_linear.data

In [6]:
mh_lori.eval()
o, h_antoni, w_a = mh_lori(input)
print(o.shape)

avarage segment embeding shape = torch.Size([1, 5, 1, 128]) [batch size, no segments, num_heads, head_dim] (to jest input routera)
expert_weights shape = torch.Size([1, 5, 1, 2]) [bs, no seq, num heads, num experts]
torch.Size([1, 20, 128])


In [7]:
print(torch.equal(o, v), torch.max(abs(o - v)).detach().item())

False 0.14374220371246338


In [11]:
#Router test
for i in range(5):
    router_a = Router_mh_lori(config_small)
    router_m = Router_miriam(moe_config)
    router_a.expert_embeddings.data = torch.transpose(router_m.expert_embeddings.data, 0, 1)
    input = torch.randn((batch_size, seq_len, hidden_size)).to(DEVICE)
    o_m = router_m(input).squeeze()
    o_a = router_a(input.unsqueeze(dim = 2)).squeeze()
    print(o_m.shape, o_a.shape)
    print(torch.equal(o_m, o_a))

#ROUTERY ROBIĄ TO SAMO, również na batch size 1 po malym mieszaniu wymiarow

torch.Size([20, 2]) torch.Size([20, 2])
True
torch.Size([20, 2]) torch.Size([20, 2])
True
torch.Size([20, 2]) torch.Size([20, 2])
True
torch.Size([20, 2]) torch.Size([20, 2])
True
torch.Size([20, 2]) torch.Size([20, 2])
True


In [9]:
h_start.shape

torch.Size([5, 128])

In [12]:
h_antoni[0].squeeze().shape

torch.Size([5, 128])

In [14]:
torch.equal(h_start, h_antoni[0].squeeze())

True

In [16]:
print(torch.transpose(w_m.squeeze(), 1, 0).shape, w_a.squeeze()[:, :].shape)

torch.Size([2, 5]) torch.Size([2, 5])


In [17]:
torch.equal(torch.transpose(w_m.squeeze(), 1, 0), w_a.squeeze()[:, :])

False