### Imports

In [24]:
!pip install datasets transformers torch  # install required packages
!pip install transformers faiss-cpu

import faiss
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
import math
from datasets import load_dataset
import torch.utils.data as d
from transformers import BertTokenizer, BertModel
from torchtext.data import get_tokenizer
from torch import nn as nn
from typing import List, Tuple, Any, Optional
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import normalize




### Parameters

In [25]:
BATCH_SIZE=16
NUM_BATCHES=10
HIDDEN_DIM=768
SEQ_LEN=20

### Data similarity batching for Lory

From Lory paper:


"We adapt the pipeline of in-context pre-training (Shi et al., 2024) in our approach. Given a set of documents D, for each document d ∈ D, we first use Contriever (Izacard et al., 2022) to retrieve top-k most similar documents N(d). The similarity between the document di and dj is defined as the cosine similarity of their Contriever embeddings, i.e., sim(di , dj) = cos(C(di), C(dj)), where C denotes the Contriever encoder model. We implement an efficient approximate nearest-neighbors search based on the FAISS library (Johnson et al., 2019). Then, we sort all the documents according to the similarity and construct training instances by batch consecutive documents. We use the same greedy algorithm as Shi et al. (2024). We start from a single document and repeatedly add the document that has the highest similarity value and has not been added to the list; we restart the process with a new document if all documents that are connected to the last document of the list are selected. We repeat this process until there are no documents left."

### Facebook contriever as tokenizer

In [3]:
class MyIterableDataset(d.IterableDataset):
    def __init__(self, dataset, tokenizer, model, seq_len, article_indices):
        super(MyIterableDataset, self).__init__()
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.model = model
        self.seq_len = seq_len
        self.article_indices = article_indices

    def __iter__(self):
        def helper(start, end):
            for i in range(start, end):
                article = self.dataset[i]["text"]
                tokenized = self.tokenizer(article, padding='max_length', truncation=True, max_length=self.seq_len, return_tensors='pt')

                # Get embeddings from the model
                input_ids = tokenized['input_ids']
                attention_mask = tokenized['attention_mask']
                with torch.no_grad():
                    outputs = self.model(input_ids, attention_mask=attention_mask)
                    embeddings = outputs.last_hidden_state.squeeze(0).numpy()  # [seq_len, hidden_size]

                yield embeddings

        worker_info = d.get_worker_info()
        if worker_info is None:
            start = 0
            end = len(self.article_indices)
        else:
            per_worker = int(math.ceil(len(self.article_indices) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            start = worker_id * per_worker
            end = min(start + per_worker, len(self.article_indices))
        return helper(start, end)

def give_dataloader(development=True, batch_size=64, seq_len=20, num_batches=None):
    if development:
        wiki_huggingface_dataset = load_dataset("wikipedia", "20220301.simple")["train"]
    else:
        wiki_huggingface_dataset = load_dataset("wikipedia", "20220301.en")["train"]

    if num_batches is None:
        article_indices = range(wiki_huggingface_dataset.num_rows)
    else:
        article_indices = range(num_batches * batch_size)

    # Load tokenizer and model from facebook/contriever (NEW!)
    tokenizer = AutoTokenizer.from_pretrained('facebook/contriever')
    model = AutoModel.from_pretrained('facebook/contriever').eval()  # Load the entire model (can we skip this?)

    ds = MyIterableDataset(wiki_huggingface_dataset, tokenizer, model, seq_len, article_indices=article_indices)
    return d.DataLoader(ds, batch_size=batch_size, collate_fn=lambda x: x)


data_loader = give_dataloader(development=True, batch_size=BATCH_SIZE, num_batches=NUM_BATCHES)

sample = next(iter(data_loader))
num_batches = sum(1 for _ in data_loader)

# Print the shape of the sample batch to verify the embedding dimensions
print("Number of batches in the dataloader:", num_batches)
print("Batch Size:", len(sample))
print("Seq Length:", len(sample[0]))
print("Embedding size (hidden size):", len(sample[0][0]))


Access to the secret `HF_TOKEN` has not been granted on this notebook.
You will not be requested again.
Please restart the session if you want to be prompted again.


Downloading builder script:   0%|          | 0.00/36.7k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/16.0k [00:00<?, ?B/s]

The repository for wikipedia contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/wikipedia.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


Downloading data:   0%|          | 0.00/134M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/205328 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Number of batches in the dataloader: 10
Batch Size: 16
Seq Length: 20
Embedding size (hidden size): 768


In [4]:
class SequenceDataset(Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx]

def reorganize_dataloader(dataloader, num_batches, batch_size, k=10):
    # Extract embeddings and keep original shapes!
    embeddings = []
    sequences = []

    for batch in dataloader:
        for sequence_embedding in batch:
            flattened_embedding = sequence_embedding.flatten()  # Flatten the sequence embedding
            embeddings.append(flattened_embedding)
            sequences.append(sequence_embedding)

    embeddings = np.array(embeddings)  # Convert list of flattened embeddings to numpy array
    embeddings = normalize(embeddings, axis=1)  # Normalize the flattened embeddings

    # Use FAISS to build an index for approximate nearest neighbors search
    index = faiss.IndexFlatL2(embeddings.shape[1])
    index.add(embeddings)

    # Retrieve top-k similar sequences for each sequence
    _, indices = index.search(embeddings, k)

    # Create adjacency list for similarity
    adjacency_list = {i: set(indices[i]) for i in range(len(embeddings))}

    def create_batches():
        """Use a greedy algorithm to create batches based on similarity."""
        visited = set()
        batches = []
        num_sequences = len(embeddings)

        for i in range(num_sequences):
            if i in visited:
                continue
            batch = [i]
            visited.add(i)
            while len(batch) < batch_size and len(batch) < num_sequences:
                last_seq = batch[-1]
                candidates = adjacency_list[last_seq] - visited
                if not candidates:
                    break
                next_seq = max(candidates, key=lambda x: np.dot(embeddings[last_seq], embeddings[x]))
                batch.append(next_seq)
                visited.add(next_seq)
            batches.append(batch)
            if len(batches) == num_batches:
                break

        # Ensure to fill up the number of required batches - They are going to be less similar!!!
        remaining_indices = set(range(num_sequences)) - visited
        for batch in batches:
            if len(batch) < batch_size and remaining_indices:
                needed = batch_size - len(batch)
                for _ in range(needed):
                    if not remaining_indices:
                        break
                    next_seq = remaining_indices.pop()
                    batch.append(next_seq)

        while remaining_indices and len(batches)<num_batches:
            new_batch = [remaining_indices.pop() for _ in range(batch_size) if remaining_indices]
            batches.append(new_batch)

        return batches

    batches = create_batches()
    # Reorganize sequences into new dataloader batches
    reorganized_sequences = []
    for batch in batches:
        reorganized_sequences.extend([sequences[i] for i in batch])
    reorganized_dataset = SequenceDataset(reorganized_sequences)
    reorganized_dataloader = DataLoader(reorganized_dataset, batch_size=batch_size, shuffle=False)
    return reorganized_dataloader



### Testing

In [5]:
# Create the dataloader
data_loader = give_dataloader(development=True, batch_size=BATCH_SIZE, num_batches=NUM_BATCHES)
sample = next(iter(data_loader))
num_batches = sum(1 for _ in data_loader)

# Print the shape of the sample batch to verify the embedding dimensions
print("Number of batches in the dataloader:", num_batches)
print("Batch Size:", len(sample))
print("Seq Length:", len(sample[0]))
print("Embedding size (hidden size):", len(sample[0][0]))

# Reorganize the dataloader
reorganized_dataloader = reorganize_dataloader(data_loader, num_batches=NUM_BATCHES, batch_size=10, k=10)

# Verify the reorganized dataloader
sample = next(iter(reorganized_dataloader))
num_batches = sum(1 for _ in reorganized_dataloader)

# Print the shape of the sample batch to verify the embedding dimensions
print("Number of batches in the reorganized dataloader:", num_batches)
print("Batch Size reorganized:", len(sample))
print("Seq Length reorganized:", len(sample[0]))
print("Embedding size (hidden size) reorganized:", len(sample[0][0]))

Number of batches in the dataloader: 10
Batch Size: 16
Seq Length: 20
Embedding size (hidden size): 768
Number of batches in the reorganized dataloader: 10
Batch Size reorganized: 10
Seq Length reorganized: 20
Embedding size (hidden size) reorganized: 768


### Model classes

In [26]:

AttentionT = torch.tensor  # torch tensor of shape [BATCH, SEQ_LEN, NUM_HEADS, HEAD_DIM]
HiddenT = torch.tensor
TokensT = torch.tensor # [BATCH, SEQ_LEN]
ModelLT = torch.tensor # [BATCH, SEQ_LEN, VOCAB_SIZE]

class AttentionCreateQKV(torch.nn.Module):
    """
    Given a tensor of shape [BATCH, SEQ_LEN, HIDDEN_DIM]
    uses linear projections to create three tensors
    Query, Key and Value.
    Each of the created tensors has shape [BATCH, SEQ_LEN, NUM_HEADS, HEAD_DIM].
    Where HEAD_DIM = HIDDEN_DIM // NUM_HEADS
    """

    def __init__(self, hidden_dim, num_heads) -> None:
        super().__init__()
        assert hidden_dim % num_heads == 0
        self.head_dim = hidden_dim // num_heads
        self.num_heads = num_heads

        self.key_transform = nn.Linear(in_features = hidden_dim, out_features = hidden_dim, bias = False)
        self.query_transform = nn.Linear(in_features = hidden_dim, out_features = hidden_dim, bias = False)
        self.value_transform = nn.Linear(in_features = hidden_dim, out_features = hidden_dim, bias = False)

    def forward(self, x: HiddenT) -> Tuple[AttentionT, AttentionT, AttentionT]:
        assert len(x.shape) == 3  # torch tensor of shape [BATCH, SEQ_LEN, HIDDEN_DIM]

        result = []
        shape = x.shape

        Q = self.query_transform(x)
        result.append(torch.reshape(Q, (shape[0], shape[1], self.num_heads, self.head_dim)))

        K = self.key_transform(x)
        result.append(torch.reshape(K, (shape[0], shape[1], self.num_heads, self.head_dim)))

        V = self.value_transform(x)
        result.append(torch.reshape(V, (shape[0], shape[1], self.num_heads, self.head_dim)))

        assert len(result) == 3  # queries, keys, values
        for r in result:
            assert len(r.shape) == 4  # [BATCH, SEQ_LEN, NUM_HEADS, HEAD_DIM]
            assert r.shape[-2:] == (self.num_heads, self.head_dim)
            assert r.shape[:-2] == x.shape[:2]

        return result

class RoPEPosEncoding(torch.nn.Module):
    """
    Given a tensor of shape [BATCH, SEQ_LEN, NUM_HEADS, HEAD_DIM]
    applies Rotary Positional Encoding.
    offset allows to apply rotary to sequnce part by part by telling how much tokens preecede the input in the sequence.
    """

    def __init__(self, head_dim, number) -> None:
        super().__init__()

        assert head_dim % 2 == 0
        self.hidden_dim = head_dim
        self.number = number
        self.theta = (1. / (self.number ** (torch.arange(0, head_dim, 2).float() / head_dim))) #now of length head_dim//2  #with .repeat_interleave(2) would double


    def forward(self, x: AttentionT, offset: int = 0):
        assert (
            len(x.shape) == 4
        )  # torch tensor of shape [BATCH, SEQ_LEN, NUM_HEADS, HEAD_DIM]
        assert offset >= 0

        (batch, seq_len, num_heads, head_dim) = x.shape
        pos_idx = offset + torch.arange(seq_len).float().to(x.device) #position index (j)
        self.theta = self.theta.to(x.device)
        angle_ji = torch.outer(pos_idx, self.theta) #outer product of position index j and θi #shape==(seq_len, head_dim//2)
        angle_ji = torch.reshape(input = angle_ji, shape = (1, seq_len, 1, head_dim//2))

        #both of shape (1, seq_len, 1, head_dim//2)
        cos = torch.cos(angle_ji).requires_grad_(requires_grad=False)
        sin = torch.sin(angle_ji)

        x_paired = torch.reshape(input = x, shape = (batch, seq_len, num_heads, head_dim // 2, 2))

        result = torch.zeros_like(x_paired)
        result[:, :, :, :, 0] = x_paired[:, :, :, :, 0] * cos - x_paired[:, :, :, :, 1] * sin
        result[:, :, :, :, 1] = x_paired[:, :, :, :, 0] * sin + x_paired[:, :, :, :, 1] * cos

        result = torch.reshape(input = result, shape = x.shape)

        assert result.shape == x.shape

        return result

number = 10000 #used in calculating theta

ACacheT = Tuple[
    torch.tensor, torch.tensor
]  # key, value, both of shape [BATCH, SEQ_LEN, NUM_HEADS, HEAD_DIM]


class Attention(torch.nn.Module):
    """
    Implements multi-head attention layer.
    Inputs tensor x of shape [BATCH, SEQ_LEN, hidden_dim].
    Uses head_proj to create three tensors q, k, v - each of shape
    [BATCH, SEQ_LEN, num_heads, head_dim].
    Then applies RoPE to q and k.
    Then calculates attention within each head, concatenates the results
    and linearly projects them to a tensor of shape [BATCH, SEQ_LEN, hidden_dim].

    Cache is a tuple of keys (kc) and values (vc) calculated in previous calls.
    For training the cache will be empty (tensors kc and vc should have shape [BATCH, 0, num_heads, hidden_dim]),
    For efficient generation, the cache will contain keys (kc), values (vc) of already read/generated tokens
    (this allows the generation of one additional token without recomputing the keys and values for all preceding tokens).
    After RoPE application to k, kc and vc are prepended to k and v respectively.

    The model outputs the linearly projected output of attention along with a cache extended with new keys and values.
    """
    def __init__(
        self, hidden_dim: int, num_heads: int, head_proj=AttentionCreateQKV
    ) -> None:
        super().__init__()

        assert hidden_dim % num_heads == 0

        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads

        self.projector = head_proj(self.hidden_dim, self.num_heads)
        self.encoder = RoPEPosEncoding(self.head_dim, number)
        self.linear = nn.Linear(in_features = hidden_dim, out_features = hidden_dim)

    def get_empty_cache(self, batch_size: int, device) -> ACacheT:
        return torch.empty(
            batch_size, 0, self.num_heads, self.head_dim, device=device
        ), torch.empty(batch_size, 0, self.num_heads, self.head_dim, device=device)

    def forward(self, x: HiddenT) -> Tuple:
        assert len(x.shape) == 3  # torch tensor of shape [BATCH, SEQ_LEN, HIDDEN_DIM]

        (batch, seq_len, _) = x.shape

        qkv = self.projector(x)
        (Q, K, V) = (qkv[0], qkv[1], qkv[2])

        Q = self.encoder(Q) #shape [BATCH, SEQ_LEN, num_heads, head_dim]
        Q = torch.transpose(input = Q, dim0 = 1, dim1 = 2) #shape [BATCH, num_heads, SEQ_LEN, head_dim]
        K = self.encoder(K) #shape [BATCH, SEQ_LEN, num_heads, head_dim]
        K = torch.transpose(input = K, dim0 = 1, dim1 = 2) #shape [BATCH, num_heads, SEQ_LEN, head_dim]

        #A=QK^T
        A = torch.matmul(input = Q, other = torch.transpose(input = K, dim0 = 2, dim1 = 3))
        #shape of A is [BATCH, num_heads, SEQ_LEN, SEQ_LEN]
        A = torch.div(input = A, other = torch.sqrt(torch.tensor([self.head_dim], device = x.device)))

        #mask = torch.tensor([float('-inf')]) * torch.triu(torch.ones_like(A), diagonal = 1)
        mask = torch.triu(torch.ones_like(A), diagonal = 1)
        mask = mask.masked_fill(mask == 1, float('-inf'))

        A = A + mask
        A = torch.nn.functional.softmax(input = A, dim = 3)

        O = torch.einsum("bhij,bjhd->bihd", A, V)
        O = torch.reshape(input = O, shape = x.shape)

        attention = self.linear(O) #as asked here: "The output:    concatenate outputs of the heads and project them linearly to have the hidden_dim dimension"

        ############# Uwaga, tu jest ta warstwa liniowa i ja nie jestem pewien czy ona powinna tu być? Czemu ona służy?

        attention_weights = A

        return attention, attention_weights

class LayerNorm(torch.nn.Module):
    def __init__(self, hidden_dim, eps=1e-05) -> None:
        super().__init__()

        self.hidden_dim = hidden_dim
        self.eps = eps

        self.gamma = nn.Parameter(torch.ones((hidden_dim)))
        self.beta = nn.Parameter(torch.zeros((hidden_dim)))

    def forward(self, x: HiddenT) -> HiddenT:
        assert len(x.shape) == 3  # torch tensor of shape [BATCH, SEQ_LEN, HIDDEN_DIM]

        mean = torch.mean(input = x, dim = -1, keepdim = True)
        mean_x2 = torch.mean(input = (x ** 2), dim = -1, keepdim = True)
        var = mean_x2 - (mean ** 2)
        x_norm = (x - mean) / torch.sqrt(var + self.eps)

        result = x_norm * self.gamma + self.beta

        assert x.shape == result.shape
        return result

class TransformerBlock(torch.nn.Module):
    def __init__(self, config) -> None:
        """
        forward_layer_class - an nn module for MoE or Lori or MH moe or MH Lori
        config -
        num_heads - num attention heads
        """
        super().__init__()
        self.hidden_dim = config.hidden_size
        self.num_heads = config.num_attention_heads

        self.layer_norm1 = LayerNorm(self.hidden_dim)
        self.layer_norm2 = LayerNorm(self.hidden_dim)
        self.forward_layer = config.forward_layer_class(config) #an nn module
        self.attention = Attention(self.hidden_dim, self.num_heads)

    def forward(self, x: HiddenT) -> HiddenT:
        #sub_with_skip(x) = x + sublayer(norm(x))
        a = self.attention(self.layer_norm1(x))
        result = x + a[0]
        result = result + self.forward_layer(self.layer_norm2(result))

        assert x.shape == result.shape

        return result

# TokensT = torch.tensor # [BATCH, SEQ_LEN]
# ModelLT = torch.tensor # [BATCH, SEQ_LEN, VOCAB_SIZE]


class Transformer(torch.nn.Module):
    def __init__(self, config) -> None:
        super().__init__()

        self.vocab_size = config.vocab_size
        self.n_layers = config.n_layers
        self.hidden_dim = config.hidden_size
        self.forward_layer = config.forward_layer_class
        self.num_heads = config.num_attention_heads

        self.embedding = torch.nn.Embedding(self.vocab_size, self.hidden_dim)

        self.layers = torch.nn.ModuleList([
            TransformerBlock(config = config) for _ in range(self.n_layers)
        ])

        self.final_proj = torch.nn.Linear(self.hidden_dim, self.vocab_size)

    def forward(self, x: TokensT) -> ModelLT:
        assert len(x.shape) == 2 # [BATCH, SEQ_LEN]

        x = self.embedding(x)

        for l in self.layers:
            x = l(x)

        x = self.final_proj(x)
        return x

# Input: [batch_size, seq_len, hidden_size] - input embeddings
# Output: [batch_size, seq_len, num_experts] - expert routing weights
class Router(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts_per_token = config.num_experts_per_token
        self.hidden_size = config.hidden_size
        self.num_experts = config.num_experts

        self.expert_embeddings = nn.Parameter(torch.randn(self.num_experts, self.hidden_size)).to(config.device)
        torch.nn.init.kaiming_uniform_(self.expert_embeddings, nonlinearity='linear')

    def forward(self, x):
        dot = torch.einsum("bsh,eh->bse", x, self.expert_embeddings)
        top_k_out = torch.topk(dot, k=self.num_experts_per_token)
        top_k = (float("-inf") * torch.ones_like(dot)).scatter_(dim=-1, index=top_k_out.indices, src=top_k_out.values)
        res = torch.nn.functional.softmax(top_k, dim=-1)
        return res

# Input: [batch_size, seq_len, hidden_size] - input embeddings
# Output: [batch_size, seq_len, hidden_size] - output embeddings
class VectorizedMoE(nn.Module):
    """version which takes first not random tokens up to expert_capacity"""
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts = config.num_experts
        self.hidden_size = config.hidden_size
        self.num_experts_per_token = config.num_experts_per_token
        self.capacity_factor = config.capacity_factor
        self.intermediate_size = config.intermediate_size

        # You can change experts representation if you want
        self.first_linear = nn.Parameter(torch.randn(self.num_experts, self.intermediate_size, self.hidden_size)).to(config.device)
        torch.nn.init.kaiming_uniform_(self.first_linear, nonlinearity='linear')
        self.second_linear = nn.Parameter(torch.randn(self.num_experts, self.hidden_size, self.intermediate_size)).to(config.device)
        torch.nn.init.kaiming_uniform_(self.second_linear, nonlinearity='linear')

        self.router = Router(config)

    def forward(self, x):
        batch_size, seq_len, hidden_size = x.shape
        #assert hidden_size == self.hidden_size
        expert_capacity = math.ceil(batch_size * seq_len / self.num_experts * self.capacity_factor)

        weights = self.router(x) #[batch_size, seq_len, num_experts]

        experts_where_ones = torch.where((weights <= 0), 0, 1) #ceiling of weights
        experts_where_ones = torch.reshape(experts_where_ones, shape=(-1, self.num_experts)) #[num_of_tokens, num_experts]
        capacity_aware_ones = torch.where((torch.cumsum(experts_where_ones, dim= 0) <= expert_capacity), input = experts_where_ones, other = 0)

        # dec_seq = experts_where_ones.shape[0] - torch.arange(experts_where_ones.shape[0]).unsqueeze(dim = 1)
        # numbered = (experts_where_ones * dec_seq)
        # which = torch.topk(numbered, k=expert_capacity, dim = 0)
        capacity_aware_weights = weights.reshape(shape=(-1, self.num_experts)) * capacity_aware_ones
        which = torch.topk(capacity_aware_weights, k=expert_capacity, dim = 0)
        indices = which.indices.transpose(1,0)
        index = indices.reshape((-1))

        tokens_for_experts = torch.index_select(input=x.reshape((-1, hidden_size)), dim=0, index=index) #[capacity*num_experts, hidden_size]
        tokens_for_experts  = tokens_for_experts.reshape((self.num_experts, expert_capacity, hidden_size))
        #now I have the proper input to the "experts", which I should process by first layer parameters

        intermediate_result = torch.einsum("ech,eih->eci", tokens_for_experts, self.first_linear)
        intermediate_result = torch.nn.functional.relu(intermediate_result)
        result = torch.einsum("eci,ehi->ech", intermediate_result, self.second_linear)
        #now tokens are processed by the "experts", I need to multiply by the weights and add them up

        w = which.values.transpose(1,0).unsqueeze(-1)

        result = result * w

        final_result = torch.zeros_like(x).reshape((-1, hidden_size)).index_add_(dim = 0, index=index, source = result.reshape((-1, hidden_size)))

        return final_result.reshape(x.shape)

class FeedForward(torch.nn.Module): #służy mi do tego by sprawdzić czy jak się podmieni vectorised moe w configu na cos innego to dziala
    """
    Inputs a tensor of shape [BATCH, SEQ_LEN, HIDDEN_DIM]
    and processes it as follows:
    * project linearly from hidden_dim to inner_dim
    * apply activation function (GELU)
    * project linearly from inner_dim to hidden_dim
    """

    def __init__(self, config) -> None:
        super().__init__()
        self.hidden_dim = config.hidden_size
        self.inner_dim = config.intermediate_size
        self.first_linear = nn.Linear(in_features = self.hidden_dim, out_features = self.inner_dim, bias = True) #in the paper there is a bias term here
        self.activation = nn.GELU()
        self.second_linear = nn.Linear(in_features = self.inner_dim, out_features = self.hidden_dim, bias = True) #in the paper there is a bias term here

    def forward(self, x: HiddenT) -> HiddenT:
        # [BATCH, SEQ_LEN, HIDDEN_DIM]
        assert len(x.shape) == 3

        ### YOUR CODE STARTS ###
        result = self.second_linear(self.activation(self.first_linear(x)))
        ###  YOUR CODE ENDS  ###

        # [BATCH, SEQ_LEN, HIDDEN_DIM]
        assert len(result.shape) == 3
        return result

### Lory

In [27]:
from transformers import PretrainedConfig

In [28]:
from torch import nn
import torch
from transformers import PretrainedConfig
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size, config.intermediate_size),
            nn.ReLU(),
            nn.Linear(config.intermediate_size, config.hidden_size),
        )

    def forward(self, x):
        return self.mlp(x)

In [29]:
# Input: [batch_size, seq_len, hidden_size] - input embeddings
# Output: [batch_size, seq_len, num_experts] - expert routing weights
class Router(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts_per_token = config.num_experts_per_token
        self.hidden_size = config.hidden_size
        self.num_experts = config.num_experts

        self.expert_embeddings = nn.Parameter(torch.randn(self.num_experts, self.hidden_size)).to(config.device)
        torch.nn.init.kaiming_uniform_(self.expert_embeddings, nonlinearity='linear')

    def forward(self, x):
        dot = torch.einsum("bsh,eh->bse", x, self.expert_embeddings)
        top_k_out = torch.topk(dot, k=self.num_experts_per_token)
        top_k = (float("-inf") * torch.ones_like(dot)).scatter_(dim=-1, index=top_k_out.indices, src=top_k_out.values)
        res = torch.nn.functional.softmax(top_k, dim=-1)
        return res

In [65]:
import math

# Input: [batch_size, seq_len, hidden_size] - input embeddings
# Output: [batch_size, seq_len, hidden_size] - output embeddings
class MoE_Lory(nn.Module):
    """version which takes first not random tokens up to expert_capacity"""
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts = config.num_experts
        self.hidden_size = config.hidden_size
        self.num_experts_per_token = config.num_experts_per_token
        self.capacity_factor = config.capacity_factor
        self.T_segments=config.T_segments
        # You can change experts representation if you want
        # self.experts = nn.ModuleList([MLP(config) for _ in range(self.num_experts)])
        #not as above but as below instead so as to compare more easily with the vectorized version
        self.intermediate_size = config.intermediate_size
        self.first_linear = nn.Parameter(torch.randn(self.num_experts, self.intermediate_size, self.hidden_size)).to(config.device)
        torch.nn.init.kaiming_uniform_(self.first_linear, nonlinearity='linear')
        self.second_linear = nn.Parameter(torch.randn(self.num_experts, self.hidden_size, self.intermediate_size)).to(config.device)
        torch.nn.init.kaiming_uniform_(self.second_linear, nonlinearity='linear')

        self.router = Router(config)

    def compute_out(self, data,linear1,linear2):
        return linear2 @ torch.nn.functional.relu(linear1 @ data)

    def merge_expert(self, weights):
        num_exp,_,_=weights.shape
        # expanded_weights1 = torch.ones((num_exp, self.intermediate_size, self.hidden_size)) * weights
        # expanded_weights2 = torch.ones((num_exp, self.hidden_size, self.intermediate_size)) * weights
        # linear1 = expanded_weights1 @ self.first_linear
        # linear2 = expanded_weights1 @ self.second_linear
        weighted_first_linear = torch.sum(weights * self.first_linear, dim=0)
        weighted_second_linear = torch.sum(weights * self.second_linear, dim=0)
        return weighted_first_linear,weighted_second_linear

    def forward(self, x):
        batch_size, seq_len, hidden_size = x.shape
        #assert hidden_size == self.hidden_size
        expert_capacity = math.ceil(batch_size * seq_len / self.num_experts * self.capacity_factor)
        result = torch.zeros_like(x)
        segment_size=seq_len//self.T_segments
        for i in range(batch_size):
            for t in range(self.T_segments):
                segment=x[i, t*segment_size:(t+1)*segment_size]
                #print("segment shape check:",segment.shape,"seq_len/T, hidden_dim")

                if t==0:
                  with torch.no_grad():
                    h_x=segment.sum(axis=0)/segment_size
                    h_x=h_x.unsqueeze(0)
                    h_x=h_x.unsqueeze(0)
                    #print("h_x shapecheck:",h_x.shape)
                    old_weights = self.router(h_x)
                    old_weights=old_weights.permute(2,0,1)
                    #print("old_weights shapecheck:",old_weights.shape)
                    merged_linear1,merged_linear2=self.merge_expert(old_weights)
                  for j in range(segment_size):
                    result[i, t*segment_size+j] = self.compute_out(x[i, t*segment_size+j],merged_linear1,merged_linear2)
                else:
                  h_x=segment.sum(axis=0)/segment_size
                  h_x=h_x.unsqueeze(0)
                  h_x=h_x.unsqueeze(0)
                  weights=self.router(h_x)
                  weights=weights.permute(2,0,1)
                  #print("weights shapecheck:",weights.shape)
                  merged_linear1,merged_linear2=self.merge_expert(old_weights)
                  old_weights=weights
                  for j in range(segment_size):
                    result[i, t*segment_size+j] = self.compute_out(x[i, t*segment_size+j],merged_linear1,merged_linear2)

        return result

In [66]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_config = dict(
    vocab_size=5000,
    max_position_embeddings=256,
    num_attention_heads=8,
    num_hidden_layers=4,
    hidden_dropout_prob=0.1,
    hidden_size=128,
    intermediate_size=512,
    num_labels=2,
    device = DEVICE #I added this one
)
moe_config = PretrainedConfig(
    **base_config,
    T_segments=5,
    num_experts=6,
    capacity_factor=2.0,
    num_experts_per_token=1,
    ff_cls=MoE_Lory
)

In [67]:
lory = MoE_Lory(moe_config)
batch_size, seq_len, hidden_size=16,20,128

input = torch.randn((batch_size, seq_len, hidden_size))
v = lory(input)

### Not done

In [None]:
# Input: [batch_size, seq_len, hidden_size] - input embeddings
# Output: [batch_size, seq_len, hidden_size] - output embeddings
class VectorizedMoE_Lory(nn.Module):
    """version which takes first not random tokens up to expert_capacity"""
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts = config.num_experts
        self.hidden_size = config.hidden_size
        self.num_experts_per_token = config.num_experts_per_token
        self.capacity_factor = config.capacity_factor
        self.intermediate_size = config.intermediate_size

        # You can change experts representation if you want
        self.first_linear = nn.Parameter(torch.randn(self.num_experts, self.intermediate_size, self.hidden_size)).to(config.device)
        torch.nn.init.kaiming_uniform_(self.first_linear, nonlinearity='linear')
        self.second_linear = nn.Parameter(torch.randn(self.num_experts, self.hidden_size, self.intermediate_size)).to(config.device)
        torch.nn.init.kaiming_uniform_(self.second_linear, nonlinearity='linear')

        self.router = Router(config)

    def forward(self, x):
        batch_size, seq_len, hidden_size = x.shape
        #assert hidden_size == self.hidden_size
        expert_capacity = math.ceil(batch_size * seq_len / self.num_experts * self.capacity_factor)

        weights = self.router(x) #[batch_size, seq_len, num_experts]

        experts_where_ones = torch.where((weights <= 0), 0, 1) #ceiling of weights
        experts_where_ones = torch.reshape(experts_where_ones, shape=(-1, self.num_experts)) #[num_of_tokens, num_experts]
        capacity_aware_ones = torch.where((torch.cumsum(experts_where_ones, dim= 0) <= expert_capacity), input = experts_where_ones, other = 0)

        # dec_seq = experts_where_ones.shape[0] - torch.arange(experts_where_ones.shape[0]).unsqueeze(dim = 1)
        # numbered = (experts_where_ones * dec_seq)
        # which = torch.topk(numbered, k=expert_capacity, dim = 0)
        capacity_aware_weights = weights.reshape(shape=(-1, self.num_experts)) * capacity_aware_ones
        which = torch.topk(capacity_aware_weights, k=expert_capacity, dim = 0)
        indices = which.indices.transpose(1,0)
        index = indices.reshape((-1))

        tokens_for_experts = torch.index_select(input=x.reshape((-1, hidden_size)), dim=0, index=index) #[capacity*num_experts, hidden_size]
        tokens_for_experts  = tokens_for_experts.reshape((self.num_experts, expert_capacity, hidden_size))
        #now I have the proper input to the "experts", which I should process by first layer parameters

        intermediate_result = torch.einsum("ech,eih->eci", tokens_for_experts, self.first_linear)
        intermediate_result = torch.nn.functional.relu(intermediate_result)
        result = torch.einsum("eci,ehi->ech", intermediate_result, self.second_linear)
        #now tokens are processed by the "experts", I need to multiply by the weights and add them up

        w = which.values.transpose(1,0).unsqueeze(-1)

        result = result * w

        final_result = torch.zeros_like(x).reshape((-1, hidden_size)).index_add_(dim = 0, index=index, source = result.reshape((-1, hidden_size)))

        return final_result.reshape(x.shape)