# Setup

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]="0"
device = "cuda:0"

In [None]:
import torch
import pandas as pd
import pickle 
import numpy as np

from transformers import (
    AutoModel, 
    AutoModelForSequenceClassification, 
    AutoConfig, 
    utils,
)
from bertviz import head_view

import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from main import load_prsa

from models.configs import (
    RowTabBERTConfig,
    FieldyConfig,
)

from models.models import (
    Model,
    RowTabBERT,
    Fieldy,
)

from dataset.datacollator import (
    RowTabBERTDataCollatorForLanguageModeling,
    RowTabBERTDataCollatorForFineTuning,
    FieldyDataCollatorForLanguageModeling,
    FieldyDataCollatorForFineTuning,
)

from args import define_main_parser

utils.logging.set_verbosity_error()  # Suppress standard warnings

In [None]:
os.chdir("..")

In [None]:
name_suffix = ""
data_type = "prsa"

parser = define_main_parser()
args_row_tabbert = parser.parse_args(args=[
    f"--data-type=prsa",
    "--family=row_tabbert",
    "--hidden-size=800",
    "--fieldtransf-nheads=10",
    "--fieldtransf-nlayers=6",
    "--n-heads=10",
    "--n-layers=10",
    "--scale-targets",
    "--scaling=std",
    "--dropout=0.1",
    "--pos-emb",
    "--col-emb",
    "--pt-epochs=24",
    "--ft-epochs=20",
    "--seed=1",
])
args_fieldy = parser.parse_args(args=[
    f"--data-type=prsa",
    "--hidden-size=800",
    "--fieldtransf-nheads=10",
    "--fieldtransf-nlayers=8",
    "--n-heads=10",
    "--n-layers=4",
    "--scale-targets",
    "--scaling=std",
    "--dropout=0.1",
    "--pos-emb",
    "--col-emb",
    "--pt-epochs=24",
    "--ft-epochs=20",
    "--seed=1",
])

pt_ep = 24
ft_ep = 20
posemb = "posemb"
colemb = "colemb"
num_ft_labels = 10 * 2

In [None]:
# Requires to have proprocessed the PRSA dataset (execute prsa.sh if needed)
with open(f"./data/prsa/PRSADataset_labeled{name_suffix}.pkl", "rb") as f:
    dataset = pickle.load(f)

# Models

Requires to have trained the models on the PRSA dataset (execute `prsa.sh` if needed)


In [None]:
row_tabbert = Model(
    special_tokens=dataset.vocab.get_special_tokens(),
    vocab=dataset.vocab,
    family=args_row_tabbert.family,
    ncols=dataset.ncols,
    hidden_size=args_row_tabbert.hidden_size,
    seq_len=dataset.seq_len,
    pos_emb=args_row_tabbert.pos_emb,
    col_emb=args_row_tabbert.col_emb,
    max_position_embeddings=512,
    mlm_loss=args_row_tabbert.mlm_loss,
    n_heads=args_row_tabbert.n_heads,
    fieldtransf_nheads=args_row_tabbert.fieldtransf_nheads,
    fieldtransf_nlayers=args_row_tabbert.fieldtransf_nlayers,
    n_layers=args_row_tabbert.n_layers,
    num_ft_labels=num_ft_labels,
    dropout=args_row_tabbert.dropout,
)
AutoConfig.register("RowTabBERT", RowTabBERTConfig)
AutoModelForSequenceClassification.register(RowTabBERTConfig, RowTabBERT)
row_tabbert_model_path_pt = f"../results/{args_row_tabbert.data_type}/RowTabBERT/RowTabBERT_{args_row_tabbert.fieldtransf_nheads}fieldtransfheads_{args_row_tabbert.fieldtransf_nlayers}fieldtransflayers_{args_row_tabbert.hidden_size}hs_{args_row_tabbert.n_heads}heads_{args_row_tabbert.n_layers}layers_{'posemb' if args_row_tabbert.pos_emb else 'noposemb'}_{'colemb' if args_row_tabbert.col_emb else 'nocolemb'}_MSE_pt{args_row_tabbert.pt_epochs}ep_ft{args_row_tabbert.ft_epochs}ep_seed{args_row_tabbert.seed}/pt" 
row_tabbert_model_pt = AutoModel.from_pretrained(row_tabbert_model_path_pt, vocab=dataset.vocab, output_attentions=True)

In [None]:
fieldy = Model(
    special_tokens=dataset.vocab.get_special_tokens(),
    vocab=dataset.vocab,
    family=args_fieldy.family,
    ncols=dataset.ncols,
    hidden_size=args_fieldy.hidden_size,
    seq_len=dataset.seq_len,
    pos_emb=args_fieldy.pos_emb,
    col_emb=args_fieldy.col_emb,
    max_position_embeddings=512,
    mlm_loss=args_fieldy.mlm_loss,
    n_heads=args_fieldy.n_heads,
    fieldtransf_nheads=args_fieldy.fieldtransf_nheads,
    fieldtransf_nlayers=args_fieldy.fieldtransf_nlayers,
    n_layers=args_fieldy.n_layers,
    num_ft_labels=num_ft_labels,
    dropout=args_fieldy.dropout,
)
AutoConfig.register("Fieldy", FieldyConfig)
AutoModelForSequenceClassification.register(FieldyConfig, Fieldy)
fieldy_model_path_pt = f"../results/{args_fieldy.data_type}/Fieldy/Fieldy_{args_fieldy.fieldtransf_nheads}fieldtransfheads_{args_fieldy.fieldtransf_nlayers}fieldtransflayers_{args_fieldy.hidden_size}hs_{args_fieldy.n_heads}heads_{args_fieldy.n_layers}layers_{'posemb' if args_fieldy.pos_emb else 'noposemb'}_{'colemb' if args_fieldy.col_emb else 'nocolemb'}_MSE_pt{args_fieldy.pt_epochs}ep_ft{args_fieldy.ft_epochs}ep_seed{args_fieldy.seed}/pt" 
fieldy_model_pt = AutoModel.from_pretrained(fieldy_model_path_pt, vocab=dataset.vocab, output_attentions=True)

# Create samples

In [None]:
tabbert_data_collator_pt = RowTabBERTDataCollatorForLanguageModeling(
    tokenizer=row_tabbert.tokenizer, 
    mlm=args_row_tabbert.mlm, 
    mlm_probability=args_row_tabbert.mlm_prob,
    ncols=dataset.ncols,
    seq_len=dataset.seq_len,
    data_type=args_row_tabbert.data_type,
    seed=args_row_tabbert.seed,
    randomize_seq=True;
)

fieldy_data_collator_pt = FieldyDataCollatorForLanguageModeling(
    tokenizer=fieldy.tokenizer, 
    mlm=args_fieldy.mlm, 
    mlm_probability=args_fieldy.mlm_prob,
    ncols=dataset.ncols,
    seq_len=dataset.seq_len,
    data_type=args_fieldy.data_type,
    seed=args_fieldy.seed,
    randomize_seq=True;
)

In [None]:
nsamples = 100

samples = []
for i in range(0, nsamples):
    samples.append(dataset.__getitem__(i))

rows_of_unk = [] 
rows_of_mask = [5,6,7,8,9] 
except_row = 0
fields_of_unk = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] 
fields_of_mask = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] 

In [None]:
samples_pt_row_tabbert = fieldy_data_collator_pt(samples)
for i in range(0, nsamples): 
    for row in rows_of_unk:
        for field in fields_of_unk:
            samples_pt_row_tabbert['input_ids'][i][row][field] = dataset.vocab.token2id["SPECIAL"]["[UNK]"][0] # We [MASK] 1 token in each sample(Vocab ID 4)
    for row in rows_of_mask:
        if row == except_row:
            for field in range(16):
                samples_pt_row_tabbert['input_ids'][i][row][field] = dataset.vocab.token2id["SPECIAL"]["[MASK]"][0] # We [MASK] 1 token in each sample(Vocab ID 4)
        else:
            for field in fields_of_mask:
                samples_pt_row_tabbert['input_ids'][i][row][field] = dataset.vocab.token2id["SPECIAL"]["[MASK]"][0] # We [MASK] 1 token in each sample(Vocab ID 4)
samples_pt_row_tabbert['attention_mask'] = None 
samples_pt_row_tabbert['masked_lm_labels'] = samples_pt_row_tabbert['labels']
for k, v in samples_pt_row_tabbert.items():
    if torch.is_tensor(v):
        samples_pt_row_tabbert[k] = v.to(device)
row_tabbert_model_pt.to(device)

samples_pt_fieldy = fieldy_data_collator_pt(samples)
for i in range(0, nsamples): 
    for row in rows_of_unk:
        for field in fields_of_unk:
            samples_pt_fieldy['input_ids'][i][row][field] = dataset.vocab.token2id["SPECIAL"]["[UNK]"][0] # We [MASK] 1 token in each sample(Vocab ID 4)
    for row in rows_of_mask:
        if row == except_row:
            for field in range(16):
                samples_pt_fieldy['input_ids'][i][row][field] = dataset.vocab.token2id["SPECIAL"]["[MASK]"][0] # We [MASK] 1 token in each sample(Vocab ID 4)
        else:
            for field in fields_of_mask:
                samples_pt_fieldy['input_ids'][i][row][field] = dataset.vocab.token2id["SPECIAL"]["[MASK]"][0] # We [MASK] 1 token in each sample(Vocab ID 4)
samples_pt_fieldy['attention_mask'] = None 
samples_pt_fieldy['masked_lm_labels'] = samples_pt_fieldy['labels']
for k, v in samples_pt_fieldy.items():
    if torch.is_tensor(v):
        samples_pt_fieldy[k] = v.to(device)
fieldy_model_pt.to(device)

# Field-wise attention toy task

In [None]:
def score(samples_pt, model_pt, family="row_tabbert"):
    
    samples_for_scoring = {}
    for k, v in samples_pt.items():
        try:
            # samples_for_scoring[k] = v[sample_id].unsqueeze(0)
            samples_for_scoring[k] = v
        except TypeError: # "None", e.g. for attention_mask
            samples_for_scoring[k] = v
    preds, full_outputs, preds_scores = model_pt(**samples_for_scoring, output_attentions=True)

    gt = samples_pt["input_ids"][:, :, 14][:, :5]
    gt = gt.cpu().numpy().tolist()
    gt_hours = []
    for s in gt:
        gt_hours.append([dataset.vocab.id2token[i][0] for i in s])
    gt_hours = [
        [(s+5)%24 for s in h] 
        for h in gt_hours
    ]
    gt_hours = np.array(gt_hours).flatten()

    guesses = []
    if family == "fieldy":
        for i, s in enumerate(preds_scores):
            s_guesses = []
            for h in range(5, 10):
                mlm_guess = torch.topk(preds_scores[i, (10*14) + h, :], 1).indices # Tabbie has been flattened from [bs, ncols, nrows]
                token_guess = [dataset.vocab.id2token[top.item()][:2] for top in mlm_guess][0]
                if token_guess[1] == "hour":
                    s_guesses.append(token_guess[0])
                else:
                    s_guesses.append(-1)
            guesses.append(s_guesses)
    elif family == "row_tabbert":
        for i, s in enumerate(preds_scores):
            s_guesses = []
            for h in range(5, 10):
                mlm_guess = torch.topk(preds_scores[i, 14 + (16*h), :], 1).indices # Tabbie has been flattened from [bs, ncols, nrows]
                token_guess = [dataset.vocab.id2token[top.item()][:2] for top in mlm_guess][0]
                if token_guess[1] == "hour":
                    s_guesses.append(token_guess[0])
                else:
                    s_guesses.append(-1)
            guesses.append(s_guesses)
    guesses = np.array(guesses).flatten()

    score = (gt_hours == guesses).sum()
    score = score / guesses.shape[0]

    return score

In [None]:
score(samples_pt_row_tabbert, row_tabbert_model_pt, family="row_tabbert")

In [None]:
score(samples_pt_fieldy, fieldy_model_pt, family="fieldy")