In [4]:
import torch
import os
from src.backend.module import nn, modules, loss as loss_module
from src.backend.utils import load_base_model, load_tokenizer
from torch.optim import Adam
import pandas as pd

In [3]:
import re

def ASSERTION_DATAFRAME_COLUMN(columns):
    columns = tuple(columns)
    DATASET_COLUMNS_WITH_ID = ("ID", "Input", "Output")
    DATASET_COLUMNS = ("Input", "Output")
    return DATASET_COLUMNS == columns or DATASET_COLUMNS_WITH_ID == columns


def preprocessed(df, tokenizer, vocab=30522, context_length=512, flag="~_~"):
    """"""
    
    assert isinstance(df, pd.DataFrame), "pre-process precondition needs to be in a DataFrame"
    assert ASSERTION_DATAFRAME_COLUMN(df.columns)

    # parse string to convert into a list
    df["Output"] = df["Output"].apply(lambda x : x.split(','))

    # regular expression that obtains flag and associated word
    pattern = rf"{flag}(\w+)"

    # if ID exist within the column then remove it
    if "ID" in df.columns:
        df.drop("ID", axis=1, inplace=True)

    inputs, outputs = [], []
    for i, sentence in enumerate(df["Input"]):
        replace = re.findall(pattern, sentence) # find all replaced tokens
        replacement : list = df["Output"][i]    # index the associated replacement list
        
        # replace tokens should equal replacement
        if len(replacement) != len(replace):
            continue

        # construct the replacement dictionary
        replacement_dict = {replace : replacement[i] for i, replace in enumerate(replace)}   

        # TODO Not ideal in initializing each iteration 
        def replaced(match):
            word = match.group(1)
            return replacement_dict.get(word, match.group(0))

        output_text = re.sub(pattern, replaced, sentence)
        input_text = sentence.replace(flag, "")
        inputs.append(input_text)
        outputs.append(output_text)


    input_tokens = tokenizer(inputs, padding=True, max_length=context_length, truncation=True, return_tensors="pt")
    # TODO: change to make 
    output_tokens = tokenizer(outputs, padding=True, max_length=context_length, truncation=True, return_tensors="pt")
    mask = ((input_tokens["input_ids"] - output_tokens["input_ids"]) != 0).to(torch.long)
    prob = torch.nn.functional.one_hot(output_tokens["input_ids"], num_classes=vocab)
    
    return input_tokens, output_tokens, mask, prob

In [3]:
syntax = nn.SyntaxBert.load_local_weights(nn.BertConfig)
tokenizer = load_tokenizer()

  from .autonotebook import tqdm as notebook_tqdm


In [43]:
df_training = pd.read_csv("./data/wiki_examples_flagged.csv")
df_training

Unnamed: 0,ID,Input,Output
0,1,"The Labrador Retriever, or simply Labrador, is...",developed
1,2,An angel is a ~_~celestial being in various tr...,"supernatural,religions,benevolent,religions"
2,3,"Attachment theory is a psychological, evolutio...","emotional,interactions"
3,4,The Gulag was the government ~_~administration...,"agency,convicts,convicts,convicts"
4,5,TreeHugger is a sustainability website that re...,"boasts,bought"
5,6,Mr. Clean (or Mr. Proper) is a brand name and ...,conceived
6,7,Foreign Affairs is an American journal of inte...,"policy,policy,magazine,policy"
7,8,Power Rangers is an American entertainment and...,"series,series,developed"
8,9,"Greenland (Greenlandic: Kalaallit Nunaat, pron...","territory,residents"
9,10,Amazing Grace' is a Christian hymn published i...,"song,song"


In [4]:

# Input must now be of shape (batch_size, num_tokens, num_features)
# Let's assume a batch_size of 10 for this example
batch_size = 2
num_tokens = 9
input_features = 768

config = modules.attn_config(embed_dim=768, 
                             num_heads=[2, 2], 
                             dropout=[0.1, 0.1], 
                             input_dim=input_features, 
                             dict_dim=30522, 
                             synonym_head="softmax", 
                             replace_head="sigmoid")

# Generating random input to simulate a batch of sequences
df_training = pd.read_csv("./wiki_examples_flagged.csv")
tok, tok_outputs, mask, prob = preprocessed(df_training, tokenizer)
embd = syntax(tok["input_ids"][:1, :])[1]
print("model output: ",embd.shape)


# Instantiate the attention module with the given configuration
attn_mech = modules.attn_module(config)

# Forward pass through the attention mechanism
# Note that config is no longer passed as an argument to the forward method
replace_probs, synonym_probs = attn_mech(embd)

# Print out shapes and values
print(f"Replacement Probabilities Shape: {replace_probs.shape}")  # Expected: (batch_size, num_tokens, 1)
print(f"Synonym Probabilities Shape: {synonym_probs.shape}")      # Expected: (batch_size, num_tokens, dict_dim)
# print(f"Replacement probabilities for each token in Batch 2: {replace_probs[1]}")           # Expected: A vector of probabilities for each token in the batch. There are three words, so three probabilities.
print(f"Replacement probability Batch 1, Token 1: {replace_probs[0][0]}")        # Expected: Value between 0 and 1
print(f"Replacement probability Batch 1, Token 1: {replace_probs[0]}")        # Expected: Value between 0 and 1
print(f"Synonym Probability Distribution for the Batch 1, Token 2: {synonym_probs[0][1]}") # Expected: A vector of probabilities for each word in the dictionary. There are three words, so three probabilities.
print(f"Synonym Probabilities Sum-to-1 Constraint for Token 1: {torch.sum(synonym_probs[0][0])}") # Expected : Sum to 1 constraint for the softmax probabilities, for the first token

model output:  torch.Size([1, 512, 768])
Replacement Probabilities Shape: torch.Size([1, 512, 1])
Synonym Probabilities Shape: torch.Size([1, 512, 30522])
Replacement probability Batch 1, Token 1: tensor([0.4955], grad_fn=<SelectBackward0>)
Replacement probability Batch 1, Token 1: tensor([[0.4955],
        [0.4960],
        [0.4957],
        [0.4958],
        [0.4956],
        [0.4952],
        [0.4957],
        [0.4955],
        [0.4956],
        [0.4957],
        [0.4958],
        [0.4957],
        [0.4959],
        [0.4962],
        [0.4959],
        [0.4959],
        [0.4958],
        [0.4955],
        [0.4953],
        [0.4959],
        [0.4952],
        [0.4959],
        [0.4963],
        [0.4955],
        [0.4956],
        [0.4954],
        [0.4958],
        [0.4961],
        [0.4959],
        [0.4954],
        [0.4956],
        [0.4955],
        [0.4964],
        [0.4963],
        [0.4958],
        [0.4961],
        [0.4954],
        [0.4957],
        [0.4961],
        [0.4957

'SyntaxBert'

In [67]:
def training(model        : nn.SyntaxBert, 
             head         : modules.attn_module,
             X            : torch.Tensor,
             replacements : torch.Tensor,
             synonyms     : torch.Tensor,
             optimizer    : any,
             loss_fn      : any,
             batch_size   : int=16,
             epoch        : int=2):
    
    
    # pre-train process =========================================
    flag = 10 if epoch > 50 else 1
    model_name = type(model).__name__
    total_dataset = len(X)
    # off load forward and back propagation to the cuda kernel
    device = (
                "cuda"
                if torch.cuda.is_available()
                else "mps"
                if torch.backends.mps.is_available()
                else "cpu"
             )

    attn_mech.to(device)

    # freeze Bert Weights
    # 
    for param in model.parameters():
        param.required_grad = False


    avg_loss = []
    # train process ============================================
    for i in range(epoch):
        losses = []
        for batch in range(0, total_dataset, batch_size):
        
            x = X[batch:batch+batch_size, ...]
            
            syn_y = synonyms[batch:batch+batch_size, ...].float()
            rep_y = replacements[batch:batch+batch_size, ...].float()
            
            # for each batch zero grad 
            optimizer.zero_grad()
                
            _, hidden_layer = model(x)[1]
            logits_r, logits_s = head(hidden_layer.to(device))
        
            # Compute the loss and its gradients
            loss = loss_fn(logits_s, logits_r, syn_y, rep_y)
            loss.backward()

            # Adjust learning weights
            optimizer.step()

        avg_loss.append(sum(losses)/len(losses)) 

        if i % flag == 0:
            print(f"[INFO] |{f'model: {model_name:<5}':^10}|{f'epoch: {i:<5}':^10}|{f'avg loss: {avg_loss[i]:<5}':^10}|")
        
    # ==========================================================


In [65]:
from torch.optim import Adam

In [6]:
import pandas as pd
df_training = pd.read_csv("./data/wiki_examples_flagged.csv")
df_training
tokenizer = load_tokenizer()
inputs, outputs, replacement, syn = preprocessed(df_training, tokenizer)

In [11]:
inputs["input_ids"].shape, outputs["input_ids"].shape, replacement.shape, syn.shape

(torch.Size([29, 512]),
 torch.Size([29, 512]),
 torch.Size([29, 512]),
 torch.Size([29, 512, 30522]))

In [68]:
syntax = nn.SyntaxBert.load_local_weights(nn.BertConfig)
tokenizer = load_tokenizer()
df_training = pd.read_csv("./wiki_examples_flagged.csv")


lr=0.001
optimizer=Adam(syntax.parameters(), lr=lr)
loss_fn = loss_module.JointCrossEntropy(head_type="linear")

inputs, labels, replacement, synonyms = preprocessed(df_training, tokenizer)

config = modules.attn_config(embed_dim=768, 
                             num_heads=[2, 2], 
                             dropout=[0.1, 0.1], 
                             input_dim=input_features, 
                             dict_dim=30522, 
                             synonym_head="softmax", 
                             replace_head="sigmoid")
attn_mech = modules.attn_module(config)

training(model=syntax, 
         head=attn_mech, 
         X=inputs["input_ids"], 
         replacements=replacement, 
         synonyms=synonyms, 
         optimizer=optimizer, 
         loss_fn=loss_fn, 
         batch_size=16)

: 

In [13]:
loss_fn = loss.JointCrossEntropy(head_type="linear")
synonym_probs.shape, replace_probs[0, :, :].shape, prob[:1, :, :].shape, mask[:1, :].T.shape
torch.nn.BCEWithLogitsLoss(reduction='mean')(replace_probs[0, :, :], mask[:1, :].T.float())
mask.unsqueeze(-1)
loss = loss_fn(synonym_probs, replace_probs[0, :, :], prob[:1, :, :].float(), mask[:1, :].T.float())
loss.backward()

NameError: name 'loss' is not defined