In [1]:
import argparse
import json
import math
import os
import random
from time import time
import mlflow
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, f1_score
from collections import defaultdict

# import pytrec_eval
import torch
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
from torch.utils.data import DataLoader, RandomSampler
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig, AutoTokenizer
from transformers import AdamW, get_linear_schedule_with_warmup
from accelerate import Accelerator
from copy import deepcopy

torch.backends.cuda.matmul.allow_tf32 = True

from watchog.dataset import (
    # collate_fn,
    TURLColTypeTablewiseDataset,
    TURLRelExtTablewiseDataset,
    SatoCVTablewiseDataset,
    ColPoplTablewiseDataset,
    GittablesColwiseMaxDataset
)

from watchog.dataset import TableDataset, SupCLTableDataset, SemtableCVTablewiseDataset, GittablesColwiseDataset, GittablesCVTablewiseDataset
from watchog.model import BertMultiPairPooler, BertForMultiOutputClassification, BertForMultiOutputClassificationColPopl, BertForMultiSelectionClassification
from watchog.model import SupCLforTable, UnsupCLforTable, lm_mp
from watchog.utils import load_checkpoint, f1_score_multilabel, collate_fn, get_col_pred, ColPoplEvaluator
from watchog.utils import task_num_class_dict
from accelerate import DistributedDataParallelKwargs
import wandb

[nltk_data] Downloading package punkt to /home/zhihao/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /home/zhihao/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [2]:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    parser = argparse.ArgumentParser()
    parser.add_argument("--wandb", type=bool, default=False)
    parser.add_argument("--model", type=str, default="Watchog")
    parser.add_argument("--unlabeled_train_only", type=bool, default=False)
    parser.add_argument("--context_encoding_type", type=str, default="v0")
    parser.add_argument("--pool_version", type=str, default="v0.2")
    parser.add_argument("--random_sample", type=bool, default=False)
    parser.add_argument("--comment", type=str, default="debug", help="to distinguish the runs")
    parser.add_argument(
        "--shortcut_name",
        default="bert-base-uncased",
        type=str,
        help="Huggingface model shortcut name ",
    )
    parser.add_argument(
        "--max_length",
        default=64,
        type=int,
        help=
        "The maximum total input sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.",
    )
    parser.add_argument(
        "--adaptive_max_length",
        default=False,
        type=bool,
    )    
    parser.add_argument(
        "--max_num_col",
        default=8,
        type=int,
    )   

    parser.add_argument(
        "--batch_size",
        default=16,
        type=int,
        help="Batch size",
    )
    parser.add_argument(
        "--epoch",
        default=1,
        type=int,
        help="Number of epochs for training",
    )
    parser.add_argument(
        "--random_seed",
        default=4649,
        type=int,
        help="Random seed",
    )
    
    parser.add_argument(
        "--train_n_seed_cols",
        default=-1,
        type=int,
        help="number of seeding columns in training",
    )

    parser.add_argument(
        "--num_classes",
        default=78,
        type=int,
        help="Number of classes",
    )
    parser.add_argument("--multi_gpu",
                        action="store_true",
                        default=False,
                        help="Use multiple GPU")
    parser.add_argument("--fp16",
                        action="store_true",
                        default=False,
                        help="Use FP16")
    parser.add_argument("--warmup",
                        type=float,
                        default=0.,
                        help="Warmup ratio")
    parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate")
    parser.add_argument("--task",
                        type=str,
                        default='gt-semtab22-dbpedia-all0',
                        choices=[
                            "sato0", "sato1", "sato2", "sato3", "sato4",
                            "msato0", "msato1", "msato2", "msato3", "msato4",
                            "gt-dbpedia0", "gt-dbpedia1", "gt-dbpedia2", "gt-dbpedia3", "gt-dbpedia4",
                            "gt-dbpedia-all0", "gt-dbpedia-all1", "gt-dbpedia-all2", "gt-dbpedia-all3", "gt-dbpedia-all4",
                            "gt-schema-all0", "gt-schema-all1", "gt-schema-all2", "gt-schema-all3", "gt-schema-all4",
                            "gt-semtab22-dbpedia", "gt-semtab22-dbpedia0", "gt-semtab22-dbpedia1", "gt-semtab22-dbpedia2", "gt-semtab22-dbpedia3", "gt-semtab22-dbpedia4",
                            "gt-semtab22-dbpedia-all", "gt-semtab22-dbpedia-all0", "gt-semtab22-dbpedia-all1", "gt-semtab22-dbpedia-all2", "gt-semtab22-dbpedia-all3", "gt-semtab22-dbpedia-all4",
                            "gt-semtab22-schema-class-all", "gt-semtab22-schema-property-all",
                            "turl", "turl-re", "col-popl-1", "col-popl-2", "col-popl-3", "row-popl",
                            "col-popl-turl-0", "col-popl-turl-1", "col-popl-turl-2",
                            "col-popl-turl-mdonly-0", "col-popl-turl-mdonly-1", "col-popl-turl-mdonly-2"
                        ],
                        help="Task names}")
    parser.add_argument("--colpair",
                        action="store_true",
                        help="Use column pair embedding")
    parser.add_argument("--metadata",
                        action="store_true",
                        help="Use column header metadata")
    parser.add_argument("--from_scratch",
                        action="store_true",
                        help="Training from scratch")
    parser.add_argument("--cl_tag",
                        type=str,
                        default="wikitables/simclr/bert_100000_10_32_256_5e-05_sample_row4,sample_row4_tfidf_entity_column_0.05_0_last.pt",
                        help="path to the pre-trained file")
    parser.add_argument("--dropout_prob",
                        type=float,
                        default=0.5)
    parser.add_argument("--eval_test",
                        action="store_true",
                        help="evaluate on testset and do not save the model file")
    parser.add_argument("--small_tag",
                        type=str,
                        default="semi1",
                        help="e.g., by_table_t5_v1")
    parser.add_argument("--data_path",
                        type=str,
                        default="/data/zhihao/TU/")
    parser.add_argument("--pretrained_ckpt_path",
                        type=str,
                        default="/data/zhihao/TU/Watchog/model/")    

    args = parser.parse_args([])
    task = args.task
    if args.small_tag != "":
        args.eval_test = True
    
    args.num_classes = task_num_class_dict[task]
    if args.colpair:
        assert "turl-re" == task, "colpair can be only used for Relation Extraction"
    if args.metadata:
        assert "turl-re" == task or "turl" == task, "metadata can be only used for TURL datasets"
    if "col-popl":
        # metrics = {
        #     "accuracy": CategoricalAccuracy(tie_break=True),
        # }
        if args.train_n_seed_cols != -1:
            if "col-popl" in task:
                assert args.train_n_seed_cols == int(task[-1]),  "# of seed columns must match"

    print("args={}".format(json.dumps(vars(args))))

    max_length = args.max_length
    batch_size = args.batch_size
    num_train_epochs = args.epoch

    shortcut_name = args.shortcut_name

    if args.colpair and args.metadata:
        taskname = "{}-colpair-metadata".format(task)
    elif args.colpair:
        taskname = "{}-colpair".format(task)
    elif args.metadata:
        taskname = "{}-metadata".format(task)
    elif args.train_n_seed_cols == -1 and 'popl' in task:
        taskname = "{}-mix".format(task)
    else:
        taskname = "".join(task)
    cv = int(task[-1])

    if args.from_scratch:
        if "gt" in task:
            tag_name = "{}/{}-{}-{}-pool{}-max_cols{}-rand{}-bs{}-ml{}-ne{}-do{}{}".format(
                taskname,  "{}-fromscratch".format(shortcut_name), args.small_tag, args.comment, args.pool_version, args.max_num_col, args.random_sample,
                batch_size, max_length, num_train_epochs, args.dropout_prob, 
                '-rs{}'.format(args.random_seed) if args.random_seed != 4649 else '')
        else:
            tag_name = "{}/{}-{}-{}-bs{}-ml{}-ne{}-do{}{}".format(
                taskname,  "{}-fromscratch".format(shortcut_name), args.small_tag, args.comment, 
                batch_size, max_length, num_train_epochs, args.dropout_prob, 
                '-rs{}'.format(args.random_seed) if args.random_seed != 4649 else '')
        
    else:
        if "gt" in task:
            tag_name = "{}/{}_{}-pool{}-max_cols{}-rand{}-bs{}-ml{}-ne{}-do{}{}".format(
                taskname, args.cl_tag.replace('/', '-'),  shortcut_name, args.small_tag, args.pool_version, args.max_num_col, args.random_sample,
                batch_size, max_length, num_train_epochs, args.dropout_prob,
                '-rs{}'.format(args.random_seed) if args.random_seed != 4649 else '')
        else:
            tag_name = "{}/{}_{}-{}-bs{}-ml{}-ne{}-do{}{}".format(
                taskname, args.cl_tag.replace('/', '-'),  shortcut_name, args.small_tag,
                batch_size, max_length, num_train_epochs, args.dropout_prob,
                '-rs{}'.format(args.random_seed) if args.random_seed != 4649 else '')

    # if args.eval_test:
    #     if args.small_tag != '':
    #         tag_name = tag_name.replace('outputs', 'small_outputs')
    #         tag_name += '-' + args.small_tag
    print(tag_name)
    file_path = os.path.join(args.data_path, "Watchog", "outputs", tag_name)

    dirpath = os.path.dirname(file_path)
    if not os.path.exists(dirpath):
        print("{} not exists. Created".format(dirpath))
        os.makedirs(dirpath)
    
    if args.fp16:
        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
        
      
        
    # accelerator = Accelerator(mixed_precision="no" if not args.fp16 else "fp16")   
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(mixed_precision="no" if not args.fp16 else "fp16", kwargs_handlers=[ddp_kwargs])

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ckpt_path = os.path.join(args.pretrained_ckpt_path, args.cl_tag)
    # ckpt_path = '/efs/checkpoints/{}.pt'.format(args.cl_tag)
    ckpt = torch.load(ckpt_path, map_location=device)
    ckpt_hp = ckpt['hp']
    print(ckpt_hp)
 
    setattr(ckpt_hp, 'batch_size', args.batch_size)
    setattr(ckpt_hp, 'hidden_dropout_prob', args.dropout_prob)
    setattr(ckpt_hp, 'shortcut_name', args.shortcut_name)
    setattr(ckpt_hp, 'num_labels', args.num_classes)
    
    
    
    tokenizer = BertTokenizer.from_pretrained(shortcut_name)
    padder = collate_fn(tokenizer.pad_token_id)


Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


args={"wandb": false, "model": "Watchog", "unlabeled_train_only": false, "context_encoding_type": "v0", "pool_version": "v0.2", "random_sample": false, "comment": "debug", "shortcut_name": "bert-base-uncased", "max_length": 64, "adaptive_max_length": false, "max_num_col": 8, "batch_size": 16, "epoch": 1, "random_seed": 4649, "train_n_seed_cols": -1, "num_classes": 101, "multi_gpu": false, "fp16": false, "warmup": 0.0, "lr": 5e-05, "task": "gt-semtab22-dbpedia-all0", "colpair": false, "metadata": false, "from_scratch": false, "cl_tag": "wikitables/simclr/bert_100000_10_32_256_5e-05_sample_row4,sample_row4_tfidf_entity_column_0.05_0_last.pt", "dropout_prob": 0.5, "eval_test": true, "small_tag": "semi1", "data_path": "/data/zhihao/TU/", "pretrained_ckpt_path": "/data/zhihao/TU/Watchog/model/"}
gt-semtab22-dbpedia-all0/wikitables-simclr-bert_100000_10_32_256_5e-05_sample_row4,sample_row4_tfidf_entity_column_0.05_0_last.pt_bert-base-uncased-poolsemi1-max_colsv0.2-rand8-bsFalse-ml16-ne64-do1

  ckpt = torch.load(ckpt_path, map_location=device)


Namespace(augment_op='sample_row4,sample_row4', batch_size=32, data_path='/data/zhihao/TU/TURL/', fp16=True, gpus='0', lm='bert', logdir='/data/zhihao/TU/Watchog/model/', lr=5e-05, max_len=256, mode='simclr', model='Watchog', n_epochs=10, pretrain_data='wikitables', pretrained_model_path='', projector=768, run_id=0, sample_meth='tfidf_entity', save_model=10, single_column=False, size=100000, table_order='column', temperature=0.05)


In [13]:
device

device(type='cuda')



In [6]:
model.eval()
model.sampler.training

False

In [7]:
device

device(type='cuda')

In [3]:

max_num_col = 16
model = BertForMultiSelectionClassification(ckpt_hp, device=device, lm=ckpt['hp'].lm, version=args.pool_version, max_num_cols=max_num_col)






In [4]:
dataset_cls = GittablesColwiseMaxDataset
test_dataset = dataset_cls(cv=cv,
                            split="test",
                            tokenizer=tokenizer,
                            max_length=max_length,
                            device=device,
                            base_dirpath=os.path.join(args.data_path, "GitTables/semtab_gittables/2022"),
                            small_tag=args.small_tag,
                            max_num_col=max_num_col,
                            random_sample=False,
                            context_encoding_type="v1.2",
                            adaptive_max_length=False ,
                            return_table_embedding=True)
padder = collate_fn(tokenizer.pad_token_id)
test_dataloader = DataLoader(test_dataset,
                                batch_size=batch_size,
                                collate_fn=padder)   

semi1_cv_{}.csv 16
test 1085


In [9]:
batch.keys()

dict_keys(['data', 'label', 'token_type_ids', 'table_embedding'])

In [17]:
model.bert.device

device(type='cpu')

In [19]:
batch["table_embedding"].dtype

torch.float64

In [23]:
for batch in test_dataloader:
    # print(batch)
    break
model.train()
model = model.to(device)
cls_indexes = torch.nonzero(
    batch["data"].T == tokenizer.cls_token_id)
logits = model(batch["data"].T, cls_indexes=cls_indexes, 
                               token_type_ids= None,
                               column_embeddings=batch['table_embedding'].float(), return_gates=True)

In [24]:
gates = logits[-1].clone().detach()

In [7]:
gates.shape

torch.Size([16, 15])

In [18]:
last_gates = gates.clone().detach()

In [25]:
torch.equal(gates, last_gates)

False

In [28]:
((last_gates != gates).sum(1)/ gates.shape[1]).mean()

tensor(0.4917, device='cuda:0')

In [30]:
torch.norm(last_gates - gates, p=2)

tensor(10.8628, device='cuda:0')

In [14]:
gates

tensor([[1., 1., 0., 1., 0., 1., 0., 0., 1., 0., 0., 1., 1., 0., 1.],
        [0., 1., 0., 0., 0., 1., 1., 1., 1., 0., 0., 1., 1., 1., 0.],
        [0., 0., 0., 1., 0., 1., 1., 0., 0., 1., 0., 1., 1., 1., 1.],
        [1., 0., 0., 1., 0., 0., 1., 1., 1., 0., 1., 1., 0., 0., 1.],
        [1., 0., 0., 0., 1., 1., 0., 1., 1., 0., 1., 1., 0., 0., 1.],
        [0., 0., 1., 0., 1., 1., 1., 1., 0., 1., 0., 1., 0., 0., 1.],
        [0., 1., 1., 1., 0., 0., 1., 1., 0., 0., 1., 0., 1., 0., 1.],
        [0., 0., 0., 1., 0., 1., 1., 1., 0., 1., 0., 1., 0., 1., 1.],
        [0., 0., 1., 1., 1., 0., 1., 1., 0., 0., 1., 0., 0., 1., 1.],
        [1., 1., 1., 1., 0., 1., 0., 0., 0., 1., 0., 0., 1., 1., 0.],
        [0., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 0., 0., 1.],
        [0., 1., 0., 0., 0., 1., 1., 0., 1., 0., 1., 1., 1., 0., 1.],
        [1., 1., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 1., 0., 0.],
        [0., 1., 0., 1., 1., 1., 0., 0., 1., 1., 1., 0., 0., 0., 1.],
        [0., 1., 1.,

In [6]:
batch["table_embedding"].device

device(type='cpu')

In [6]:
for batch in test_dataloader:
    break
input_ids = batch["data"].T


In [10]:
len(input_ids.split(max_length, dim=1))

16

In [5]:
for batch in test_dataloader:
    break
model = model.to(device)
cls_indexes = torch.nonzero(
                    batch["data"].T == tokenizer.cls_token_id)
logits = model(batch["data"].T, cls_indexes=cls_indexes)

In [3]:
model = BertForMultiOutputClassification(ckpt_hp, device=device, lm=ckpt['hp'].lm, version=args.pool_version)
max_num_col = 8
dataset_cls = GittablesColwiseDataset
test_dataset = dataset_cls(cv=cv,
                            split="test",
                            tokenizer=tokenizer,
                            max_length=max_length,
                            device=device,
                            base_dirpath=os.path.join(args.data_path, "GitTables/semtab_gittables/2022"),
                            small_tag=args.small_tag,
                            max_num_col=max_num_col,
                            random_sample=False,
                            context_encoding_type="v1.2",
                            adaptive_max_length=False   )
test_dataloader = DataLoader(test_dataset,
                                batch_size=batch_size,
                                collate_fn=padder)   
padder = collate_fn(tokenizer.pad_token_id)
# model.eval()
# model = model.to(device)
# labels_context = []
# embeddings_context = []
# with torch.no_grad():
#     for batch_idx, batch in enumerate(test_dataloader):
#         cls_indexes = torch.nonzero(
#                 batch["data"].T == tokenizer.cls_token_id)
#         logits, embs = model(batch["data"].T, cls_indexes=cls_indexes, get_enc=True)
#         label = batch["label"].T.cpu()
#         labels_context += label.numpy().tolist()
#         embeddings_context.append(embs.cpu())
# embeddings_context = torch.cat(embeddings_context, dim=0).numpy()
# labels_context = np.array(labels_context)



test
test 1085


In [4]:
max_length

64

In [4]:
model = BertForMultiOutputClassification(ckpt_hp, device=device, lm=ckpt['hp'].lm, version=args.pool_version)
max_num_col = 8
dataset_cls = GittablesColwiseMaxDataset
test_dataset = dataset_cls(cv=cv,
                            split="test",
                            tokenizer=tokenizer,
                            max_length=max_length,
                            device=device,
                            base_dirpath=os.path.join(args.data_path, "GitTables/semtab_gittables/2022"),
                            small_tag=args.small_tag,
                            max_num_col=10,
                            random_sample=False,
                            context_encoding_type="v1.2",
                            adaptive_max_length=False   )
test_dataloader = DataLoader(test_dataset,
                                batch_size=batch_size,
                                collate_fn=padder)   
padder = collate_fn(tokenizer.pad_token_id)





test
test 1085


In [6]:
temp_dataloader = DataLoader(test_dataset,
                                batch_size=batch_size,
                                collate_fn=padder) 

In [25]:
test_dataset[3]['data'].shape

torch.Size([512])

In [4]:
table_embedding = model.bert.embeddings(batch["data"].T)

NameError: name 'batch' is not defined

In [19]:
batch["data"].T.shape

torch.Size([16, 512])

In [None]:
B = table_embedding.shape[0]
N = table_embedding.shape[1]
M = 4
D = table_embedding.shape[-1]

In [5]:
table_embedding.shape

torch.Size([13, 512, 768])

In [21]:
column_indexes = torch.arange(8, device=device).unsqueeze(1).repeat(1, 64).flatten().unsqueeze(0).repeat(table_embedding.shape[0], 1)
print(column_indexes.shape)

torch.Size([13, 512])


In [111]:
from torch_scatter import scatter
# Broadcasting in the first and last dim.
col_embeddings = []
for i in range(B):
    out = scatter(table_embedding[i], column_indexes[i], dim=0, reduce="mean")
    col_embeddings.append(out)
col_embeddings = torch.stack(col_embeddings)
col_to_token = torch.eye(M, device=device, dtype=torch.float).unsqueeze(-1).repeat(1, 1, N//M).transpose(1, 2).flatten(0, 1).transpose(0, 1).unsqueeze(0).repeat(B, 1, 1)


In [110]:
col_to_token.dtype

torch.float64

In [102]:
col_to_token

tensor([[1., 1., 1.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 1., 1.]], device='cuda:0')

In [45]:
gates = torch.randint(0, 2, (B, M), device=device)

In [None]:
gates

In [23]:
col_embeddings.shape

torch.Size([13, 8, 768])

In [28]:
import torch.nn.functional as F
logits = torch.randn(B, N, device=device)
# Sample soft categorical using reparametrization trick:
# F.gumbel_softmax(logits, tau=1, hard=False)
# Sample hard categorical using "Straight-through" trick:
gates = F.gumbel_softmax(logits, tau=1, hard=True)

In [26]:
gates.shape

torch.Size([13, 512])

In [39]:
col_to_token = torch.eye(M, device=device, dtype=torch.float).unsqueeze(-1).repeat(1, 1, N//M).transpose(1, 2).flatten(0, 1).transpose(0, 1).unsqueeze(0).repeat(B, 1, 1)


torch.Size([13, 512])

In [154]:
sampler = SubsetOperator(k=2, tau=1.0, hard=True)
gates = sampler(logits) # (B, M)

In [155]:
gates

tensor([[1., 0., 0., 1.],
        [1., 0., 0., 1.],
        [1., 0., 1., 0.],
        [1., 1., 0., 0.],
        [1., 1., 0., 0.],
        [0., 1., 0., 1.],
        [0., 0., 1., 1.],
        [0., 0., 1., 1.],
        [1., 0., 0., 1.],
        [0., 0., 1., 1.],
        [0., 0., 1., 1.],
        [0., 1., 1., 0.],
        [0., 0., 1., 1.]], device='cuda:0')

In [164]:
num_tokens_per_col = L // N
import torch.nn.functional as F
from torch_scatter import scatter
token_col_indexes = torch.cat([torch.tensor([i]*num_tokens_per_col) for i in range(N)]).to(device)

In [166]:
column_embeddings = scatter(table_embedding, token_col_indexes, dim=1, reduce="mean")

In [5]:
B = args.batch_size
L = 512
D = 768
N = 10 # total number of columns
M = 8-1 # number of columns to sample
num_tokens_per_col = args.max_length
import torch.nn as nn
for batch in test_dataloader:
    break
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # torch.device('cpu')
model = model.to(device)
model.train()

BertForMultiOutputClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps

In [25]:
context_embeddings.shape

torch.Size([16, 9, 1536])

In [16]:
for batch in test_dataloader:
    break
res = batch["data"].T.split(num_tokens_per_col, dim=1)
table_embedding = []
for i in range(N):
    if i == 0:
        table_embedding.append(model.bert.embeddings(res[i].to(device), token_type_ids=torch.zeros_like(res[i].to(device))))
    else:
        table_embedding.append(model.bert.embeddings(res[i].to(device), token_type_ids=torch.ones_like(res[i].to(device))))
table_embedding = torch.cat(table_embedding, dim=1)

In [8]:

projector = nn.Linear(2*D, 1).to(device)
import torch.nn.functional as F
from torch_scatter import scatter
token_col_indexes = torch.cat([torch.tensor([i]*num_tokens_per_col) for i in range(N)]).to(device) # [0, 0, ..., 1, 1, ..., 2, 2, ..., N-1, N-1, ..., N-1]
col_to_token = torch.eye(N, device=device, dtype=torch.float) \
            .unsqueeze(-1).repeat(1, 1, num_tokens_per_col).transpose(1, 2).flatten(0, 1) \
            .transpose(0, 1)[1:, :].unsqueeze(0).repeat(B, 1, 1) # (B, N-1, L) [1, ,1 ,1 , ..., 1, 0, 0, ..., 0, 0, 0], ...,  [0, 0, 0, ..., 0, 1, 1, ..., 1]
            
column_embeddings = scatter(table_embedding, token_col_indexes, dim=1, reduce="mean") # (B, N, D)
target_column_embeddings, context_column_embeddings = column_embeddings[:, 0, :].unsqueeze(1),  column_embeddings[:, 1:, :]# (B, M, D)
target_column_embeddings = target_column_embeddings.repeat(1, context_column_embeddings.shape[1], 1) # (B, N, D)
context_embeddings = torch.cat([target_column_embeddings, context_column_embeddings], dim=-1) # (B, N, D)
logits = projector(context_embeddings).squeeze() # (B, N)
sampler = SubsetOperator(k=M, tau=1.0, hard=True)
gates = sampler(logits) # (B, M)
chosen_mask = torch.matmul(gates.unsqueeze(1), col_to_token).squeeze(1).unsqueeze(-1) # (B, N, 1)
chosen_embeddings = table_embedding * chosen_mask 
chosen_embeddings = chosen_embeddings[chosen_mask.detach().bool().expand_as(chosen_embeddings)].view(B, -1, D)
embeddings = torch.cat([table_embedding[:, :num_tokens_per_col, :], chosen_embeddings], dim=1)



In [10]:
encoder_outputs = model.bert.encoder(
    embeddings,
)
sequence_output = encoder_outputs[0]

In [18]:
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa

In [20]:
model.bert.get_head_mask

<bound method ModuleUtilsMixin.get_head_mask of BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
   

In [13]:

labels = torch.randint(0, 3, (B, ), device=device)
linear = nn.Linear(D, 3).to(device)
res = linear(sequence_output.mean(1))
loss = F.cross_entropy(res, labels)
loss.backward()

In [14]:
projector.weight.grad

tensor([[ 5.5931e-11, -1.3787e-10, -4.0102e-10,  ...,  8.3100e-04,
          2.9412e-02, -1.1279e-03]], device='cuda:0')

In [30]:
table_embedding.shape

torch.Size([16, 640, 768])

In [31]:
chosen_embeddings.shape

torch.Size([16, 512, 768])

In [9]:
embeddings.shape

torch.Size([16, 512, 768])

In [20]:
table_embedding.shape

torch.Size([16, 640, 768])

In [17]:
column_embeddings.shape

torch.Size([16, 512, 768])

In [9]:
res[0]

tensor([[  101,  2110,  1035,  ...,     0,     0,     0],
        [  101,  2632, 22123,  ...,  2226,  9706,  6711],
        [  101,  9349,  1025,  ...,  9349,     0,     0],
        ...,
        [  101, 10283,  1025,  ..., 22304,  1025,  5367],
        [  101, 10283,  1025,  ...,  5215,  1025, 16598],
        [  101,  1016,  1012,  ...,  1017,  1012,  1014]], device='cuda:0')

In [10]:
res[1]

tensor([[  102,  2632, 22123,  ...,  2226,  9706,  6711],
        [  102,  2110,  1035,  ...,     0,     0,     0],
        [  102, 17350, 21926,  ..., 21472, 11387,  1025],
        ...,
        [  102,  1014,  1025,  ...,  2382,  1025,  2861],
        [  102,  1014,  1025,  ...,  2382,  1025,  2861],
        [  102,  1014,  1025,  ...,     0,     0,     0]], device='cuda:0')

In [36]:
col_to_token = torch.eye(N, device=device, dtype=torch.float) \
            .unsqueeze(-1).repeat(1, 1, num_tokens_per_col).transpose(1, 2).flatten(0, 1) \
            .transpose(0, 1)[1:, :].unsqueeze(0).repeat(B, 1, 1) # (B, N, L) [1, ,1 ,1 , ..., 1, 0, 0, ..., 0, 0, 0], ...,  [0, 0, 0, ..., 0, 1, 1, ..., 1]
            

In [27]:
from transformers import BertTokenizer

# Load the BERT tokenizer

# Example sentences


# Encode the sentences with padding
encoded = tokenizer("how are you " + "[PAD]" * 4, padding=True, truncation=True, return_tensors="pt", add_special_tokens=False)


print(encoded["input_ids"])
# Check the padding token and its ID
print("Padding token:", tokenizer.pad_token)
print("Padding token ID:", tokenizer.pad_token_id)


tensor([[2129, 2024, 2017,    0,    0,    0,    0]])
Padding token: [PAD]
Padding token ID: 0


In [28]:
"how are you " + "[PAD]" * 4

'how are you [PAD][PAD][PAD][PAD]'

In [42]:
chosen_embedding.shape

torch.Size([16, 256, 768])

In [31]:
col_to_token.shape

torch.Size([16, 8, 448])

In [28]:
col_to_token.shape

torch.Size([16, 8, 512])

In [27]:
torch.equal(target_column_embeddings[0, 1, :], target_column_embeddings[0, 2, :])

True

In [20]:
context_embeddings.shape

torch.Size([16, 7, 1536])

In [15]:
target_column_embeddings.shape

torch.Size([16, 1, 768])

In [16]:
context_column_embeddings.shape

torch.Size([16, 7, 768])

In [185]:
labels = torch.randint(0, 2, (B, 1), device=device)
linear = nn.Linear(D, 1).to(device)
res = linear(chosen_embedding.mean(1))

In [188]:
res.shape

torch.Size([13, 2])

In [183]:
chosen_embedding.shape

torch.Size([13, 256, 768])

In [184]:
chosen_embedding.mean(1).shape

torch.Size([13, 768])

In [7]:
projector.weight.grad

tensor([[ 9.5579e-11, -5.7231e-11,  2.1875e-11,  ...,  1.3810e-03,
         -3.9660e-04,  4.0686e-04]], device='cuda:0')

In [179]:
chosen_embedding.shape

torch.Size([13, 256, 768])

In [162]:
table_embedding.scatter(1, column_indexes, table_embedding)

64

In [160]:
col_to_token[0][0]

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 

In [159]:
col_to_token[0].shape

torch.Size([8, 512])

In [6]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, TensorDataset, ConcatDataset, random_split
from torch.utils.data.sampler import SubsetRandomSampler

import random
import time
from pathlib import Path

import numpy as np

import matplotlib
import matplotlib.pyplot as plt
EPSILON = np.finfo(np.float32).tiny

class SubsetOperator(torch.nn.Module):
    def __init__(self, k, tau=1.0, hard=False):
        super(SubsetOperator, self).__init__()
        self.k = k
        self.hard = hard
        self.tau = tau

    def forward(self, scores):
        assert len(scores.shape) == 2
        m = torch.distributions.gumbel.Gumbel(torch.zeros_like(scores), torch.ones_like(scores))
        g = m.sample()
        scores = scores + g

        # continuous top k
        khot = torch.zeros_like(scores)
        onehot_approx = torch.zeros_like(scores)
        for i in range(self.k):
            khot_mask = torch.max(1.0 - onehot_approx, torch.tensor([EPSILON]).to(scores.device))
            scores = scores + torch.log(khot_mask)
            onehot_approx = torch.nn.functional.softmax(scores / self.tau, dim=1)
            khot = khot + onehot_approx

        if self.hard:
            # straight through
            khot_hard = torch.zeros_like(khot)
            val, ind = torch.topk(khot, self.k, dim=1)
            khot_hard = khot_hard.scatter_(1, ind, 1)
            res = khot_hard - khot.detach() + khot
        else:
            res = khot

        return res


In [163]:
sampler = SubsetOperator(k=2, tau=1.0, hard=False)

x = torch.tensor([[1.,2.,3.,4.]])
y = sampler(x)
print(y, y.sum())

tensor([[0.1325, 0.1989, 0.3476, 1.3210]]) tensor(2.)
