In [191]:
import torch
import numpy as np
from lm_from_scratch.models.t5 import T5
from lm_from_scratch.corpus.decision_corpus import DecisionCorpus
import pandas as pd
from artifacts import DECISION_CORPUS_RAW
from tqdm import tqdm

import sentencepiece as spm
from random import randint, random
from transformers import T5Tokenizer
from tokenizers import AddedToken

VOCAB_SIZE = 15000
MAX_LEN = 128 # 512 # what is the maximum context length for predictions?
EMBED_DIM = 384 # 768
N_LAYERS = 1
ATTN_HEADS = 6 # 64 * 6 = 384
DROPOUT = 0.0

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

EVAL_ITERS = 10
MAX_ITERS = 1000
EVAL_INTERVAL = 100
LEARNING_RATE = 1e-3

BATCH_SIZE = 64 # how many independent sequences will we process in parallel?

MAX_SENTENCE_LEN = MAX_LEN // 2
MIN_SENTENCE_LEN = 10

PAD_TOKEN_ID = 0
EOS_TOKEN_ID = 1
UNK_TOKEN_ID = 2
BOS_TOKEN_ID = 3

SENTINEL_TOKEN_COUNT = 200
SENTINEL_TOKEN_ID = VOCAB_SIZE - SENTINEL_TOKEN_COUNT
IGNORE_INDEX = -100


To specify which task the model should perform, we add a task-specific (text) prefix to the original input sequence before feeding it to the model.

See Appendix D for full examples of preprocessed inputs.

In [126]:
corpus = DecisionCorpus()
data = corpus.get_text()

In [127]:
with open(DECISION_CORPUS_RAW, "w", encoding="utf-8") as f:
    for d in data:
        f.write(d + "\n")

spm.SentencePieceTrainer.Train(
    input=DECISION_CORPUS_RAW,
    model_prefix='sentencepiece_tokenizer',
    vocab_size=VOCAB_SIZE - SENTINEL_TOKEN_COUNT,
    pad_id=PAD_TOKEN_ID,
    unk_id=UNK_TOKEN_ID,
    eos_id=EOS_TOKEN_ID,
    bos_id=BOS_TOKEN_ID,
    pad_piece='<pad>',
    unk_piece='<unk>',
    eos_piece='</s>',
    bos_piece='<s>',
    model_type='unigram',
)

sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : 
trainer_spec {
  input: /home/clem/Source/sandbox/lm-from-scratch/artifacts/decision-raw.txt
  input_format: 
  model_prefix: sentencepiece_tokenizer
  model_type: UNIGRAM
  vocab_size: 14800
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 2
  bos_id: 3
  eos_id: 1
  pad_id: 0
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece: <pad>
  unk_surface:  �

In [128]:
# tokenizer = spm.SentencePieceProcessor(model_file="sentencepiece_tokenizer.model")
tokenizer = T5Tokenizer("sentencepiece_tokenizer.model", extra_ids=0)

tokenizer.add_special_tokens({
    "additional_special_tokens": [
        AddedToken(content=f"<extra_id_{i}>",
                   single_word=False,
                   normalized=False,
                   special=True) for i in range(SENTINEL_TOKEN_COUNT)
    ]})

200

# Load sentences

In [129]:
corpus_df = corpus.df.sample(
        frac=1,
        random_state=42
    ).reset_index(
        drop=True)

In [130]:
sentences = corpus.get_sentences()

len(sentences)

  0%|          | 0/10000 [00:00<?, ?it/s]

100%|██████████| 10000/10000 [00:04<00:00, 2454.03it/s]


172244

In [131]:
sentences_ids = [tokenizer.encode(sentence, add_special_tokens=False) for sentence in tqdm(sentences)]
sentences_ids = [sentence_ids for sentence_ids in sentences_ids if len(sentence_ids) > MIN_SENTENCE_LEN]

len(sentences_ids)

  0%|          | 0/172244 [00:00<?, ?it/s]

100%|██████████| 172244/172244 [00:39<00:00, 4320.88it/s]


148798

# Split corpus

In [132]:
# Train and test splits
# first 90% will be train, rest val
sentence_split = int(0.9*len(sentences_ids))

train_data = sentences_ids[:sentence_split]
val_data = sentences_ids[sentence_split:]

# Batch preparation

In [133]:
def get_sentence_batch(split, batch_size=BATCH_SIZE):
    data = train_data if split == 'train' else val_data

    sentence_indices = torch.randint(len(data), (batch_size,))
    
    for sentence_ix  in sentence_indices:
        sentence_ids = data[int(sentence_ix)]

        ix = torch.randint(max(1, len(sentence_ids) - MAX_LEN), (1,))
        sentence_ids = sentence_ids[ix:ix+MAX_LEN]

        pred_count = min(MAX_LEN, max(1, round(len(sentence_ids) * 0.15)))
        
        masked_positions = np.random.choice(range(1, len(sentence_ids)), pred_count, replace=False)
        masked_positions.sort()
        
        masked_token_ids = sentence_ids.copy()
        target_token_ids = [PAD_TOKEN_ID]

        for sentinel_id, masked_position in enumerate(masked_positions):
            target_token_ids.append(SENTINEL_TOKEN_ID + sentinel_id)
            target_token_ids.append(masked_token_ids[masked_position])

            masked_token_ids[masked_position] = SENTINEL_TOKEN_ID + sentinel_id
        
        target_token_ids.append(SENTINEL_TOKEN_ID + sentinel_id + 1)
        target_token_ids.append(EOS_TOKEN_ID)

        mask_padding = MAX_LEN - len(target_token_ids)
        target_attn_mask = np.zeros(MAX_LEN)
        target_attn_mask[:-mask_padding] = 1

        sentence_padding = MAX_LEN - len(sentence_ids)
    
        yield [
            np.concatenate([sentence_ids, [PAD_TOKEN_ID] * sentence_padding]),
            np.concatenate([masked_token_ids, [PAD_TOKEN_ID] * sentence_padding]),
            np.concatenate([target_token_ids, [PAD_TOKEN_ID] * mask_padding]),
            np.concatenate([target_token_ids[1:], [IGNORE_INDEX] * (mask_padding +1)]),
            np.concatenate([(np.array(sentence_ids) != PAD_TOKEN_ID) * 1, [0] * sentence_padding]),
            target_attn_mask,
        ]

In [134]:
def get_batch(split, batch_size):
    return map(
         lambda x: torch.tensor(x, device=DEVICE, dtype=torch.long), 
         zip(*get_sentence_batch(split, batch_size=batch_size)))

(token_ids,
 masked_token_ids,
 target_token_ids,
 target_token_loss_ids,
 src_attention_masks,
 target_attention_masks) = get_batch("train", batch_size=1)

print("token_ids              \n", token_ids, "\n")
print("masked_token_ids       \n", masked_token_ids, "\n")
print("target_token_ids       \n", target_token_ids, "\n")
print("target_token_loss_ids       \n", target_token_loss_ids, "\n")
print("src_attention_masks    \n", src_attention_masks, "\n")
print("target_attention_masks \n", target_attention_masks, "\n")

token_ids              
 tensor([[ 383,  459, 1311,    9,    6,  509,   15,    6,   91,   78, 7006,    5,
          447,   16,  483,    5,    9,    6, 1190,   30,  151,    5,    7, 1421,
          489,   12, 3295,    7, 1598, 1510,    4,   13,  261,  727,   24,    9,
            6, 3212,   22,   65,  700,   15,    6,   99, 1280,   51,   21, 1472,
            5,   57,  116,   13,  399,  126,  846,   12,   33,    6,   91,  560,
           15,    6, 1190,   11,  116,   22,   65, 1854,   13,  316,  126,  846,
           19,   57, 1328,   12, 6255,   14,    7,  408,    8,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0]], device='cuda:0') 

masked_token_ids       
 tensor([[  383,   459,  1311,     9,     6,   509, 14800, 1480

# Model Training

In [192]:
model = T5(
    dim=EMBED_DIM,
    vocab_size=VOCAB_SIZE,
    enc_depth=N_LAYERS,
    enc_heads=ATTN_HEADS,
    enc_dim_head=EMBED_DIM // ATTN_HEADS,
    enc_mlp_mult=4,
    dec_depth=N_LAYERS,
    dec_heads=ATTN_HEADS,
    dec_dim_head=EMBED_DIM // ATTN_HEADS,
    dec_mlp_mult=4,
    dropout=DROPOUT,
)
m = model.to(DEVICE)

In [182]:
import torch.nn.functional as F

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(EVAL_ITERS)
        for k in range(EVAL_ITERS):
            (token_ids,
            masked_token_ids,
            target_token_ids,
            target_token_ids_loss,
            src_attention_masks,
            target_attention_masks) = get_batch(split, BATCH_SIZE)
            logits = model(masked_token_ids, target_token_ids,
                           src_attention_masks, target_attention_masks)
        
            loss = F.cross_entropy(logits.transpose(-2,-1), target_token_ids_loss)

            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

for iter in range(MAX_ITERS):
    optimizer.zero_grad()

    # every once in a while evaluate the loss on train and val sets
    if iter % EVAL_INTERVAL == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    (token_ids,
    masked_token_ids,
    target_token_ids,
    target_token_ids_loss,
    src_attention_masks,
    target_attention_masks) = get_batch('train', BATCH_SIZE)

    # evaluate the loss
    logits = model(masked_token_ids, target_token_ids,
                   src_attention_masks, target_attention_masks)
    
    loss = F.cross_entropy(logits.transpose(-2,-1), target_token_ids_loss)
    
    loss.backward()
    optimizer.step()


step 0: train loss 2.4647, val loss 2.6757


# Model visualization

https://opendelta.readthedocs.io/en/latest/notes/overview.html

In [193]:
from opendelta import Visualization
model_vis = Visualization(model)
model_vis.structure_graph()

# Test time

In [168]:
sentinel_token = tokenizer.decode(SENTINEL_TOKEN_ID)

(token_ids,
masked_token_ids,
target_token_ids,
target_token_ids_loss,
src_attention_masks,
target_attention_masks) = get_batch("train", batch_size=1)

print(tokenizer.decode([token_ids[0, 5]], skip_special_tokens=False))

token_ids[0, 5] = SENTINEL_TOKEN_ID

print(tokenizer.decode(token_ids[0].cpu().numpy(), skip_special_tokens=False))

France
4. La société RT <extra_id_39> a porté plainte et s'est constituée partie civile du chef de diffamation publique envers un particulier par courrier réceptionné par le service d'accueil unique du justiciable (SAUJ) le 21 janvier 2022, transmis au secrétariat commun de l'instruction, le 25 janvier suivant.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>


In [169]:
from torch.nn import functional as F

# generate from the model
def generate(model, token_ids, attn_mask, max_new_tokens, max_len=MAX_LEN):

    targets = torch.zeros((1, 1), device=DEVICE).long()

    # idx is (B, T) array of indices in the current context
    for i in range(max_new_tokens):

        # get the predictions
        model.eval()
        logits = model(token_ids, targets[:, -max_len:], attn_mask)
        model.train()

        # focus only on the last time step
        logits = logits[:, -1, :] # becomes (B, C)

        # apply softmax to get probabilities
        probs = F.softmax(logits, dim=-1) # (B, C)

        targets_next = torch.argmax(probs, dim=-1)[None,:]
        # targets_next = torch.multinomial(probs, num_samples=1) # (B, 1)

        # append sampled index to the running sequence
        targets = torch.cat((targets, targets_next), dim=1) # (B, T+1)
    return targets


print(tokenizer.decode(
        generate(model,
                 token_ids, src_attention_masks,
                 max_new_tokens=3)[0].tolist()))


<pad> <extra_id_39>, <extra_id_95>
