<a href="https://colab.research.google.com/github/abudubai16/NLP-using-MoE/blob/main/MoE_NLP_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This is an implementation of Mixture of Experts using Switch Transformers based off of the following research paper (https://arxiv.org/pdf/2101.03961)

I implented a Switching 'Mixture of Experts' layer. Then I created a new torch.nn.TransformerEncoderLayer using inheritance to overwrite pytorch's feedforward section with the new layer I mentioned before. The rest of the code is a standard transformer used of word prediction.

## Downloading required libraries

In [None]:
! pip install portalocker --q
! pip install datasets --q --q
! pip install positional-encodings[pytorch] --q --q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/542.0 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.3/21.3 MB[0m [31m56.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
! pip uninstall JAX

Found existing installation: jax 0.4.26
Uninstalling jax-0.4.26:
  Would remove:
    /usr/local/lib/python3.10/dist-packages/jax-0.4.26.dist-info/*
    /usr/local/lib/python3.10/dist-packages/jax/*
Proceed (Y/n)? Y
  Successfully uninstalled jax-0.4.26


## Importing libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from tqdm import tqdm
from positional_encodings.torch_encodings import PositionalEncoding1D
from datasets import load_dataset, concatenate_datasets

from transformers import RobertaTokenizer
from tokenizers import Tokenizer

import os
import time
import math

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Downloading the dataset

In [None]:
if not os.path.exists('/content/wikitext'):
  ! git clone 'https://huggingface.co/datasets/wikitext' --q

In [None]:
dataset = load_dataset('parquet', data_files={
    'train1':'/content/wikitext/wikitext-103-v1/train-00000-of-00002.parquet',
    'train2': '/content/wikitext/wikitext-103-v1/train-00001-of-00002.parquet',
    'val':'/content/wikitext/wikitext-103-v1/validation-00000-of-00001.parquet',
    'test':'/content/wikitext/wikitext-103-v1/test-00000-of-00001.parquet'
})

# Dataset for training the model
train_ds = concatenate_datasets([
    dataset['train1'],
      dataset['train2'],
      dataset['test']
])['text']

# Dataset for creating the tokenizer
stacked = concatenate_datasets([
      dataset['train1'],
      dataset['train2'],
      dataset['val'],
      dataset['test']
  ])

Generating train1 split: 0 examples [00:00, ? examples/s]

Generating train2 split: 0 examples [00:00, ? examples/s]

Generating val split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

### Tokenizer

In [None]:
tokenizer = RobertaTokenizer.from_pretrained('FacebookAI/roberta-base')

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

## Creating the dataloader

In [None]:
class token_ds(Dataset):
  def __init__(self, ds, max_length):
    super().__init__()
    length = len(ds)
    max_length += 1
    collector = []
    temp = torch.zeros(1, max_length)

    with tqdm(total=length) as pbar:
      for pos, line in enumerate(ds):
        sentences = line.split('.')
        for sentence in sentences:
          a = torch.tensor(tokenizer.encode(sentence)).long()
          length = a.shape[0]

          if a[1] == 2:
            continue

          if length < max_length:
            a = torch.cat([a, torch.ones(max_length-length)])
          else:
            a = a[-max_length:]

          collector.append(a.unsqueeze(0))
          pos += 1

          # Performance purpose
          if pos%10000 == 0:
            temp = torch.cat([temp, *collector], dim=0).long()
            collector = []

        pbar.update(1)

    temp = torch.cat([temp, *collector], dim=0).long()

    self.value = temp[1:, :-1]
    self.target = temp[1:, 1:]

  def __len__(self):
    return self.value.shape[0]

  def __getitem__(self, i):
    return self.value[i,:], self.target[i, :]

In [None]:
seq_len = 100
batch_size = 32

val_dl = DataLoader(token_ds(dataset['val']['text'], seq_len), num_workers=2, batch_size=batch_size)

100%|██████████| 3760/3760 [00:05<00:00, 724.94it/s]


In [None]:
for pos, (X,Y) in enumerate(val_dl):
  print(f'{X} \n {Y}')
  print(f'{X.shape} \n {Y.shape}')
  break

tensor([[    0,  5457, 11858,  ...,     1,     1,     1],
        [    0, 11858, 42292,  ...,     1,     1,     1],
        [    0,    85,    16,  ...,     1,     1,     1],
        ...,
        [    0,    20,    80,  ...,     1,     1,     1],
        [    0,    20,  4533,  ...,     1,     1,     1],
        [    0, 38187, 40037,  ...,     1,     1,     1]]) 
 tensor([[ 5457, 11858, 42292,  ...,     1,     1,     1],
        [11858, 42292, 20577,  ...,     1,     1,     1],
        [   85,    16,  3615,  ...,     1,     1,     1],
        ...,
        [   20,    80,  4707,  ...,     1,     1,     1],
        [   20,  4533,  6031,  ...,     1,     1,     1],
        [38187, 40037, 10361,  ...,     1,     1,     1]])
torch.Size([32, 100]) 
 torch.Size([32, 100])


## Switching 'Mixture Of Experts' BLock

In [None]:
class SwitchingMoE(nn.Module):
  def __init__(self, d_model, noise_const=0.8, decay_rate=1e-3, dim_feedforward=2048, num_experts=8, dropout=0.1, loss_const=0.01, compute_loss=True):
    super().__init__()

    self.gate = nn.Linear(d_model, num_experts)
    self.num_experts = num_experts

    self.noise_const = noise_const
    self.loss_const = loss_const
    self.decay_rate = decay_rate
    self.compute_loss = compute_loss

    self.FC1 = nn.ModuleList([nn.Linear(d_model, dim_feedforward) for _ in range(num_experts)])
    self.FC2 = nn.ModuleList([nn.Linear(dim_feedforward, d_model) for _ in range(num_experts)])

    self.dropout1 = nn.Dropout(dropout)
    self.dropout2 = nn.Dropout(dropout)

    self.root2 = math.sqrt(2)

  def _compute_aux_loss(self, conditions, values, T):
    """
      shape of conditions: (B*T, num_experts)
      shape of indices: (B*T)
      shape of values: (B*T)
    """

    f = torch.zeros(self.num_experts, device=device)
    for i in range(self.num_experts):
      f[i] = values[values==i].shape[0]

    f = f/T
    p = torch.sum(conditions, dim=0)/T
    aux_loss = self.loss_const*self.num_experts*(f*p).sum()
    return aux_loss

  def forward(self, x):
    B, T, C = x.shape
    x = x.view(B*T, -1).to(device)

    # Create the noise, no noise when the module is in evaluation
    '''
    if self.training:
      noise = torch.randn(B*T, self.num_experts, device=device)
      self.noise_const *= (1-self.decay_rate)
    else:
      noise = torch.zeros(B*T, self.num_experts, device=device)
    '''
    # Find which expert to find
    conditions = F.softmax(self.gate(x), dim=1)

    # Add the generated noise to the conditions
    # conditions = (conditions+self.noise_const*noise)/self.root2
    values = torch.argmax(conditions, dim=1).squeeze(0)

    if self.compute_loss and self.training:
      aux_loss = self._compute_aux_loss(conditions, values, B*T)
      aux_loss.backward(retain_graph=True)
    else:
      aux_loss = 0

    # take out the tensors for each specific FC layer and return it after calculations
    for i in range(self.num_experts):
      x[values==i] = self.dropout2(self.FC2[i](self.dropout1(nn.functional.relu(self.FC1[i](x[values == i])))))

    x = x.view(B, T, C)

    return x

## Creating the model

In [None]:
class TransformerEncoderLayer(nn.Module):
  def __init__(self, d_model, nhead, noise_const=0.8, dim_feedforward=2048, num_experts=8, dropout=0.1, loss_const=0.01, compute_loss=True):
    super().__init__()

    assert(noise_const<1)
    assert(noise_const>0)
    assert(loss_const>0)

    if d_model!=512 and dim_feedforward==2048:
      dim_feedforward = 4*d_model

    self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, batch_first=True)

    self.switching_moe = SwitchingMoE(d_model,
                                      noise_const=noise_const,
                                      dim_feedforward=dim_feedforward,
                                      num_experts=num_experts,
                                      dropout=dropout,
                                      loss_const=loss_const,
                                      compute_loss=compute_loss
                                      )

    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, src_mask=None):
    x = self.norm1(x + self._sa_block(x, src_mask))
    t = self._ff_block(x)
    x = self.norm2(t + x)
    return x

  def _sa_block(self, x, mask=None):
    x = self.self_attn(x, x, x, attn_mask=mask)[0]
    return self.dropout(x)

  def _ff_block(self, x):
    x = self.switching_moe(x)
    return x

In [None]:
class TransformerModel(nn.Module):
  def __init__(self, ntokens, d_model, nhead, seq_len, device, noise_const=0.8, nlayers=6, num_experts=8, dim_feedforward=2048, dropout=0.1, compute_loss=True):
    super(TransformerModel, self).__init__()
    self.d_model = d_model
    self.device = device
    """
    ntokens: The number of words in the dictionary
    d_model: the size of each word embedding
    n_head: the number of heads in each encoder layer
    nlayers: the number of transformer encoder layers in the encoder
    """
    self.seq_len = seq_len

    self.embed = nn.Embedding(ntokens, d_model, padding_idx=0)
    self.pos_embed = PositionalEncoding1D(d_model)

    self.encoder = nn.ModuleList([TransformerEncoderLayer(d_model,
                                                          nhead,
                                                          noise_const=noise_const,
                                                          num_experts=num_experts,
                                                          compute_loss=compute_loss) for _ in range(nlayers)])

    self.fc = nn.Linear(d_model, ntokens)
    self.dropout = nn.Dropout(dropout)

    self.src_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(device)
    self.init_weights()

  def init_weights(self):
    initrange = 0.1
    self.embed.weight.data.uniform_(-initrange, initrange)
    self.fc.bias.data.zero_()
    self.fc.weight.data.uniform_(-initrange, initrange)

  def forward(self, src, src_mask = None):
    aux_loss = 0
    src = self.embed(src) * math.sqrt(self.d_model)
    src = src + self.pos_embed(src)

    if src_mask is None:
      src_mask = self.src_mask

    for encoder in self.encoder:
      src = encoder(src, src_mask=src_mask)

    target = self.dropout(F.softmax(self.fc(src), dim=1))

    return target

### Sequence generator

In [None]:
def generate(src: str, num_words: int):
  model.eval()
  src = torch.tensor(tokenizer.encode(src)[:-1]).long()

  for _ in range(num_words):
    # Pad the input sequence properly so that the model can understand
    temp = src
    pos = seq_len-1
    if int(src.shape[0]) < seq_len:
      temp = torch.cat([src, torch.ones(seq_len-src.shape[0], dtype=torch.long)])
      pos = src.shape[0]
    if int(src.shape[0]) > seq_len:
      temp = src[-seq_len:]

    # Get the predictions of words from the model
    temp = temp.unsqueeze(0).to(device)
    dist = model(temp) # (1, seq_len, word_probability)
    dist = dist.squeeze(0)[pos-1]  # <----- (pos-1)

    # Sort the words in most likely to come next
    dist = torch.sort(dist)[1]

    # Pick out the word from the top 3 recommendations of the model
    value = int(torch.rand(1)*3)
    token = dist[-(1+value)].unsqueeze(dim=0).to('cpu')

    # Add that word to the sentence
    src = torch.cat([src, token])

    if token == 2:
      continue

  src = tokenizer.decode(list(src))
  model.train()

  return src

# Training function

In [None]:
loss_fn = nn.CrossEntropyLoss()

In [None]:
def train(model, optim, train_dl, val_dl, l_fn, epochs, check_overfit=True):
  H = {
      'train_loss': [],
      'train_acc': [],
      'val_loss': [],
      'val_acc': [],
  }
  e0 = 0

  for e in range(epochs):
    model.train()

    print(f"------------------------------------------------------------")
    print(f"EPOCH : {e+1}")

    train_loss = 0
    val_loss = 0
    train_correct = 0
    val_correct = 0

    print("Training Step:")
    with tqdm(total=len(train_dl)) as pbar:
      for _, (X, Y) in enumerate(train_dl):
        Y = Y.view(-1)
        X, Y = (X.to(device), Y.to(device))

        # Forward prop
        pred = model(X)
        B,T,C = pred.shape
        pred = pred.view(B*T, C)
        loss = loss_fn(pred, Y)
        train_loss += loss.to("cpu").detach().numpy()

        # Back prop
        optim.zero_grad()
        loss.backward()
        optim.step()
        train_correct += sum([1 for i, val in enumerate(pred.argmax(1)) if int(val) == int(Y[i])])

        pbar.update(1)

    print("\nValidation Step:")
    with tqdm(total=len(val_dl)) as pbar:
      with torch.no_grad():
        model.eval()

        for _, (X, Y) in enumerate(val_dl):
          Y = Y.view(-1)
          X, Y = (X.to(device), Y.to(device))

          pred = model(X)
          B,T,C = pred.shape
          pred = pred.view(B*T, C)
          loss = loss_fn(pred, Y)

          val_loss += loss.to("cpu").detach().numpy()
          val_correct += sum([1 for i, val in enumerate(pred.argmax(1)) if int(val) == int(Y[i])])
          pbar.update(1)


    # Store the important specifications of the training process
    H["train_loss"].append(train_loss)
    H["train_acc"].append(train_correct / len(train_dl.dataset))
    H["val_loss"].append(val_loss)
    H["val_acc"].append(val_correct/ len(val_dl.dataset))

    # Print the messages appropriately
    print(f"\n\nTrain Loss : {train_loss:.2f}")
    print(f"Val Loss : {val_loss:.2f}")
    print(f"Train Accuracy : {train_correct/len(train_dl.dataset):.4f}")
    print(f"Val Accuracy : {val_correct/len(val_dl.dataset):.4f}")

    # Check for overfitting
    if check_overfit and (len(H['val_acc']) > 1 and H['val_acc'][-1] < H['val_acc'][-2]):
      print("The model is showing signs of over fitting enter Y to continue, or N for breaking the training loop")
      t = input()
      if t == 'Y':
        continue
      else:
        return H, e
    e0 = e
  return H, e0

# Training the model

In [None]:
# Hyperparameters
nhead = 4
d_model = 400
ntokens = tokenizer.vocab_size
num_inputs = seq_len
model = TransformerModel(nhead=nhead,
                         d_model=d_model,
                         ntokens=ntokens,
                         num_experts=4,
                         seq_len=num_inputs,
                         nlayers=4,
                         device=device,
                         compute_loss=True,
                         noise_const=0.4).to(device)

opt = torch.optim.Adam(model.parameters(), lr=3e-5)

In [None]:
epochs = 3
num_sequences = 10_000
num_repetitions = 5

for i in range(num_repetitions):
  train_dl = DataLoader(token_ds(train_ds[i*num_sequences:(i+1)*num_sequences], seq_len), num_workers=2, batch_size=batch_size)

  start_time = time.time()
  train(model, opt, train_dl, val_dl, loss_fn, epochs, check_overfit=True)
  end_time = time.time()

  print(f'\nElapsed Time: {end_time-start_time} \n')

100%|██████████| 10000/10000 [00:11<00:00, 870.01it/s]


------------------------------------------------------------
EPOCH : 1
Training Step:


100%|██████████| 901/901 [07:32<00:00,  1.99it/s]



Validation Step:


100%|██████████| 352/352 [00:59<00:00,  5.88it/s]




Train Loss : 9728.70
Val Loss : 3797.47
Train Accuracy : 1.8001
Val Accuracy : 20.1414
------------------------------------------------------------
EPOCH : 2
Training Step:


100%|██████████| 901/901 [07:28<00:00,  2.01it/s]



Validation Step:


100%|██████████| 352/352 [01:00<00:00,  5.79it/s]




Train Loss : 9718.26
Val Loss : 3794.37
Train Accuracy : 2.5593
Val Accuracy : 30.5554
------------------------------------------------------------
EPOCH : 3
Training Step:


100%|██████████| 901/901 [07:39<00:00,  1.96it/s]



Validation Step:


100%|██████████| 352/352 [01:00<00:00,  5.81it/s]




Train Loss : 9710.96
Val Loss : 3792.22
Train Accuracy : 4.3856
Val Accuracy : 6.7732
The model is showing signs of over fitting enter Y to continue, or N for breaking the training loop
N

Elapsed Time: 1575.4972562789917 



 40%|███▉      | 3966/10000 [00:03<00:05, 1043.36it/s]


KeyboardInterrupt: 

# Testing

In [None]:
test = 'from what i understand'
generate(test, num_words=500)

In [None]:
test = 'from what i understand'
generate(test, num_words=500)