In [2]:
import os
import torch
import torch.nn as nn
import json
import os
import sys
import logging
import random
import datetime
import numpy as np
import math
import itertools
from tqdm import tqdm, trange
from torch.nn import functional as F
from sentence_transformers import SentenceTransformer, models, losses, util
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Connect to GPUs
os.environ["CUDA_VISIBLE_DEVICES"] = '0' # Change to GPUs being used
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [4]:
# Capture cell output in case session disconnects
today = datetime.date.today()
so = open(str(today) + ".log", 'w', 10) # Replace with current date
sys.stdout.echo = so
get_ipython().log.handlers[0].stream = so
get_ipython().log.setLevel(logging.INFO)

In [5]:
# Check current directory
cwd = os.getcwd()
files = os.listdir(cwd)
print("Files in %r: %s" % (cwd, files))

Files in '/home/jchauhan/Legal Ideas': ['07_wiki_validation-cross-validation.ipynb', 'more_words.ipynb', 'separate_ideas.ipynb', '2023-03-02.log', 'data-wiki', '.ipynb_checkpoints', '20230213.log', '2023-03-01.log', '20230212.log', 'clean_si_results_2', 'experiments_1', 'clean_si_results', 'repeatable.ipynb', 'clean_separate_ideas.ipynb', 'experiment_2.ipynb']


In [6]:
def get_vocab(fp):
    """
    Returns mapping of word to frequency from vocab file
    
    :param fp: Filepath to vocab
    :return: Dictionary, word (annotated with sense) -> frequency
    """
    vocab = {}
    with open(fp) as f: # 'Legal Ideas/data-wiki/vocab.txt'
        lines = f.readlines()
        for line in lines:
            line = line.split()
            vocab[line[0]] = int(line[1])
    
    return vocab

def get_word_senses(vocab):
    """
    Returns mapping of word to sense(s)
    
    :param vocab: Dictionary, word (annotated with sense) -> frequency
    :return: Dictionary, word -> sense(s)
    """
    terms = [k for k in vocab if k[0] == '@' and '@' in k[1:]]
    words = defaultdict(set)
    for term in terms:
        freq = vocab[term]
        word = term[1:term[1:].index('@')+1]
        category = term[term[1:].index('@') + 3:]
        if freq < 1000 or len(category) == 0: # Ignore words w/o annotated sense or over 1000 sentences 
            continue
        words[word].add(category)
    
    return words

In [7]:
def get_ts_words(words):
    """
    Return mapping of word to senses for all words that have two senses
    
    :param words: Dictionary, word -> sense(s)
    :return: Dictionary, word -> sense(s), for every word with two senses
    """
    candidates = {}
    for word in words:
        if len(words[word]) == 2:
            candidates[word] = list(words[word])
            random.shuffle(candidates[word])
    
    return candidates

In [8]:
def _count_generator(reader):
    """
    See https://stackoverflow.com/questions/75480052/can-someone-please-explain-how-this-function-works
    """
    b = reader(1024 * 1024)
    while b:
        yield b
        b = reader(1024 * 1024)

def count_lines(fp):
    """
    :param fp: Filepath
    :return: Int, number of \n characters in file
    """
    with open(fp, 'rb') as fp:
        c_generator = _count_generator(fp.raw.read)
        count = sum(buffer.count(b'\n') for buffer in c_generator)
        return count

In [9]:
def get_index(vocab, index_path, corpus_path):
    """
    Builds and saves (as JSON file) mapping from word to indices of sentences containing word. 
    Retrieves the index if it already exists.
    
    :param vocab: Dictionary, word (annotated with sense) -> frequency
    :param index_path: Filepath the index should be saved to or retrieved from
    :param corpus_path: Filepath to corpus
    :return: Dictionary, word (annotated with sense) -> list of sentence_indices
    """
    index = defaultdict(list)

    if os.path.exists(index_path): # 
        print("Starting index load")
        with open(index_path, 'r') as f:
            index = json.load(f)
    else:
        corpus_size = count_lines(corpus_path)
        
        print("Building index")
        with open(corpus_path) as f: # 
            for i, line in tqdm(enumerate(f), total=corpus_size):
                line = line.split()
                for part in line:
                    if part in vocab:
                        index[part].append(i)     
        
        print("Starting index dump")
        with open(index_path, 'w') as f:
            json.dump(index, f)
            
    return index

In [10]:
def load_sentences_by_indices(indices, corpus_path): 
    """
    Returns a mapping of index to sentence at index in corpus 
    
    :param indices: List of ints, indices of sentences to gather
    :param corpus_path: Filepath containing sentences
    :return: Dictionary, index -> sentence in corpus at index 
    """
    indices_set = set(indices)
    indices_map = {}
    
    corpus_size = count_lines(corpus_path)
    with open(corpus_path) as f:
        for i, line in tqdm(enumerate(f), total=corpus_size):
            if i in indices_set:
                indices_map[i] = line.strip()
    return indices_map

def get_indexed_sentences(candidates, index, corpus_path):
    """
    Returns a mapping of index to sentence at index for every sentence containing a key from candidates
    
    :param candidates: Dictionary, word -> sense(s)
    :return: List of ints, indices for all sentences containing words in candidates
    """
    indices = set()
    for k in candidates:
        for v in candidates[k]:
            word = '@{}@-{}'.format(k, v)
            indices = indices | set(index[word])

    sentences_map = load_sentences_by_indices(indices, corpus_path)
    return sentences_map

In [11]:
def sentences_embeddings(model, sentences):
    """
    Returns embeddings for sentences
    
    :param model: Sentence encoder
    :param sentences: List of strings, sentences to encode
    :return: np.array, embeddings for sentences
    """
    sentences_revealed = []
    for sentence in sentences:
        parts = sentence.split()
        parts_revealed = []
        for part in parts:
            if part[0] == '@' and '@' in part[1:]: # Removes sense annotation
                part = part[1:part[1:].index('@') + 1]
            parts_revealed.append(part)
        sentences_revealed.append(' '.join(parts_revealed))
    return model.encode(sentences_revealed)

In [12]:
def init_weights(m):
    """
    Initializes Linear layers of m using Uniform Xavier initialization
    """
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

class MLP1(nn.Module):
    """
    1 hidden layer, fully connected NN
    """
    def __init__(self, input_size, hidden_size, pdrop=.5):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.GELU(),
            nn.Dropout(pdrop),
            nn.Linear(hidden_size, 2),
        ).to(device)
        
        self.mlp.apply(init_weights)  

    def forward(self, x):
        output = self.mlp(x)
        return output

In [13]:
 class MLP2(nn.Module):
    """
    2 hidden layers, fully connected NN
    """
    def __init__(self, input_size, hidden_size, pdrop=.5):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Dropout(pdrop),
            nn.Linear(hidden_size, 2),
        ).to(device)
        
        self.mlp.apply(init_weights)
        
    def forward(self, x):
        output = self.mlp(x)
        return output

In [14]:
class SentenceDataset(Dataset):
    """
    Dataset for sentence embeddings
    """
    def __init__(self,x,y):
        self.x = torch.tensor(x,dtype=torch.float32).to(device)
        self.y = torch.tensor(y,dtype=torch.float32).to(device)
        self.length = self.x.shape[0]

    def __getitem__(self,idx):
        return self.x[idx],self.y[idx]
    
    def __len__(self):
        return self.length

In [15]:
class NpEncoder(json.JSONEncoder):
    """
    JSON Encoder. Handles ints, floats and numpy arrays
    """
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()

        return super(NpEncoder, self).default(obj)

In [16]:
def find_separation(model, epochs, optimizer, embeddings, res_folder_path):
    """
    For each word in embeddings, computes the separation between the word's two senses and saves the results in json files
    
    :param model: PyTorch model, gets trained to separate embeddings for each word's senses
    :param epochs: Int, training epochs
    :param optimizer: PyTorch optimizer
    :param embeddings: Dictionary, word --> { sense 1: [embeddings], sense 2: [embeddings] }
    :param res_folder_path: Filepath of folder to save results in (folder must already exist)
    :return: None, saves output in files
    """

    for word in list(embeddings.keys()):
        word_embds = embeddings[word]
        [a,b] = list(word_embds.keys())
        total_sizes = [len(word_embds[a]), len(word_embds[b])]

        x = np.concatenate((np.array(word_embds[a]), np.array(word_embds[b]))) # Make sure 0 is right axis
        y = np.concatenate((np.zeros(len(word_embds[a])), np.ones(len(word_embds[b]))), axis=None)

        dataset = SentenceDataset(x,y)

        c_weights = [len(word_embds[b])/x.shape[0], len(word_embds[a])/x.shape[0]]
        loss_fn = nn.CrossEntropyLoss(weight = torch.Tensor(c_weights).to(device))  

        train_size = int(0.875 * len(dataset))
        test_size = len(dataset) - train_size
        train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

        batch_size = 256
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

        total_pos, total_neg, total, best_b_accuracy = 0, 0, 0, 0
        best_i = -1

        for i in tqdm(range(epochs), position=0, leave=True):   
            # Training
            for x_train, y_train in train_dataloader: 
                y_train = y_train.type(torch.LongTensor)
                x_train, y_train = x_train.to(device), y_train.to(device)

                optimizer.zero_grad()  
                pred = model(x_train)   
                train_loss = loss_fn(pred, y_train)  
                train_loss.backward()  
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10, norm_type=2.0) # gradient clipping

                optimizer.step()

            # Evaluation
            if i%5 == 0:
                model.eval()
                r_tp, r_fp, r_fn, r_tn = 0, 0, 0, 0

                with torch.no_grad(): 
                    for x_test, y_test in test_dataloader:
                        x_test, y_test = x_test.to(device), y_test.to(device)
                        y_test = y_test.to(torch.float32) 
                        predicted_outputs = model(x_test) 
                        _, predicted = torch.max(predicted_outputs, 1)             

                        pos_mask, neg_mask = y_test == 1, y_test == 0
                        pos_indices, neg_indices = torch.nonzero(pos_mask), torch.nonzero(neg_mask)  
                        pos_y, neg_y =  y_test[pos_indices], y_test[neg_indices]
                        pos_y_pred, neg_y_pred = predicted[pos_indices], predicted[neg_indices]

                        r_tp += (pos_y == pos_y_pred).sum().item()
                        r_fn += (pos_y != pos_y_pred).sum().item()
                        r_fp += (neg_y != neg_y_pred).sum().item()
                        r_tn += (neg_y == neg_y_pred).sum().item()

                        if i == 0:
                            total_pos += pos_y.size(0)
                            total_neg += neg_y.size(0)
                            test_sizes = [total_neg, total_pos]

                    tpr, tnr = r_tp / total_pos, r_tn / total_neg
                    b_accuracy = (tpr + tnr) / 2
                    precision = r_tp / (r_tp + r_fp)
                    recall = r_tp / (r_tp + r_fn)

                    f1 = 2 * (recall * precision) / (recall + precision) if recall != 0 or precision != 0 else 0

                    # Performance is evaluated as balanced accuracy
                    if b_accuracy > best_b_accuracy:
                        best_f1, best_precision, best_recall, best_i = f1, precision, recall, i
                        best_tpr, best_tnr, best_b_accuracy = tpr, tnr, b_accuracy         

                model.train()

        # Save best performance
        with open(res_folder_path + '/{}.json'.format(word), 'w') as f:
            json.dump([best_f1, best_precision, best_recall, best_b_accuracy, best_tnr, best_tpr, total_sizes, test_sizes, best_i], f, cls=NpEncoder)

In [17]:
# Get words with two senses
vocab = get_vocab("data-wiki/vocab.txt")
words = get_word_senses(vocab)
ts_words = get_ts_words(words)

list(ts_words.keys())

['united_states', 'france', 'american', 'germany', 'canada', 'england', 'brazil', 'italy', 'new_zealand', 'mexico', 'south_africa', 'german', 'scotland', 'poland', 'netherlands', 'sweden', 'ireland', 'italian', 'english', 'switzerland', 'new_south_wales', 'denmark', 'melbourne', 'wales', 'portugal', 'dutch', 'swedish', 'greek', 'northern_ireland', 'milan', 'liverpool', 'scottish', 'barcelona', 'norwegian', 'kent', 'great_britain', 'uruguay', 'danish', 'essex', 'finnish', 'portuguese', 'oxford', 'hindu', 'mexican', 'surrey', 'russian', 'cambridge', 'hampshire', 'somerset', 'leeds', 'japanese', 'hungarian', 'persian', 'chinese', 'chelsea', 'derbyshire', 'richmond', 'southampton', 'celtic', 'republic_of_ireland', 'java', 'hiv', 'union', 'warwickshire', 'confederate', 'cork', 'leicestershire', 'middlesex', 'sunderland', 'nottinghamshire', 'northamptonshire', 'middlesbrough', 'serbian', 'columbia', 'dundee', 'reading', 'pc', 'czech', 'wigan', 'jewish', 'armenian', 'parma', 'darlington', 'ge

In [None]:
index = get_index(vocab, "data-wiki/index.json", "data-wiki/corpus.txt")

In [None]:
sentences_map = get_indexed_sentences(ts_words, index, 'data-wiki/corpus.txt')

In [None]:
model = SentenceTransformer('all-mpnet-base-v2', device='cuda:0') # Change to GPUs being used

embeddings = { word: {
    senses[0]: sentences_embeddings(model, [sentences_map[i] for i in index['@{}@-{}'.format(word, senses[0])]]),
    senses[1]: sentences_embeddings(model, [sentences_map[i] for i in index['@{}@-{}'.format(word, senses[1])]]),
} for word, senses in ts_words.items()}

In [None]:
epochs = 100
optimizer = torch.optim.Adam(model.parameters(),lr=0.001, weight_decay=.0001)  

model1 = MLP1(768, 4 * 768)
model1.to(device)
res1_fp = "si_results"

find_separation(model1, epochs, optimizer, embeddings, res1_fp)

In [None]:
model2 = MLP1(768, 4 * 768)
model2.to(device)
res2_fp = "si_results2"

find_separation(model2, epochs, optimizer, embeddings, res2_fp)

In [82]:
def get_results(fp, words):
    """
    :param fp: Filepath to results file
    :param candidates: Dictionary with keys contains 
    """
    res = {}
    for word in words:
        with open(fp + "/{}.json".format(word), 'r') as f:
            [best_f1, best_precision, best_recall, best_b_accuracy, best_tnr, best_tpr, total_sizes, test_sizes, best_i] = json.load(f)
            res[word] = best_b_accuracy
    return res

In [85]:
# Evaluate model results. Higher values = easier separation of a word's two senses
res1 = get_results('si_results', list(embeddings.keys()))

sorted_res1 = sorted(res1.items(), key=lambda x: x[1])
print("Model 1 Results:")
for item in sorted_res1:
    print(item[0], item[1])

Model 1 Results:
great_britain 0.419407563069804
england 0.42508495215352454
ireland 0.42873802951562906
uruguay 0.43855120562145333
scotland 0.44131963209120395
united_states 0.44197264193617386
new_south_wales 0.44828621284540454
reading 0.4534919487550176
canada 0.45355081442736955
nottinghamshire 0.4572876447876448
hampshire 0.46088345045486734
warwickshire 0.46228623028158633
union 0.4694937497219627
northamptonshire 0.4696754563894523
dundee 0.4705170460397569
southampton 0.4741333490816748
brazil 0.47562776035082793
france 0.4785275727925571
northern_ireland 0.4790547953115076
japanese 0.4793795620437956
surrey 0.4799681750235932
wigan 0.4805020170327207
poland 0.482544733444014
denmark 0.48319142888082733
italy 0.4861067428964341
sunderland 0.4884146341463415
confederate 0.4888315828957239
leeds 0.48966640347414137
oxford 0.4909090909090909
serbian 0.4913074712643678
american 0.4947743148844707
finnish 0.494937030031831
liverpool 0.49570481941684513
german 0.49879154552113514
c

In [88]:
res2 = get_results('si_results_2', list(embeddings.keys()))

sorted_res2 = sorted(res2.items(), key=lambda x: x[1])
print("Model 2 Results:")
for item in sorted_res2:
    print(item[0], item[1])

Model 2 Results:
american 0.6532726671972702
hiv 0.7045093890505157
dutch 0.714530761364456
scottish 0.7222987245800934
portuguese 0.7287643034449356
confederate 0.7296852401613763
hindu 0.7298761430721148
german 0.7650070215478468
mexican 0.7661094748715378
italian 0.7748689993238675
oxford 0.8329721921271217
union 0.833610400682012
jewish 0.8526707612134747
cambridge 0.8550757679388385
chinese 0.8681893434059311
japanese 0.8713872832369942
persian 0.8831094478733208
armenian 0.8836914430622665
serbian 0.8902146045861197
danish 0.9003957317504063
wigan 0.9024086378737541
russian 0.9074290955270213
greek 0.9118544323728546
norwegian 0.913595925358172
swedish 0.9162692900645673
hungarian 0.9186148985700033
czech 0.9217703349282296
united_states 0.9278182439853285
finnish 0.9289295382101841
middlesbrough 0.9374578708488743
albanian 0.9387878787878787
italy 0.9431395948193542
new_zealand 0.9452389621859482
wrexham 0.9458029197080291
reading 0.9517601547388781
south_africa 0.95275234381845

In [93]:
dif = {}
for word in list(embeddings.keys()):
    dif[word] = res2[word] - res1[word]

sorted_dif = sorted(dif.items(), key=lambda x: x[1])
print("Model 2 - Model 1 Difference:")
for item in sorted_dif:
    print(item[0], item[1])

Model 2 - Model 1 Difference:
glamorgan -0.008887549232376779
parma -0.002251290217779145
darlington 0.005716610699449398
wrexham 0.007035463727458335
genesis 0.009140859140859048
armenian 0.0106238948970524
albanian 0.014029480260497218
american 0.15849835231279946
hiv 0.18144825620042915
hindu 0.21327333336585674
dutch 0.2149021851328512
scottish 0.21999997247409375
portuguese 0.22392989041593514
confederate 0.24085365726565233
italian 0.25929087432386755
mexican 0.26605722528953446
german 0.2662154760267117
south_africa 0.32982632312291527
jewish 0.34190013014551357
oxford 0.3420631012180308
cambridge 0.3477563429739017
chinese 0.3491431314501775
union 0.36411665096004925
persian 0.3812452872343396
new_zealand 0.391626614870353
japanese 0.3920077211931986
greek 0.3954008357525719
serbian 0.39890713332175187
danish 0.39924630646305004
russian 0.404181329989711
swedish 0.41002638427569094
norwegian 0.41118758257420307
hungarian 0.41507534546868063
wigan 0.4219066208410334
czech 0.4222