In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init
from datasets import load_from_disk

# hyperparameters
batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 1024 # what is the maximum context length for predictions
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# eval_iters = 400
n_embed = 256
n_head = 8
head_size = n_embed // n_head
n_layer = 4
dropout = 0.2
num_experts = 8
top_k = 2
# ------------

torch.manual_seed(1337)
#D:\Downloads\DS_AI\VDT\MoE\moe\abtract_summarization
train_dataset = load_from_disk("abtract_summarization/train")
eval_dataset = load_from_disk("abtract_summarization/eval")
test_dataset = load_from_disk("abtract_summarization/test")


  from pandas.core import (


In [3]:
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

## Model

In [4]:
tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
vocab_size = tokenizer.vocab_size 
vocab_size

36096

In [None]:
# class Head(nn.Module):
#     """ one head of self-attention """

#     def __init__(self, head_size):
#         super().__init__()
#         self.head_size = head_size
#         self.key = nn.Linear(n_embed, head_size, bias=False)
#         self.query = nn.Linear(n_embed, head_size, bias=False)
#         self.value = nn.Linear(n_embed, head_size, bias=False)
#         #self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

#         self.dropout = nn.Dropout(dropout)
#         self.k_cache =None 
#         self.v_cache = None 
#         self.cache_index = 0

#     def forward(self, x):
#         B,T,C = x.shape
#         k = self.key(x)   # (B,T,C)
#         q = self.query(x) # (B,T,C)
#         v = self.value(x) # (B,T,C)
#         if self.training:
#             if self.k_cache is None or self.v_cache is None: 
#                 self.k_cache = torch.zeros(B, block_size, self.head_size, device = x.device)
#                 self.v_cache = torch.zeros(B, block_size, self.head_size, device = x.device)
#                 self.cache_index = 0
            
#             if self.cache_index + T <= block_size: 
#                 self.k_cache[:, self.cache_index: self.cache_index + T, :] = k 
#                 self.v_cache[:, self.cache_index: self.cache_index + T, :] = v
#             else: 
#                 shift = self.cache_index + T - block_size 
#                 self.k_cache[:, :-shift, :] = self.k_cache[:, shift:, :].clone()
#                 self.v_cache[:, :-shift, :] = self.v_cache[:, shift:, :].clone()
#                 self.k_cache[:, -T:, :] = k 
#                 self.v_cache[:, -T:, :] = v
            
#             self.cache_index = min(self.cache_index + T, block_size)
#             k_used = self.k_cache
#             v_used = self.v_cache
#         else:
#             k_used = k 
#             v_used = v
#         # compute attention scores ("affinities")
#         wei = q @ k_used.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
#         tril = torch.tril(torch.ones(T, k_used.shape[1], device=x.device))
#         wei = wei.masked_fill(tril == 0, float('-inf')) # (B, T, T)
#         wei = F.softmax(wei, dim=-1) # (B, T, T)
#         wei = self.dropout(wei)
#         # perform the weighted aggregation of the values
#         out = wei @ v_used # (B, T, T) @ (B, T, C) -> (B, T, C)
#         del q 
#         del k 
#         del v
#         del wei
#         return out

# #Multi-Headed Self Attention
# class MultiHeadAttention(nn.Module):
#     """ multiple heads of self-attention in parallel """

#     def __init__(self, num_heads, head_size):
#         super().__init__()
#         self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
#         self.proj = nn.Linear(n_embed, n_embed)
#         self.dropout = nn.Dropout(dropout)

#     def forward(self, x):
#         out = torch.cat([h(x) for h in self.heads], dim=-1)
#         out = self.dropout(self.proj(out))
#         return out

In [5]:
class Head(nn.Module):
    """One head of self-attention with optional KV caching."""
    def __init__(self, head_size):
        super().__init__()
        self.head_size = head_size
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

        self.k_cache = None
        self.v_cache = None
        self.cache_index = 0
        
    def clear_cache(self):
        self.k_cache = None
        self.v_cache = None
        self.cache_index = 0

    def forward(self, x, kv_cache=False):
        q = self.query(x)  # (B, T, C)

        if not kv_cache:
            k = self.key(x)
            v = self.value(x)
            k_used = k
            v_used = v
        else:
            print("use kv_cache", end = " ")
            x_last = x[:, -1:, :]  # (B, 1, C)
            B, T, C = x_last.shape
            k_last = self.key(x_last)  # (B, 1, head_size)
            v_last = self.value(x_last)

            if self.k_cache is None:
                self.k_cache = torch.zeros(B, block_size, self.head_size, device=x.device)
                self.v_cache = torch.zeros(B, block_size, self.head_size, device=x.device)
                self.cache_index = 0

            if self.cache_index + T <= block_size:
                self.k_cache[:, self.cache_index:self.cache_index+T, :] = k_last
                self.v_cache[:, self.cache_index:self.cache_index+T, :] = v_last
            else:
                shift = self.cache_index + T - block_size
                self.k_cache[:, :-shift, :] = self.k_cache[:, shift:, :].clone()
                self.v_cache[:, :-shift, :] = self.v_cache[:, shift:, :].clone()
                self.k_cache[:, -T:, :] = k_last
                self.v_cache[:, -T:, :] = v_last

            self.cache_index = min(self.cache_index + T, block_size)
            k_used = self.k_cache
            v_used = self.v_cache

        attn_scores = q @ k_used.transpose(-2, -1) * self.head_size ** -0.5
        tril = torch.tril(torch.ones(q.shape[1], k_used.shape[1], device=x.device))
        attn_scores = attn_scores.masked_fill(tril == 0, float('-inf'))
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        out = attn_weights @ v_used
        del q
        del k_used
        del v_used
        return out

#Multi-Headed Self Attention
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, kv_cache):
        out = torch.cat([h(x, kv_cache) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

    def clear_cache(self):
        for head in self.heads:
            head.clear_cache()

In [6]:
#Expert module
class Expert(nn.Module):
    """ An MLP is a simple linear layer followed by a non-linearity i.e. each Expert """

    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

#noisy top-k gating
class NoisyTopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(NoisyTopkRouter, self).__init__()
        self.top_k = top_k
        #layer for router logits
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        self.noise_linear =nn.Linear(n_embed, num_experts)


    def forward(self, mh_output):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.topkroute_linear(mh_output)

        #Noise logits
        noise_logits = self.noise_linear(mh_output)

        #Adding scaled unit gaussian noise to the logits
        noise = torch.randn_like(logits)*F.softplus(noise_logits)
        noisy_logits = logits + noise

        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(noisy_logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices

#Now create the sparse mixture of experts module


class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k, capacity_factor=1.0):
        super(SparseMoE, self).__init__()
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
        self.top_k = top_k
        self.capacity_factor = capacity_factor
        self.num_experts = num_experts
    
    def forward(self, x):
    # Assuming x has shape [batch_size, seq_len, n_embd]
        batch_size, seq_len, _ = x.shape
        gating_output, indices = self.router(x)
        final_output = torch.zeros_like(x)

        # Flatten the batch and sequence dimensions to treat each token independently
        flat_x = x.view(-1, x.size(-1))  
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))

        tokens_per_batch = batch_size * seq_len * self.top_k
        expert_capacity = int((tokens_per_batch / self.num_experts) * self.capacity_factor)

        updates = torch.zeros_like(flat_x)

        for i, expert in enumerate(self.experts):
            expert_mask = (indices == i).any(dim=-1)
            flat_mask = expert_mask.view(-1)
            selected_indices = torch.nonzero(flat_mask).squeeze(-1)

            limited_indices = selected_indices[:expert_capacity] if selected_indices.numel() > expert_capacity else selected_indices
            if limited_indices.numel() > 0:
                expert_input = flat_x[limited_indices]
                expert_output = expert(expert_input)

                gating_scores = flat_gating_output[limited_indices, i].unsqueeze(1)
                weighted_output = expert_output * gating_scores

                updates.index_add_(0, limited_indices, weighted_output)

        # Reshape updates to match the original dimensions of x
        final_output += updates.view(batch_size, seq_len, -1)

        return final_output

In [7]:
#First create a self attention + mixture of experts block, that may be repeated several number of times
#Copy pasting key architecture variables for clarity

class Block(nn.Module):
    """ Mixture of Experts Transformer block: communication followed by computation (multi-head self attention + SparseMoE) """

    def __init__(self, n_embed, n_head, num_experts, top_k):
        # n_embed: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embed // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.smoe = SparseMoE(n_embed, num_experts, top_k)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x, kv_cache = False):
        x = x + self.sa(self.ln1(x), kv_cache)
        x = x + self.smoe(self.ln2(x))
        return x
    def clear_cache(self):
        self.sa.clear_cache()

In [8]:
from transformers.modeling_outputs import Seq2SeqLMOutput

In [9]:
#Finally putting it all together to crease a sparse mixture of experts language model
class SparseMoELanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.ModuleList([Block(n_embed, n_head=n_head, num_experts=num_experts,top_k=top_k) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed) # final layer norm
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, input_ids, labels=None, attention_mask = None, kv_cache = False):
        idx = input_ids
        targets = labels
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        for block in self.blocks:
            x = block(x, kv_cache)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return Seq2SeqLMOutput(logits=logits, loss=loss)

    def generate(self, input_ids, context_length,  max_new_tokens = 256):
        self.eval()
        batch_size = input_ids.shape[0]
        generated = torch.zeros((batch_size, max_new_tokens), dtype=torch.long, device=input_ids.device)
        finish = torch.zeros((batch_size,), device = input_ids.device) 
        for i in range(max_new_tokens):
            idx_cond = input_ids[:, -context_length : ]
            input_ids = idx_cond
            with torch.no_grad():
                output = self(idx_cond, kv_cache = True)
                logits = output.logits
                logits = logits[:, -1, :]

            probs = F.softmax(logits, dim = -1)
            idx_next = torch.multinomial(probs, num_samples=1)
            # print(f"generate token {i}!")
            # print(tokenizer.decode[idx_next[0].item()], end = " ")
            generated[:, i] = idx_next.squeeze(-1)
            input_ids = torch.cat((input_ids, idx_next), dim = -1)
            check_end = (idx_next == tokenizer.eos_token_id).int()
            finish = torch.logical_or(finish, check_end)
            if finish.all():
                break
        self.clear_cache()
        return generated

    def clear_cache(self):
        for block in self.blocks:
            block.clear_cache()

In [10]:

def kaiming_init_weights(m):
    if isinstance (m, (nn.Linear)):
        init.kaiming_normal_(m.weight)

In [11]:
model = SparseMoELanguageModel()
model.apply(kaiming_init_weights)
model.to('cuda')
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model size: {total_params / 1e6} M")

Model size: 36.668224 M


In [12]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt", 
                                       padding="max_length", max_length = 1024, label_pad_token_id=-100)

## Train

In [13]:
training_args = Seq2SeqTrainingArguments("moe_scratchv4_test/",
                                        do_train=True,
                                        do_eval=True,
                                        num_train_epochs=5,
                                        learning_rate=1e-5,
                                        warmup_ratio=0.05,
                                        weight_decay=0.01,
                                        per_device_train_batch_size=8,
                                        per_device_eval_batch_size=8,
                                        logging_dir='./log',
                                        group_by_length=True,
                                        save_strategy="epoch",
                                        save_total_limit=3,
                                        eval_strategy="steps",
                                        logging_steps=100,
                                        fp16=True,
                                      )

In [14]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset= eval_dataset,
    data_collator=data_collator,
)
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpavt1024[0m ([33mhust_edu_vn[0m). Use [1m`wandb login --relogin`[0m to force relogin


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.10 GiB. GPU 0 has a total capacity of 4.00 GiB of which 0 bytes is free. Of the allocated memory 4.98 GiB is allocated by PyTorch, and 42.13 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

## Test

In [28]:
from datasets import load_metric
import numpy as np 
metric = load_metric("rouge") 

predictions = []
references = []
context_length = 128
dataloader = torch.utils.data.DataLoader(test_dataset, collate_fn=data_collator, batch_size=16, drop_last=True)
texts = []
labels_list = []
for i, batch in enumerate(tqdm(dataloader)):
  labels_list.append(batch['labels'])


  0%|          | 0/244 [00:00<?, ?it/s]

In [None]:
with open("generated_texts.txt", "w", encoding="utf-8") as f:
    for outputs in texts:
        decoded = [tokenizer.decode(out, clean_up_tokenization_spaces=False, skip_special_tokens=True) for out in outputs]
        for line in decoded:
            f.write(line.strip() + "\n")
print("Saved to generated_texts.txt")

Saved to generated_texts.txt


In [36]:
with open("generated_texts.txt", "r", encoding="utf-8") as f: 
    text_file = f.read()

texts = text_file.split('\n')

In [29]:
labels_list = []
for i, batch in enumerate(tqdm(dataloader)):
    labels_list.append(batch['labels'])

  0%|          | 0/244 [00:00<?, ?it/s]

In [30]:
predictions = []
references = []
for labels in labels_list:
  with tokenizer.as_target_tokenizer():
    labels = np.where(labels != -100,  labels, tokenizer.pad_token_id)
    actuals = [tokenizer.decode(out, clean_up_tokenization_spaces=False, skip_special_tokens=True) for out in labels]
  references.extend(actuals)



In [None]:
predictions = []
references = []
for outputs, labels in zip(texts, labels_list[:-1]):
  with tokenizer.as_target_tokenizer():
    outputs = [tokenizer.decode(out, clean_up_tokenization_spaces=False, skip_special_tokens=True) for out in outputs]

    labels = np.where(labels != -100,  labels, tokenizer.pad_token_id)
    actuals = [tokenizer.decode(out, clean_up_tokenization_spaces=False, skip_special_tokens=True) for out in labels]
  predictions.extend(outputs)
  references.extend(actuals)
  metric.add_batch(predictions=outputs, references=actuals)



In [41]:
import json
results = dict(metric.compute(predictions=texts[:-1], references=references).items() )
with open("rouge_results_v1.json", "w", encoding="utf-8") as f:
    json.dump(results, f, ensure_ascii=False, indent=4)
print("Saved to rouge_results.json")

Saved to rouge_results.json


In [36]:
import json
results = dict(metric.compute(predictions=predictions, references=references).items() )
with open("rouge_results.json", "w", encoding="utf-8") as f:
    json.dump(results, f, ensure_ascii=False, indent=4)
print("Saved to rouge_results.json")

Saved to rouge_results.json


In [42]:
results

{'rouge1': AggregateScore(low=Score(precision=0.15747547930683872, recall=0.5889597353936619, fmeasure=0.2372820014992402), mid=Score(precision=0.15934795706488486, recall=0.5933378046447517, fmeasure=0.2393352369141203), high=Score(precision=0.16110722956429294, recall=0.5976237666269457, fmeasure=0.24138524710664344)),
 'rouge2': AggregateScore(low=Score(precision=0.014301891835570115, recall=0.05236371094124646, fmeasure=0.021389394906551177), mid=Score(precision=0.014654653715164494, recall=0.05351161962600661, fmeasure=0.021860062992080254), high=Score(precision=0.014955272268918891, recall=0.05477300263038851, fmeasure=0.02227028522990267)),
 'rougeL': AggregateScore(low=Score(precision=0.08264730981685033, recall=0.3201909304895171, fmeasure=0.12499304949356468), mid=Score(precision=0.08353110177968809, recall=0.3232113365969094, fmeasure=0.12593306948524252), high=Score(precision=0.08438061871067698, recall=0.3262146010056825, fmeasure=0.1267747564229634)),
 'rougeLsum': Aggreg

## Load model trained

In [12]:
from safetensors.torch import load_file

In [13]:
model = SparseMoELanguageModel()
model.load_state_dict(load_file(r"D:\Downloads\DS_AI\VDT\MoE\moe_v3_saved\model.safetensors"))
model.to('cuda')

SparseMoELanguageModel(
  (token_embedding_table): Embedding(36096, 256)
  (position_embedding_table): Embedding(1024, 256)
  (blocks): ModuleList(
    (0-3): 4 x Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-7): 8 x Head(
            (key): Linear(in_features=256, out_features=32, bias=False)
            (query): Linear(in_features=256, out_features=32, bias=False)
            (value): Linear(in_features=256, out_features=32, bias=False)
            (dropout): Dropout(p=0.2, inplace=False)
          )
        )
        (proj): Linear(in_features=256, out_features=256, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (smoe): SparseMoE(
        (router): NoisyTopkRouter(
          (topkroute_linear): Linear(in_features=256, out_features=8, bias=True)
          (noise_linear): Linear(in_features=256, out_features=8, bias=True)
        )
        (experts): ModuleList(
          (0-7): 8 x Expert(
            (net): Sequential(
  

In [14]:
predictions = []
references = []
context_length = 1024
test_dataloader = torch.utils.data.DataLoader(test_dataset, collate_fn=data_collator, batch_size=16, drop_last = True)
texts = []
labels_list = []
for i, batch in enumerate(tqdm(test_dataloader)):
  outputs = model.generate(
      input_ids=batch['input_ids'].to('cuda'),
      context_length=context_length,
  )
  texts.append(outputs)
  labels_list.append(batch['labels'])

with open("generated_texts.txt", "w", encoding="utf-8") as f:
    for outputs in texts:
        decoded = [tokenizer.decode(out, clean_up_tokenization_spaces=False, skip_special_tokens=True) for out in outputs]
        for line in decoded:
            f.write(line.strip() + "\n")
print("Saved to generated_texts.txt")
for outputs, labels in zip(texts, labels_list):
  with tokenizer.as_target_tokenizer():
    outputs = [tokenizer.decode(out, clean_up_tokenization_spaces=False, skip_special_tokens=True) for out in outputs]

    labels = np.where(labels != -100,  labels, tokenizer.pad_token_id)
    actuals = [tokenizer.decode(out, clean_up_tokenization_spaces=False, skip_special_tokens=True) for out in labels]
  predictions.extend(outputs)
  references.extend(actuals)
  metric.add_batch(predictions=outputs, references=actuals)
import json
results = dict(metric.compute(predictions=predictions, references=references).items() )
with open("rouge_v3_results.json", "w", encoding="utf-8") as f:
    json.dump(results, f, ensure_ascii=False, indent=4)
print("Saved to rouge_v3_results.json")

print(results)

  0%|          | 0/244 [00:00<?, ?it/s]

use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache use kv_cache

KeyboardInterrupt: 