In [107]:
import torch
import json
import numpy as np
from transformers import BertTokenizer, BertForSequenceClassification, BertModel
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from utils import get_sorted_tweets, get_target_words


In [67]:
# Load BERT model and tokenizer
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name).cuda()
tweets = get_sorted_tweets()
target_words = get_target_words()

In [58]:
target_words

['frisk',
 'impostor',
 'containment',
 'gala',
 'recount',
 'lotte',
 'pogrom',
 'parasol',
 'pyre',
 'milker',
 'launchpad',
 'vanguard',
 'airdrop',
 'ventilator',
 'villager',
 'primo',
 'delta',
 'epicenter',
 'burnham',
 'bullpen',
 'virus',
 'turnip',
 'monet',
 'mask',
 'crt',
 'ido',
 'unlabeled',
 'teargas',
 'gaza',
 'folklore',
 'entanglement',
 'paternity',
 'bunker',
 'moxie']

In [91]:
def generate_vector_from_context(word, text):
    tok_w = tokenizer(word, return_tensors='pt', add_special_tokens=False)
    tok = int(tok_w['input_ids'].flatten()[0])
    len_tok = len(tok_w['input_ids'].flatten())
    tok_t = tokenizer(text, return_tensors='pt', padding='max_length')
    ids = tok_t['input_ids'].flatten().tolist()
    if tok in ids:
        idx = ids.index(tok)
    else:
        raise ValueError(f'{tok} from {tok_w} not in list {ids}. \n text: {text} word {word} \n tokenizer decode: {tokenizer.decode(ids)}')
    for item in tok_t:
        tok_t[item] = tok_t[item].to('cuda')
    vec = model(**tok_t)['last_hidden_state'].squeeze(0)[idx:idx+len_tok].cpu().detach().numpy()
    vec = np.average(vec, axis=0)
    return vec

In [93]:
def avg_vector_by_year(year):
    target_word_vectors = {wrd: [] for wrd in target_words}
    data = tweets[year]
    for t in data:
        word = t['word']
        text = t['text']
        try:
            vec = generate_vector_from_context(word, text)
        except ValueError:
            pass
        target_word_vectors[word].append(vec)
    for wrd in target_words:
        vecs = np.array(target_word_vectors[wrd])
        target_word_vectors[wrd] = np.average(vecs, axis=0)

    return target_word_vectors



In [94]:
vecs = avg_vector_by_year('2021')

  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)


In [None]:
def load_data(data_path, labels_path):
    # Load tweet instances
    with open(data_path, 'r', encoding='utf-8') as file:
        data_instances = [json.loads(line) for line in file]

    # Load labels
    with open(labels_path, 'r', encoding='utf-8') as file:
        labels = dict(line.strip().split('\t') for line in file)

    return data_instances, labels

In [100]:
# Set paths
train_data_path = 'data/train.data.jl'
train_labels_path = 'data/train.labels.tsv'

# Load data
data_instances, labels = load_data(train_data_path, train_labels_path)

pairs = {item['id']: [item['tweet1']['text'], item['tweet2']['text'], item['word']] for item in data_instances}

In [127]:
def find_acc(threshold):
    correct = 0
    count = 0
    for key in pairs:
        label = labels[key]
        t1, t2, word = pairs[key]
        try:
            vec1 = generate_vector_from_context(word, t1)
            vec2 = generate_vector_from_context(word, t2)
        except ValueError:
            continue
        res = float(cosine_similarity([vec1], [vec2]).flatten()[0])
        res = 1 if res > threshold else 0
        if res == int(label):
            correct += 1
        count += 1

    return correct / count


In [128]:
for threshold in range(10):
    i = threshold / 10
    print(find_acc(i))
# best threshold 0.7

0.45646067415730335
0.45646067415730335
0.45646067415730335
0.45646067415730335
0.46348314606741575
0.4768258426966292
0.5344101123595506
0.6587078651685393
0.613061797752809
0.5484550561797753
