In [6]:
import collections
import six 

def convert_tokens_to_ids(vocab, tokens):
    """Converts a sequence of tokens into ids using the vocab."""
    ids = []
    for token in tokens:
        if token not in vocab.keys():
            ids.append(vocab['[UNK]'])
        else:
            ids.append(vocab[token])
    return ids
def convert_to_unicode(text):
    """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
    if six.PY3:
        if isinstance(text, str):
            return text
        elif isinstance(text, bytes):
            return text.decode("utf-8", "ignore")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    elif six.PY2:
        if isinstance(text, str):
            return text.decode("utf-8", "ignore")
        elif isinstance(text, unicode):
            return text
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    else:
        raise ValueError("Not running on Python2 or Python 3?")
    
def load_vocab(vocab_file):
    """Loads a vocabulary file into a dictionary."""
    vocab = collections.OrderedDict()
    index = 0
    with open(vocab_file, "r") as reader:
        while True:
            token = convert_to_unicode(reader.readline())
            if not token:
                break
            token = token.strip()
            vocab[token] = index
            index += 1
    return vocab
class WordLevelTokenizer(object):
    """Runs end-to-end tokenziation."""
    def __init__(self, vocab_file, config, delimiter=" ", max_seq_len=128):
        self.vocab = load_vocab(vocab_file)
        self.vocab_reverse = collections.OrderedDict()
        for k, v in self.vocab.items():
            self.vocab_reverse[v] = k
        self.pad_token_id = config.pad_token_id
        self.bos_token_id = config.bos_token_id
        self.eos_token_id = config.eos_token_id
        self.unk_token_id = config.unk_token_id
        self.mask_token_id = config.mask_token_id
        self.special_token_ids = set(
            [config.pad_token_id, config.bos_token_id, config.eos_token_id, 
            config.unk_token_id, config.mask_token_id]
        )
        
        self.max_seq_len = max_seq_len
        self.delimiter = delimiter
        
    def tokenize(self, text):
        split_tokens = []
        for token in text.split(self.delimiter):
            split_tokens.append(token)
        return split_tokens
    
    def convert_tokens_to_ids(self, tokens):
        return convert_tokens_to_ids(self.vocab, tokens)
    
    def __call__(self, text):
        original = self.convert_tokens_to_ids(self.tokenize(text))
        original = original[:(self.max_seq_len-2)]
        return [self.bos_token_id] + original + [self.eos_token_id]
    
    def batch_decode(self, pred_labels, skip_special_tokens=True):
        decode_labels_batch = []
        for labels in pred_labels:
            decode_labels = []
            for l in labels.tolist():
                if l == self.eos_token_id:
                    break
                if l not in self.special_token_ids:
                    decode_labels += [self.vocab_reverse[l]]
            decode_labels_batch += [self.delimiter.join(decode_labels)]
        return decode_labels_batch

In [34]:
import os
from transformers import AutoConfig
import torch
from torch import Tensor 
from jaxtyping import Shaped
from typing import Callable

In [15]:

config_encoder = AutoConfig.from_pretrained(
                os.path.join("/Users/georgenakayama/Documents/Academics/CS/CS224U/ReCOGS_224U/model", "encoder_config.json")
            )

config_decoder = AutoConfig.from_pretrained(
                    os.path.join("/Users/georgenakayama/Documents/Academics/CS/CS224U/ReCOGS_224U/model", "decoder_config.json")
            )
src_tokenizer = WordLevelTokenizer(
    os.path.join("/Users/georgenakayama/Documents/Academics/CS/CS224U/ReCOGS_224U/model", "src_vocab.txt"), 
    config_encoder,
    max_seq_len=512
)
tgt_tokenizer = WordLevelTokenizer(
    os.path.join("/Users/georgenakayama/Documents/Academics/CS/CS224U/ReCOGS_224U/model", "tgt_vocab.txt"), 
    config_decoder,
    max_seq_len=512
)
AND_token_id = tgt_tokenizer("AND")[1]
SEMICOLON_token_id = tgt_tokenizer(";")[1]
print(AND_token, SEMICOLON_token)


68 67


In [41]:
def chamferToken(loss_fn: Callable , 
                 a: Shaped[Tensor, "bs nt nl vs"], 
                 b: Shaped[Tensor, "bs nt nl"], 
                 mask_a: Shaped[Tensor, "bs nt nl"], 
                 mask_b: Shaped[Tensor, "bs nt nl"],
                 reduce: bool=True):
    """
    chamfer dist for tokens
    loss_fn: loss function that computes distance between tokens
    a: [bs, nt, nl, vocab_size] number of tokens, max length
    b: [bs, nt, nl]
    mask_a, mask_b: [bs, nt, nl], mask of valid tokens
    reduce: if reduce, will return the averaged number
    NOTE: loss_fn should NOT output averaged elements, and should NOT ignore indices. 
    ignored indices should be indicated in masks
    """
    bs, nt, nl, vs = a.shape
    # for each token in a, get the min dist in b  
    token_mask_a, token_mask_b = mask_a.sum(-1) == nl, mask_b.sum(-1) == nl
    num_unmasked_a, num_unmasked_b = nl - mask_a.sum(-1), nl - mask_b.sum(-1) 
    num_unmasked_a = torch.where(num_unmasked_a == 0, nl, num_unmasked_a)
    num_unmasked_b = torch.where(num_unmasked_b == 0, nl, num_unmasked_b)
    _a = a.unsqueeze(2).repeat_interleave(nt, dim=2).permute(0, 1, 4, 2, 3).reshape(bs*nt, vs, nt, nl) # [bs, nt, vs, nt, nl]
    _b = b.unsqueeze(1).repeat_interleave(nt, dim=1).reshape(bs*nt, nt, nl) #[bs, nt, nt, nl]
    a_loss = loss_fn(_a, _b).reshape(bs, nt, nt, nl) # [bs* nt, nt, nl]
    a_loss = a_loss.masked_fill(mask_b.reshape(bs, 1, nt, nl), value=0).sum(-1) / num_unmasked_b.unsqueeze(1)
    a_loss = a_loss.masked_fill(token_mask_b.unsqueeze(1), value=1e9)
    a_loss = a_loss.min(-1)[0] # [bs, nt]
    a_loss = a_loss.masked_fill(token_mask_a, value=0)
    if reduce:
        a_loss = a_loss.sum(-1) / (nt - token_mask_a.sum(-1))
    
    _a = a.unsqueeze(1).repeat_interleave(nt, dim=1).permute(0, 1, 4, 2, 3).reshape(bs*nt, vs, nt, nl) # [bs, nt, vs, nt, nl]
    _b = b.unsqueeze(2).repeat_interleave(nt, dim=2).reshape(bs*nt, nt, nl) #[bs, nt, nt, nl]
    b_loss = loss_fn(_a, _b).reshape(bs, nt, nt, nl) # [bs* nt, nt, nl]
    b_loss = b_loss.masked_fill(mask_a.reshape(bs, 1, nt, nl), value=0).sum(-1) / num_unmasked_a.unsqueeze(1)
    b_loss = b_loss.masked_fill(token_mask_a.unsqueeze(1), value=1e9)
    b_loss = b_loss.min(-1)[0] # [bs, nt]
    b_loss = b_loss.masked_fill(token_mask_b, value=0)
    if reduce:
        b_loss = b_loss.sum(-1) / (nt - token_mask_b.sum(-1))
        
    loss = a_loss + b_loss 
    if reduce:
        return loss.mean()
    return loss
    
bs, nt, nl, vs = 2, 3, 4, 5
ids = torch.randint(0, vs, size=(bs*nt*nl,))
a = torch.eye(vs)[ids].reshape(bs, nt, nl, vs)
b = ids.reshape(bs, nt, nl)
loss_func = lambda x, y: (x.argmax(1) - y) ** 2
mask_a = torch.zeros_like(b).bool()
mask_b = torch.zeros_like(b).bool()
loss = chamferToken(loss_func, a, b, mask_a, mask_b)
print(loss)



tensor(0.)


In [None]:
path = "/Users/georgenakayama/Documents/Academics/CS/CS224U/ReCOGS_224U/cogs_token_removal/test_remove_x_.tsv"
items = []
eval_cat = []
for l in open(path, "r").readlines():
    text, sparse, cat = l.split("\t")
    src_input_ids = src_tokenizer(text)
    tgt_input_ids = tgt_tokenizer(sparse)
    try:
        semicolon_id = len(tgt_input_ids) - tgt_input_ids[::-1].index(SEMICOLON_token_id) - 1
    except ValueError:
        semicolon_id = -1
    print(semicolon_id)
    conj_sparse = torch.tensor(tgt_input_ids[semicolon_id + 1:])
    AND_mask = conj_sparse == AND_token_id
    ids = torch.eye(conj_sparse.shape[0])[AND_mask].argmax(-1)
    ids = torch.cat([torch.zeros(1), ids, torch.ones(1) * conj_sparse.shape[0] - 1])
    conj_lens = ids[1:] - ids[:-1]
    print(ids)
    print(conj_lens)
    conjs = sparse.split(";")[-1].split("AND")
    print(src_input_ids, tgt_input_ids)
    print(text)
    print(sparse)
    print(conjs)
    print(cat.strip())
    print("="*80)