In [1]:
import torch
import numpy as np
from lm_from_scratch.models.bert import BERT
from lm_from_scratch.corpus.decision_corpus import DecisionCorpus
import pandas as pd
from artifacts import TOKENIZER_PATH

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.processors import TemplateProcessing
from random import randint, random

VOCAB_SIZE = 15000
N_SEGMENTS = 2
MAX_LEN = 20# 128 # 512 # what is the maximum context length for predictions?
EMBED_DIM = 384 # 768
N_LAYERS = 3
ATTN_HEADS = 6 # 32 * 4 = 128
DROPOUT = 0.1

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

EVAL_ITERS = 10
MAX_ITERS = 10000
EVAL_INTERVAL = 1000
LEARNING_RATE = 1e-3

BATCH_SIZE = 64 # how many independent sequences will we process in parallel?

MAX_SENTENCE_LEN = MAX_LEN // 2
MIN_SENTENCE_LEN = 10

# Corpus and tokenizer setup

In [2]:
CLS_TOKEN_ID = 0
SEP_TOKEN_ID = 1
PAD_TOKEN_ID = 2
MASK_TOKEN_ID = 3
UNK_TOKEN_ID = 4
IGNORE_INDEX = -100

corpus = DecisionCorpus()

tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
trainer = BpeTrainer(vocab_size=VOCAB_SIZE,
                     special_tokens=["[CLS]", "[SEP]", "[PAD]", "[MASK]", "[UNK]"])
tokenizer.pre_tokenizer = Whitespace()

tokenizer.train_from_iterator(corpus.get_text(), trainer)

# post-processing to traditional BERT inputs
tokenizer.post_processor = TemplateProcessing(
    single="[CLS] $A [SEP]",
    pair="[CLS] $A [SEP] $B:1 [SEP]:1",
    special_tokens=[
        ("[CLS]", tokenizer.token_to_id("[CLS]")),
        ("[SEP]", tokenizer.token_to_id("[SEP]")),
    ],
)

# pad the outputs to the longest sentence present
tokenizer.enable_padding(pad_id=PAD_TOKEN_ID, pad_token="[PAD]", length=MAX_LEN)

tokenizer.save(str(TOKENIZER_PATH))






# Load sentence pairs

In [3]:
sentences_pairs = corpus.get_sentence_pairs()

df = pd.DataFrame(sentences_pairs, columns=["sentence_1", "sentence_2_isnext"])

len(df)

162244

# Test for single sentence

In [None]:
output = tokenizer.encode(df.loc[2,0])

print(df.loc[2,0])
print(output.tokens)
print(output.ids)

# Test for paired sentences

In [None]:
output = tokenizer.encode(*df.loc[2,])

print(df.loc[2,])
print(output.tokens)
print(output.ids)
print(output.type_ids) # segment ids
print(output.attention_mask)

In [None]:
output = tokenizer.encode_batch([["Il résulte de l'arrêt attaqué.", "Le 13 avril 2018."],
                                ["Le 13 avril 2018.", "Une enquête préliminaire a été ouverte."]])

for out in output:
    print(out.tokens)
    print(out.type_ids) # segment ids
    print(out.attention_mask)
    print("\n")

# Split dataset

In [4]:
# Train and test splits
sentence_pair_split = int(0.9*len(df))

df_train = df[:sentence_pair_split]
df_eval = df[sentence_pair_split:].reset_index(drop=True)

train_col_1_shuffled = df_train["sentence_2_isnext"].sample(
    frac=1, 
    random_state=212,
    replace=True).reset_index(drop=True)
df_train["sentence_2_notnext"] = train_col_1_shuffled.values

sum(train_col_1_shuffled == df_train["sentence_2_isnext"]) == 0

eval_col_1_shuffled = df_eval["sentence_2_isnext"].sample(
    frac=1,
    random_state=23,
    replace=True).reset_index(drop=True)
df_eval["sentence_2_notnext"] = eval_col_1_shuffled.values

sum(eval_col_1_shuffled == df_eval["sentence_2_isnext"]) == 0

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_train["sentence_2_notnext"] = train_col_1_shuffled.values


False

In [5]:
sentence_pair_train_isnext = tokenizer.encode_batch(
    df_train[["sentence_1", "sentence_2_isnext"]].values)

sentence_pair_train_notnext = tokenizer.encode_batch(
    df_train[["sentence_1", "sentence_2_notnext"]].values)

sentence_pair_val_isnext = tokenizer.encode_batch(
    df_eval[["sentence_1","sentence_2_isnext"]].values)

sentence_pair_val_notnext = tokenizer.encode_batch(
    df_eval[["sentence_1","sentence_2_isnext"]].values)

In [6]:
def filter_by_sentence_length(sentence_pairs_isnext, sentence_pairs_notnext, max_len):
    pairs_isnext = []
    pairs_notnext = []

    for pair_isnext, pair_notnext in zip(sentence_pairs_isnext, sentence_pairs_notnext):
        if len(pair_isnext) > max_len or len(pair_isnext) < MIN_SENTENCE_LEN:
            continue
        if len(pair_notnext) > max_len or len(pair_notnext) < MIN_SENTENCE_LEN:
            continue
        pairs_isnext.append(pair_isnext)
        pairs_notnext.append(pair_notnext)
    
    return pairs_isnext, pairs_notnext

sentence_pair_train_data = filter_by_sentence_length(
                                sentence_pair_train_isnext, 
                                sentence_pair_train_notnext,
                                MAX_LEN)

sentence_pair_val_data = filter_by_sentence_length(
                                sentence_pair_val_isnext,
                                sentence_pair_val_notnext,
                                MAX_LEN)

In [7]:
def get_sentence_pair_batch(split, batch_size=BATCH_SIZE):
    data_isnext, data_notnext = sentence_pair_train_data if split == 'train' else sentence_pair_val_data
    pair_ix = torch.randint(len(data_isnext), (batch_size,))
    max_pred_count = len(data_isnext[0])
    
    for i, ix  in enumerate(pair_ix):
        is_next = i % 2 == 0

        sentence_pair = data_isnext[ix] if is_next else data_notnext[ix]

        available_mask = np.where(np.array(sentence_pair.special_tokens_mask) == 0)[0]
        pred_count = min(max_pred_count, max(1, round(len(available_mask) * 0.15)))
        
        masked_positions = np.random.choice(available_mask, pred_count, replace=False)
        masked_positions.sort()
        
        masked_token_ids = sentence_pair.ids.copy()
        for masked_position in masked_positions:
            if random() < 0.8:  # 80%
                masked_token_ids[masked_position] = MASK_TOKEN_ID
            elif random() < 0.5:  # 10%
                index = randint(5, VOCAB_SIZE - 1) # random index in vocabulary
                masked_token_ids[masked_position] = index

        masked_tokens = np.array(sentence_pair.ids)[masked_positions]
        mask_padding = max_pred_count - len(masked_positions)
    
        yield [
            sentence_pair.ids,
            masked_token_ids,
            sentence_pair.type_ids,
            np.concatenate([masked_tokens, [IGNORE_INDEX] * mask_padding]),
            np.concatenate([masked_positions, [0] * mask_padding]),
            sentence_pair.attention_mask,
            is_next,
        ]

In [8]:
def get_batch(split, batch_size):
    return map(
         lambda x: torch.tensor(x, device=DEVICE, dtype=torch.long), 
         zip(*get_sentence_pair_batch(split, batch_size=batch_size)))

(token_ids,
 masked_token_ids,
 segment_ids,
 masked_tokens,
 masked_positions,
 attention_masks,
 is_next) = get_batch("train", batch_size=1)

print("token_ids          \n", token_ids, "\n")
print("masked_token_ids   \n", masked_token_ids, "\n")
print("segment_ids        \n", segment_ids, "\n")
print("masked_tokens      \n", masked_tokens, "\n")
print("masked_positions   \n", masked_positions, "\n")
print("attention_masks    \n", attention_masks, "\n")
print("is_next            \n", is_next, "\n")

token_ids          
 tensor([[  0, 993, 149, 298,  31,   1, 850,  47,  18,  61,  40,  62, 304, 758,
          31,   1,   2,   2,   2,   2]], device='cuda:0') 

masked_token_ids   
 tensor([[  0, 993, 149, 298,  31,   1, 850,  47,  18,   3,  40,   3, 304, 758,
          31,   1,   2,   2,   2,   2]], device='cuda:0') 

segment_ids        
 tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]],
       device='cuda:0') 

masked_tokens      
 tensor([[  61,   62, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100]], device='cuda:0') 

masked_positions   
 tensor([[ 9, 11,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0]], device='cuda:0') 

attention_masks    
 tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]],
       device='cuda:0') 

is_next            
 tensor([1], device='cuda:0') 



  lambda x: torch.tensor(x, device=DEVICE, dtype=torch.long),


# Building the model

In [9]:
model = BERT(
    vocab_size=VOCAB_SIZE,
    n_segments=N_SEGMENTS,
    max_len=MAX_LEN,
    embed_dim=EMBED_DIM,
    num_heads=ATTN_HEADS,
    dropout=DROPOUT,
    n_layers=N_LAYERS,
)
m = model.to(DEVICE)

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

# `ignore_index`: specifies a target value that is ignored
# and does not contribute to the input gradient. When :attr:`size_average` is
# `True`, the loss is averaged over non-ignored targets. Note that
# Only applicable when the target contains class indices

loss_fn_lm = torch.nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
loss_fn_clsf = torch.nn.CrossEntropyLoss() 

# Training loop

In [10]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        clsf_losses = torch.zeros(EVAL_ITERS)
        lm_losses = torch.zeros(EVAL_ITERS)
        for k in range(EVAL_ITERS):
            _, token_ids, segment_ids, masked_tokens, masked_positions, attention_masks, is_next = get_batch("train", BATCH_SIZE)
            logits_lm, logits_clsf = model(token_ids, segment_ids, attention_masks, masked_positions)

            loss_lm = loss_fn_lm(logits_lm.transpose(-2,-1), masked_tokens)
            loss_clsf = loss_fn_clsf(logits_clsf, is_next)
            
            clsf_losses[k] = loss_clsf.item()
            lm_losses[k] = loss_lm.item()

        out[split] = (clsf_losses.mean(), lm_losses.mean())
    model.train()
    return out


for iter in range(MAX_ITERS):
    optimizer.zero_grad()

    # every once in a while evaluate the loss on train and val sets
    if iter % EVAL_INTERVAL == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train'][0]:.4f}|{losses['train'][1]:.4f}," +
              f"val loss {losses['val'][0]:.4f}|{losses['val'][1]:.4f}")

    # sample a batch of data
    _, token_ids, segment_ids, masked_tokens, masked_positions, attention_masks, is_next = get_batch("train", BATCH_SIZE)

    # evaluate the loss
    logits_lm, logits_clsf = model(token_ids, segment_ids, attention_masks, masked_positions)
    loss_lm = loss_fn_lm(logits_lm.transpose(-2,-1), masked_tokens) # for masked LM
    loss_clsf = loss_fn_clsf(logits_clsf, is_next) # for sentence classification
    loss = loss_lm + loss_clsf
    
    loss.backward()
    optimizer.step()


step 0: train loss 0.7123|9.6134,val loss 0.7121|9.6099


# Test time

In [60]:
test_sentence = "[CLS]alors qu'il relevait que l'assignation[MASK] [SEP]"

test_ids = tokenizer.encode(test_sentence)

test_token_ids = torch.tensor(test_ids.ids, dtype=torch.long, device=DEVICE)
test_attn_mask = torch.tensor(test_ids.ids, dtype=torch.long, device=DEVICE) != PAD_TOKEN_ID
test_segment = torch.zeros(MAX_LEN, device=DEVICE, dtype=torch.long)
test_mask_position = torch.tensor(test_ids.ids, dtype=torch.long, device=DEVICE) == MASK_TOKEN_ID

In [61]:
logit_lm, logits_clsf = model(
      test_token_ids[None, :],
      test_segment[None, :],
      test_attn_mask[None, :].long(),
      test_mask_position[None, :].long())

tokenizer.decode(torch.argmax(logit_lm[0], dim=1).cpu().numpy(), skip_special_tokens=True)

'1 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1'

In [81]:
(token_ids,
 masked_token_ids,
 segment_ids,
 masked_tokens,
 masked_positions,
 attention_masks,
 is_next) = get_batch("test", batch_size=1)

print(tokenizer.decode([token_ids[0, 5]], skip_special_tokens=False))

masked_tokens = torch.ones((1, MAX_LEN), dtype=torch.long, device=DEVICE) * -100
masked_tokens[0,0] = token_ids[0, 5]
token_ids[0, 5] = MASK_TOKEN_ID

masked_positions = torch.zeros((1, MAX_LEN), dtype=torch.long, device=DEVICE)
masked_positions[0,0] = 5

print(tokenizer.decode(token_ids[0].cpu().numpy(), skip_special_tokens=False))
print(token_ids)

diligences
[CLS] Dit que sur les [MASK] du procureur général près la Cour de cassation , le présent arrêt sera transmis pour être transcrit en marge ou à la suite de la décision cassée ; [SEP] Ainsi fait et jugé par la Cour de cassation , première chambre civile , et prononcé par le président en son audience publique du huit mars deux mille vingt - trois . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
tensor([[   0, 1514,  182,  295,  194,    3,  161,  596,  423,  329,  153,  265,
          146,  253,   16,  149,  333,  259, 1296, 1409,  217,  486, 1571,  151,
         1364,  211,  120,  153,  888,  146,  153,  360, 8163,   31,    1,  643,
          381,  166,  878,  183,

In [83]:
model.eval()
logit_lm, logits_clsf = model(token_ids, segment_ids, attention_masks, masked_positions)
model.train()

predicted_masked_tok = torch.argmax(logit_lm, dim=2).cpu().numpy()

print(tokenizer.decode(predicted_masked_tok[0], skip_special_tokens=True))
print(predicted_masked_tok[0])


diligences être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être être
[1410  486  486  486  486  486  486  486  486  486  486  486  486  486
  486  486  486  486  486  486  486  486  486  486  486  486  486  486
  486  486  486  486  486  486  486  486  486  486  486  486  486  486
  486  486  486  486  486  486  486  486  486  486  486  486  486  486
  486  486  486  486  486  486  486  486  486  486  486  486  486  486

In [None]:
loss_fn_lm(logit_lm.view(-1,VOCAB_SIZE) , masked_tokens.view(-1))

In [None]:
masked_tokens[masked_tokens == -100] = 0
loss_fn_lm(logit_lm.view(-1,VOCAB_SIZE) , masked_tokens.view(-1))

In [None]:
torch.argmax(logit_lm[0,0])

In [None]:
masked_tokens