In [1]:
import time
import torch
import torch.nn as nn
import numpy as np
from model import GPT, GPTConfig
from context_free_grammar import CFG
import wandb
import lightning.pytorch as pl
import pytorch_lightning

In [2]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33maboitrea[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
cfg = CFG(L=3, ns=[1, 3, 3, 10], nr=[2, 2, 2], T=[8, 8, 8])
sentence_length = np.prod(cfg.T)

In [5]:
start = time.time()
config = GPTConfig(vocab_size=cfg.ns[-1], n_embd=384, n_head=6, n_layer=6)
m = GPT(config)
m = nn.DataParallel(m)
m.to(config.device)

number of parameters: 10.64M


DataParallel(
  (module): GPT(
    (transformer): ModuleDict(
      (wte): Embedding(10, 384)
      (wpe): Embedding(256, 384)
      (drop): Dropout(p=0.0, inplace=False)
      (h): ModuleList(
        (0-5): 6 x Block(
          (ln_1): LayerNorm()
          (attn): MultiHeadAttention(
            (heads): ModuleList(
              (0-5): 6 x Head(
                (key): Linear(in_features=384, out_features=64, bias=False)
                (query): Linear(in_features=384, out_features=64, bias=False)
                (value): Linear(in_features=384, out_features=64, bias=False)
                (dropout): Dropout(p=0.0, inplace=False)
              )
            )
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (ln_2): LayerNorm()
          (mlp): MLP(
            (c_fc): Linear(in_features=384, out_features=1536, bias=True)
            (gelu): GELU(approximate='none')
            (c_proj)

In [6]:
 # print the number of parameters in the model
print(sum(p.numel() for p in m.parameters()) / 1e6, "M parameters")

10.742784 M parameters


In [4]:
# data loading = sample new sentences to fill-in the mini-batch
def get_batch(config: GPTConfig = GPTConfig()):
    sentence = cfg.sample_flattened(1)[0][0].view(sentence_length)  # reshape in a 1d tensor
    # generate a small batch of data of inputs x and targets y
    ix = torch.randint(0, sentence_length - config.block_size, size=(config.batch_size,))
    x = torch.stack([sentence[i: i + config.block_size] for i in ix])
    y = torch.stack([sentence[i+1: i + config.block_size + 1] for i in ix])
    x, y = x.to(config.device), y.to(config.device)
    return x, y

In [5]:
@torch.no_grad()
def estimate_loss(m, eval_iters):
    # This function samples a new batch of sentences and evaluates the loss of the model
    out = {}
    m.eval()
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
        X, Y = get_batch()
        logits = m(X)
        loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=-1)
        losses[k] = loss.item()
    out["val"] = losses.mean()
    m.train()
    return out

In [6]:
context_length = 3
@torch.no_grad()
def estimate_grammar_err(m, n_gen=100):
    m.eval()
    model = m.module
    # generate n_gen sentences from the model and check their correctness
    # for generating sentences from the model, we first sample a real sentence from the grammar
    # then, the model is given the first 'context_length' tokens and asked to complete the sentence
    # Returns the number of sentence correct (with 0 mistake) at each level
    error_per_sentence = []    
    for i in range(n_gen):
        mistakes = []
        context = cfg.sample_flattened(1)[0][0][:,:context_length].to(config.device)
        gen_sentence = m.module.generate(context.reshape(1,3), max_new_tokens=sentence_length-context_length)[0].view(-1,1)
        _, err = cfg.collapse_and_get_err(gen_sentence.view(*cfg.T).cpu())
        for level_errors in err:
            mistakes.append(torch.count_nonzero(level_errors).detach().numpy())
        error_per_sentence.append(np.array(mistakes))
    error_per_sentence = np.array(error_per_sentence)
    # compute number of sentence that are correct at each level of the grammar
    res = []
    for l in range(cfg.L):
        nb_correct = (n_gen - np.count_nonzero(error_per_sentence[:,l]))
        res.append(nb_correct)
    m.train()
    return np.array(res)

In [10]:
training_parameters = {'max_iters' : 15000,
                       'eval_interval' : 500,
                       'eval_iters' : 50,
                       'quality_metric_iters' : 50,
                       'learning_rate' : 1e-3,
                       'architecture':"GPT 10.7M",
                       'grammar': cfg.__str__(),
                       'batch_size':config.batch_size,}
training_parameters['optimizer'] = torch.optim.AdamW(m.parameters(), lr=training_parameters['learning_rate'])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(training_parameters['optimizer'], mode='min', patience=2, factor=0.1) # Divide lr by 10

In [17]:
# Training loop
def train(m):
    for iter in range(training_parameters['max_iters']):
        # every once in a while evaluate the loss on newly generated sentences
        if iter % training_parameters['eval_interval'] == 0 or iter == training_parameters['max_iters'] - 1:
            val_loss = estimate_loss(m, training_parameters['eval_iters'])['val']
            print(
                f"step {iter}: val loss {val_loss:.4f}"
            )
            scheduler.step(metrics=val_loss)
            
            errors = estimate_grammar_err(m, training_parameters['quality_metric_iters'])
            print(
                f"step {iter}: correct sentences for each level{errors}"
            )
            log_dict = {"nb sentences seen": iter*config.batch_size,
                          "loss": val_loss,
                          "learning_rate": training_parameters['optimizer'].param_groups[0]["lr"]}
            for i,err in enumerate(errors):
               log_dict[f'% of correct sentences at level {i}'] = err/training_parameters['quality_metric_iters']
            wandb.log(log_dict)
    
        # sample a batch of data
        xb, yb = get_batch()
    
        # evaluate the loss
        logits = m(xb)
        training_parameters['optimizer'].zero_grad(set_to_none=True)
        loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), yb.view(-1), ignore_index=-1)
        loss.backward()
        training_parameters['optimizer'].step()

In [None]:
wandb.init(project='CFG-experiments',config=training_parameters)
wandb.watch(m, log='all', log_freq=1)

train(m)
wandb.finish()

step 0: val loss 2.4564
step 0: correct sentences for each level[0 0 0]
step 500: val loss 0.4503
step 500: correct sentences for each level[0 0 0]
step 1000: val loss 0.2039
step 1000: correct sentences for each level[ 0  0 10]
step 1500: val loss 0.1848
step 1500: correct sentences for each level[ 0  0 14]
step 2000: val loss 0.1533
step 2000: correct sentences for each level[ 0  0 18]
step 2500: val loss 0.1441
step 2500: correct sentences for each level[ 0  0 39]
step 3000: val loss 0.1383
step 3000: correct sentences for each level[ 0  0 17]
step 3500: val loss 0.1402
step 3500: correct sentences for each level[ 0  0 22]
step 4000: val loss 0.1354
step 4000: correct sentences for each level[ 0  1 39]
step 4500: val loss 0.1329
step 4500: correct sentences for each level[ 0  0 38]
step 5000: val loss 0.1286
step 5000: correct sentences for each level[ 0  4 40]
step 5500: val loss 0.1253


# GPT 2 with 85M parameters

In [18]:
torch.cuda.empty_cache()

In [19]:
# New experiment with larger model and same grammar
config = GPTConfig(vocab_size=cfg.ns[-1], n_embd=768, n_head=12, n_layer=12)
m_large = GPT(config)
m_large = nn.DataParallel(m_large)
m_large.to(config.device)

number of parameters: 85.04M


DataParallel(
  (module): GPT(
    (transformer): ModuleDict(
      (wte): Embedding(10, 768)
      (wpe): Embedding(256, 768)
      (drop): Dropout(p=0.0, inplace=False)
      (h): ModuleList(
        (0-11): 12 x Block(
          (ln_1): LayerNorm()
          (attn): MultiHeadAttention(
            (heads): ModuleList(
              (0-11): 12 x Head(
                (key): Linear(in_features=768, out_features=64, bias=False)
                (query): Linear(in_features=768, out_features=64, bias=False)
                (value): Linear(in_features=768, out_features=64, bias=False)
                (dropout): Dropout(p=0.0, inplace=False)
              )
            )
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (ln_2): LayerNorm()
          (mlp): MLP(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): GELU(approximate='none')
            (c_p

In [20]:
print(sum(p.numel() for p in m_large.parameters()) / 1e6, "M parameters")

85.23264 M parameters


In [21]:
training_parameters = {'max_iters' : 15000,
                       'eval_interval' : 500,
                       'eval_iters' : 50,
                       'quality_metric_iters' : 50,
                       'learning_rate' : 1e-4,
                       'architecture':"GPT 85.04M",
                       'grammar': cfg.__str__(),
                       'batch_size':config.batch_size,}
training_parameters['optimizer'] = torch.optim.AdamW(m_large.parameters(), lr=training_parameters['learning_rate'])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(training_parameters['optimizer'], mode='min', patience=2, factor=0.1) # Divide lr by 10

In [None]:
wandb.init(project='CFG-experiments',config=training_parameters)
wandb.watch(m_large, log='all', log_freq=1)

train(m_large)
wandb.finish()

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113957432098687, max=1.0…

step 0: val loss 2.3375
step 0: correct sentences for each level[0 0 0]
step 500: val loss 0.2701
step 500: correct sentences for each level[0 0 0]
step 1000: val loss 0.1835
step 1000: correct sentences for each level[0 0 0]
step 1500: val loss 0.1413
step 1500: correct sentences for each level[ 0  0 31]
step 2500: val loss 0.1309
