<a href="https://colab.research.google.com/github/Druvith/Animation-Nation/blob/master/MakemoreMoE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [11]:
import os

# Create the TinyStories directory if it doesn't exist
os.makedirs("TinyStories", exist_ok=True)

# Download the archive
!wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz

# Extract the archive into the TinyStories directory
!tar -xzf TinyStories_all_data.tar.gz -C TinyStories

--2024-03-21 15:10:40--  https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz
Resolving huggingface.co (huggingface.co)... 65.8.178.12, 65.8.178.93, 65.8.178.118, ...
Connecting to huggingface.co (huggingface.co)|65.8.178.12|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/42/7f/427f7497b6c6596c18b46d5a72e61364fcad12aa433c60a0dbd4d344477b9d81/26cf7605aca15bc4ea6fa637256400d9d01317b28ed296172b2d1dd160cd7699?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27TinyStories_all_data.tar.gz%3B+filename%3D%22TinyStories_all_data.tar.gz%22%3B&response-content-type=application%2Fgzip&Expires=1711293040&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxMTI5MzA0MH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy80Mi83Zi80MjdmNzQ5N2I2YzY1OTZjMThiNDZkNWE3MmU2MTM2NGZjYWQxMmFhNDMzYzYwYTBkYmQ0ZDM0NDQ3N2I5ZDgxLzI2Y2Y3NjA1

In [12]:
import os
import glob
import json

In [13]:
shard_filenames = sorted(glob.glob(os.path.join('TinyStories', "*.json")))
with open(shard_filenames[0], "r") as f:
        data = json.load(f)

stories = [x['story'] for x in data]
text = "\n".join(stories)

In [14]:
len(text)

77586884

In [15]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

	
 !"$%&'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]`abcdefghijklmnopqrstuvwxyz|~ éñ–—‘’“”…
97


In [16]:
import torch

stoi = { ch:i for i, ch in enumerate(chars) }
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[x] for x in l])

data = torch.tensor(encode(text), dtype = torch.long)


In [17]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [18]:
block_size = 32

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

device= 'cuda' if torch.cuda.is_available() else 'cpu'


In [19]:
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1137)

<torch._C.Generator at 0x7892d9d505b0>

In [20]:
class MoeLayer(nn.Module):
    def __init__(self, experts, gate, k=1):
        super().__init__()
        assert len(experts) > 0
        self.experts = nn.ModuleList(experts)
        self.gate = gate
        self.k = k

    def forward(self, inputs: torch.Tensor):
        inputs_squashed = inputs.view(-1, inputs.shape[-1])
        gate_logits = self.gate(inputs_squashed)
        weights, selected_experts = torch.topk(
            gate_logits, self.k
        )
        weights = nn.functional.softmax(
            weights,
            dim=1,
            dtype=torch.float,
        ).type_as(inputs)
        results = torch.zeros_like(inputs_squashed)
        for i, expert in enumerate(self.experts):
            batch_idx, nth_expert = torch.where(selected_experts == i)
            results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(
                inputs_squashed[batch_idx]
            )
        return results.view_as(inputs)

In [23]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias = False)
        self.query = nn.Linear(n_embed, head_size, bias = False)
        self.value = nn.Linear(n_embed, head_size, bias = False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        out = wei @ v
        return out

class MulitHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x =  torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.dropout(self.proj(x))
        return out


class FeedForward(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4* n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
         nn.Dropout(dropout))

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embed, n_head, num_experts=4):
        super().__init__()
        self.sa_head= MulitHeadAttention(n_head, n_embed//n_head)
        self.ffw = MoeLayer(
            experts=[FeedForward(n_embed) for _ in range(num_experts)],
            gate=nn.Linear(n_embed, num_experts, bias=False),
        )

#         self.ffw=  FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.sa_head(self.ln1(x))
        x = x+self.ffw(self.ln2(x))
        return x


class Transformer(nn.Module):
    def __init__(self):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, n_embed, device=device)
        self.position_embedding_table = nn.Embedding(block_size, n_embed, device=device)
        self.blocks = nn.Sequential(*[Block(n_embed, n_head=n_head) for _ in range(n_layer)])
        self.lm_head = nn.Linear(n_embed, vocab_size)


    def forward(self, idx, targets=None):
        B, T = idx.shape

        token_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T).to(device))
        x = token_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x)
        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokes):
        for _ in range(max_new_tokes):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim = -1)
            idx_next = torch.multinomial(probs, num_samples = 1)
            idx = torch.cat((idx, idx_next), dim = 1)
        return idx

In [24]:
!pip install wandb

Collecting wandb
  Downloading wandb-0.16.4-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.42-py3-none-any.whl (195 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m195.4/195.4 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-1.43.0-py2.py3-none-any.whl (264 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m264.6/264.6 kB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting setproctitle (from wandb)
  Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)
Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->wan

In [70]:
import wandb
!wandb login

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [25]:
# hyperparameters
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_embed = 384
n_head = 6
n_layer = 6
dropout = 0.0
# ------------

In [26]:
model = Transformer()
optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)
print(sum(p.numel() for p in model.parameters()), 'total parameters')

32090209 total parameters


In [27]:
model.eval()

Transformer(
  (token_embedding_table): Embedding(97, 384)
  (position_embedding_table): Embedding(256, 384)
  (blocks): Sequential(
    (0): Block(
      (sa_head): MulitHeadAttention(
        (heads): ModuleList(
          (0-5): 6 x Head(
            (key): Linear(in_features=384, out_features=64, bias=False)
            (query): Linear(in_features=384, out_features=64, bias=False)
            (value): Linear(in_features=384, out_features=64, bias=False)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ffw): MoeLayer(
        (experts): ModuleList(
          (0-3): 4 x FeedForward(
            (net): Sequential(
              (0): Linear(in_features=384, out_features=1536, bias=True)
              (1): ReLU()
              (2): Linear(in_features=1536, out_features=384, bias=True)
              (3): Dropout(p=0.0, inplace=Fals

In [71]:


@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            X = X.to(device)
            Y = Y.to(device)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [74]:
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(),lr=1e-4)

wandb.init(
    # set the wandb project where this run will be logged
    project="mixture of experts",

    # track hyperparameters and run metadata
    config={
    'batch_size': batch_size,
    'block_size': block_size,
    'max_iters': max_iters,
    'eval_interval': eval_interval,
    'learning_rate': learning_rate,
    'device': device,
    'eval_iters': eval_iters,
    'n_embd': n_embd,
    'n_embed': n_embed,
    'n_head': n_head,
    'n_layer': n_layer,
    'dropout': dropout
    }
)

wandb.watch(model)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % 100 == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        wandb.log({'train_loss': losses['train'], 'val_loss': losses['val']}, step=iter)

    # sample a batch of data
    xb, yb = get_batch('train')
    xb = xb.to(device)
    yb = yb.to(device)

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

model_dir = 'Mixture of experts models'
os.makedirs(model_dir, exist_ok=True)
final_model_path = os.path.join(model_dir, 'moe_model.pth')
torch.save(model.state_dict(), final_model_path)
print('Final trained model saved!')

wandb.finish()

step 0: train loss 4.8669, val loss 4.8668
step 100: train loss 2.3447, val loss 2.3453
step 200: train loss 2.3033, val loss 2.3023
step 300: train loss 2.2774, val loss 2.2778
step 400: train loss 2.2385, val loss 2.2381
step 500: train loss 2.1665, val loss 2.1676
step 600: train loss 2.0180, val loss 2.0165
step 700: train loss 1.8367, val loss 1.8385
step 800: train loss 1.7020, val loss 1.7083
step 900: train loss 1.6097, val loss 1.6092
step 1000: train loss 1.5354, val loss 1.5375
step 1100: train loss 1.4760, val loss 1.4777
step 1200: train loss 1.4261, val loss 1.4225
step 1300: train loss 1.3794, val loss 1.3810
step 1400: train loss 1.3411, val loss 1.3403
step 1500: train loss 1.3069, val loss 1.3092
step 1600: train loss 1.2763, val loss 1.2792
step 1700: train loss 1.2519, val loss 1.2577
step 1800: train loss 1.2294, val loss 1.2297
step 1900: train loss 1.2060, val loss 1.2077
step 2000: train loss 1.1881, val loss 1.1883
step 2100: train loss 1.1667, val loss 1.1700


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

[34m[1mwandb[0m: [32m[41mERROR[0m Control-C detected -- Run data was not synced


In [76]:
print(sum(p.numel() for p in model.parameters()), 'total parameters')

32090209 total parameters


In [79]:
# generation
context = torch.zeros((1,1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokes=500)[0].tolist()))

	t's hanga nepagerous."
Lily says bling and asked, "I can you something wre next time!" Max says. They want to watch the snows find the other rock. He made a surprise or scared the water. The next day, Max's friend, Benny. In the tack and crew never long that chighs heals knew things, never lemss ok fun on the game. Timmy stopped creames and counted any happy.
Once upon a time, there was a big fanama love in his was so not a went bace. He lived its forest. He had a little many closer and watch th


In [80]:
model_dir = 'Mixture of experts models'
final_model_path = os.path.join(model_dir, 'moe_model1.pth')
torch.save(model.state_dict(), final_model_path)
