In [None]:
import torch
import numpy as np
from lm_from_scratch.models.t5_small import T5
from lm_from_scratch.corpus.decision_corpus import DecisionCorpus
from artifacts import DECISION_CORPUS_RAW
from tqdm import tqdm
import math

import sentencepiece as spm
from transformers import T5Tokenizer
from tokenizers import AddedToken

# tokenizer & model parameters
VOCAB_SIZE = 15000

MAX_LEN = 128 # 512 # what is the maximum context length for predictions?
MAX_SENTENCE_LEN = MAX_LEN // 2
MIN_SENTENCE_LEN = 100

EMBED_DIM = 384 # 768
N_LAYERS = 1
ATTN_HEADS = 6 # 64 * 6 = 384
DROPOUT = 0.0


PAD_TOKEN_ID = 0
EOS_TOKEN_ID = 1
UNK_TOKEN_ID = 2
BOS_TOKEN_ID = 3

DESOINING_RATE = 0.15
SENTINEL_TOKEN_COUNT = math.ceil(MAX_LEN * DESOINING_RATE) + 1
SENTINEL_TOKEN_ID = VOCAB_SIZE - SENTINEL_TOKEN_COUNT
IGNORE_INDEX = -100


# training parameters
BATCH_SIZE = 64 # how many independent sequences will we process in parallel?
MAX_ITERS = 1000
EVAL_INTERVAL = 100
LEARNING_RATE = 1e-3
EVAL_ITERS = 100

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


To specify which task the model should perform, we add a task-specific (text) prefix to the original input sequence before feeding it to the model.

See Appendix D for full examples of preprocessed inputs.

In [2]:
corpus = DecisionCorpus()
data = corpus.get_text()

# BPE tokenizer

In [28]:
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

pad_token='<pad>'
unk_token='<unk>'
eos_token='</s>'
bos_token='<s>'

sentinel_tokens = [f"<extra_id_{i}>" for i in range(SENTINEL_TOKEN_COUNT)]

tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
trainer = BpeTrainer(vocab_size=VOCAB_SIZE,
                     special_tokens=[pad_token, unk_token, eos_token, bos_token] + sentinel_tokens)
tokenizer.pre_tokenizer = Whitespace()

tokenizer.train_from_iterator(data, trainer)

# tokenizer.enable_padding(direction="right", 
#                          pad_id=0,
#                          pad_token=pad_token,
#                          length=MAX_LEN)

SENTINEL_TOKEN_ID = 4






# Sentence piece tokenizer

In [3]:
with open(DECISION_CORPUS_RAW, "w", encoding="utf-8") as f:
    for d in data:
        f.write(d + "\n")

spm.SentencePieceTrainer.Train(
    input=DECISION_CORPUS_RAW,
    model_prefix='sentencepiece_tokenizer',
    vocab_size=VOCAB_SIZE - SENTINEL_TOKEN_COUNT,
    pad_id=PAD_TOKEN_ID,
    unk_id=UNK_TOKEN_ID,
    eos_id=EOS_TOKEN_ID,
    bos_id=BOS_TOKEN_ID,
    pad_piece='<pad>',
    unk_piece='<unk>',
    eos_piece='</s>',
    bos_piece='<s>',
    model_type='unigram',
)

sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : 
trainer_spec {
  input: /home/clem/Source/sandbox/lm-from-scratch/artifacts/decision-raw.txt
  input_format: 
  model_prefix: sentencepiece_tokenizer
  model_type: UNIGRAM
  vocab_size: 14979
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 2
  bos_id: 3
  eos_id: 1
  pad_id: 0
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece: <pad>
  unk_surface:  �

In [4]:
# tokenizer = spm.SentencePieceProcessor(model_file="sentencepiece_tokenizer.model")
tokenizer = T5Tokenizer("sentencepiece_tokenizer.model", extra_ids=0)

tokenizer.add_special_tokens({
    "additional_special_tokens": [
        AddedToken(content=f"<extra_id_{i}>",
                   single_word=False,
                   normalized=False,
                   special=True) for i in range(SENTINEL_TOKEN_COUNT)
    ]})

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


21

# Load sentences

In [25]:
sentences = corpus.get_sentences()

len(sentences)

100%|██████████| 10000/10000 [00:04<00:00, 2458.60it/s]


172244

In [35]:
sentences_ids = [tokenizer.encode(sentence, add_special_tokens=False).ids for sentence in tqdm(sentences)]
sentences_ids = [sentence_ids for sentence_ids in sentences_ids if len(sentence_ids) > MIN_SENTENCE_LEN]

len(sentences_ids)

100%|██████████| 172244/172244 [00:21<00:00, 7931.94it/s]


32821

# Split corpus

In [36]:
# Train and test splits
# first 90% will be train, rest val
sentence_split = int(0.9*len(sentences_ids))

train_data = sentences_ids[:sentence_split]
val_data = sentences_ids[sentence_split:]

# Batch preparation

In [37]:
def get_sentence_batch(split, batch_size=BATCH_SIZE):
    data = train_data if split == 'train' else val_data

    sentence_indices = torch.randint(len(data), (batch_size,))
    
    for sentence_ix  in sentence_indices:
        sentence_ids = data[int(sentence_ix)]

        ix = torch.randint(max(1, len(sentence_ids) - MAX_LEN), (1,))
        sentence_ids = sentence_ids[ix:ix+MAX_LEN]

        pred_count = min(SENTINEL_TOKEN_COUNT - 1, max(1, round(len(sentence_ids) * DESOINING_RATE)))
        
        masked_positions = np.random.choice(range(1, len(sentence_ids)), pred_count, replace=False)
        masked_positions.sort()
        
        masked_token_ids = sentence_ids.copy()
        target_token_ids = [PAD_TOKEN_ID]

        for sentinel_id, masked_position in enumerate(masked_positions):
            target_token_ids.append(SENTINEL_TOKEN_ID + sentinel_id)
            target_token_ids.append(masked_token_ids[masked_position])

            masked_token_ids[masked_position] = SENTINEL_TOKEN_ID + sentinel_id
        
        target_token_ids.append(SENTINEL_TOKEN_ID + sentinel_id + 1)
        target_token_ids.append(EOS_TOKEN_ID)

        mask_padding = MAX_LEN - len(target_token_ids)
        target_attn_mask = np.zeros(MAX_LEN)
        target_attn_mask[:-mask_padding] = 1

        sentence_padding = MAX_LEN - len(sentence_ids)

        target_tokenid_for_loss = np.concatenate([target_token_ids[1:], [IGNORE_INDEX] * (mask_padding +1)])
        # target_tokenid_for_loss[target_tokenid_for_loss >= SENTINEL_TOKEN_ID] = IGNORE_INDEX
    
        yield [
            np.concatenate([sentence_ids, [PAD_TOKEN_ID] * sentence_padding]),
            np.concatenate([masked_token_ids, [PAD_TOKEN_ID] * sentence_padding]),
            np.concatenate([target_token_ids, [PAD_TOKEN_ID] * mask_padding]),
            target_tokenid_for_loss,
            np.concatenate([(np.array(sentence_ids) != PAD_TOKEN_ID) * 1, [0] * sentence_padding]),
            np.array(target_attn_mask),
        ]

In [38]:
def get_batch(split, batch_size):
    return map(
         lambda x: torch.tensor(x, device=DEVICE, dtype=torch.long), 
         zip(*get_sentence_batch(split, batch_size=batch_size)))

(token_ids,
 masked_token_ids,
 target_token_ids,
 target_token_loss_ids,
 src_attention_masks,
 target_attention_masks) = get_batch("train", batch_size=1)

print("token_ids              \n", token_ids, "\n")
print("masked_token_ids       \n", masked_token_ids, "\n")
print("target_token_ids       \n", target_token_ids, "\n")
print("target_token_loss_ids       \n", target_token_loss_ids, "\n")
print("src_attention_masks    \n", src_attention_masks, "\n")
print("target_attention_masks \n", target_attention_masks, "\n")

token_ids              
 tensor([[   61,    82,    81,    75,   232,   643,    81,   330,    43,   232,
            86,   360,   169,   318,    99,   123,    70,   399,    37,   676,
            38,  6930,   371,    97,    31,   638,  2204,   169,   535,   655,
           575,   203,   169,   482,   166,    97,    31,  4062,   181,  2268,
           222,    81,   824,    47,    82,  9107,   195,   966,   720,   166,
          4266,   101, 12248,    36,   319,   169,   584,    97,    31,   673,
            50,    41,   356,   140,   173,  2542,   222,    81,   824,    44,
           232,  2994,   203,   259,  1704,   171,  1390,    36,   643,   171,
           553,   716,   171,    97,  3239,   330,    46,   232,    42,   356,
           140,   246,    81,    69,    82,    81,    57,   232,   801,    81,
           330,    41,   232,    43,   356,   195,  7043,   222,    81,   824,
            47,   232,   643,   171,   173, 11271,    36,    81,   330,    42,
           232,    44,   35

  lambda x: torch.tensor(x, device=DEVICE, dtype=torch.long),


# Model Training

During pre-training, we use an “inverse square root” learning rate schedule: $$1/\sqrt{max(n, k)}$$

where:
* n is the current training iteration and
* k is the number of warm-up steps (set to $10^4$ in all of our experiments).

This sets a constant learning rate of 0.01 for the first $10^4$ steps, then exponentially decays the learning rate until pre-training is over.

In [39]:
model = T5(
    dim=EMBED_DIM,
    vocab_size=VOCAB_SIZE,
    enc_depth=N_LAYERS,
    enc_heads=ATTN_HEADS,
    enc_dim_head=EMBED_DIM // ATTN_HEADS,
    enc_mlp_mult=4,
    dec_depth=N_LAYERS,
    dec_heads=ATTN_HEADS,
    dec_dim_head=EMBED_DIM // ATTN_HEADS,
    dec_mlp_mult=4,
    dropout=DROPOUT,
)
m = model.to(DEVICE)

In [40]:
import torch.nn.functional as F

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(EVAL_ITERS)
        for k in range(EVAL_ITERS):
            (_,
            masked_token_ids,
            target_token_ids,
            target_token_ids_loss,
            src_attention_masks,
            target_attention_masks) = get_batch(split, BATCH_SIZE)
            
            logits = model(masked_token_ids, target_token_ids,
                           src_attention_masks, target_attention_masks)
            loss = F.cross_entropy(logits.transpose(-2,-1), target_token_ids_loss)

            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
# lambda_scheduler = lambda x: 1 / math.sqrt(max(x * 100, 10000))
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_scheduler)


In [41]:
for iter in range(MAX_ITERS):
    optimizer.zero_grad()

    # every once in a while evaluate the loss on train and val sets
    if iter % EVAL_INTERVAL == 0:
        losses = estimate_loss()
        # current_lr = scheduler.get_last_lr()[0] "lr {current_lr:.4f}"
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    (_,
    masked_token_ids,
    target_token_ids,
    target_token_ids_loss,
    src_attention_masks,
    target_attention_masks) = get_batch('train', BATCH_SIZE)

    # evaluate the loss
    logits = model(masked_token_ids, target_token_ids,
                   src_attention_masks, target_attention_masks)
    
    loss = F.cross_entropy(logits.transpose(-2,-1), target_token_ids_loss)
    
    loss.backward()
    optimizer.step()
    # scheduler.step()

step 0: train loss 9.7630, val loss 9.7669
step 100: train loss 2.7886, val loss 2.9647
step 200: train loss 2.6368, val loss 2.8275
step 300: train loss 2.5558, val loss 2.7598
step 400: train loss 2.5198, val loss 2.7206
step 500: train loss 2.5002, val loss 2.7166
step 600: train loss 2.4914, val loss 2.7015
step 700: train loss 2.4516, val loss 2.6770
step 800: train loss 2.4570, val loss 2.6797
step 900: train loss 2.4314, val loss 2.6626


# Test time

In [42]:
from torch.nn import functional as F

# generate from the model
def generate(model, token_ids, attn_mask, max_new_tokens, max_len=MAX_LEN):

    targets = torch.zeros((1, 1), device=DEVICE).long()

    # idx is (B, T) array of indices in the current context
    for i in range(max_new_tokens):

        # get the predictions
        model.eval()
        logits = model(token_ids, targets[:, -max_len:], attn_mask)
        model.train()

        # focus only on the last time step
        logits = logits[:, -1, :] # becomes (B, C)

        # apply softmax to get probabilities
        probs = F.softmax(logits, dim=-1) # (B, C)

        targets_next = torch.argmax(probs, dim=-1)[None,:]
        # targets_next = torch.multinomial(probs, num_samples=1) # (B, 1)

        # append sampled index to the running sequence
        targets = torch.cat((targets, targets_next), dim=1) # (B, T+1)
    return targets

In [51]:
(token_ids,
masked_token_ids,
target_token_ids,
target_token_ids_loss,
src_attention_masks,
target_attention_masks) = get_batch("train", batch_size=1)

In [52]:
print(tokenizer.decode(masked_token_ids[0].cpu().numpy(), skip_special_tokens=False))

print("\nexpected\n")
print(tokenizer.decode(target_token_ids[0].cpu().numpy(), skip_special_tokens=False))
print("\nactual\n")
print(tokenizer.decode(
        generate(model,
                 masked_token_ids, src_attention_masks,
                 max_new_tokens=6)[0].tolist(),
        skip_special_tokens=False))

les condamner in solidum <extra_id_0> payer à l ' AGS CGEA d '[ Localité 21 ] des dommages - intérêts représentant le montant des avances consenties <extra_id_1> salariés , alors « qu ' en condamnant les <extra_id_2> Bosal holding France et <extra_id_3> Nederland <extra_id_4> à réparer le préjudice subi par l ' AGS CGEA , correspondant <extra_id_5> montant des avances consenties <extra_id_6> salariés <extra_id_7> après avoir pourtant écarté la demande des <extra_id_8> tendant <extra_id_9> voir attribuer la qualité de co employeur à ces deux sociétés <extra_id_10> la cour d ' appel a ainsi fait peser sur les sociétés <extra_id_11> la réparation d ' un préjudice sans lien de <extra_id_12> direct avec les fautes <extra_id_13> <extra_id_14> étaient imputées , et <extra_id_15> violé les articles <extra_id_16> <extra_id_17> 1383 du <extra_id_18>

expected

<pad> <extra_id_0> à <extra_id_1> aux <extra_id_2> sociétés <extra_id_3> Bosal <extra_id_4> BV <extra_id_5> au <extra_id_6> aux <extra_id

In [66]:
print(tokenizer.decode([token_ids[0, 5]], skip_special_tokens=False))

token_ids[0, 5] = SENTINEL_TOKEN_ID

print(tokenizer.decode(token_ids[0].cpu().numpy(), skip_special_tokens=False))

print(tokenizer.decode(
        generate(model,
                 token_ids, src_attention_masks,
                 max_new_tokens=1)[0].tolist()))


la
Ainsi fait et jugé par <extra_id_2> Cour de cassation, chambre sociale, et prononcé par le président en son audience publique du vingt-deux mars deux mille vingt-trois.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
<pad> <extra_id_7>


# Model visualization

https://opendelta.readthedocs.io/en/latest/notes/overview.html

In [193]:
from opendelta import Visualization
model_vis = Visualization(model)
model_vis.structure_graph()