In [26]:
import pickle
import csv
import os
from typing import Set, Tuple, NamedTuple, List, Dict, Counter, Optional

import torch
from torch import nn
import numpy as np
from scipy.spatial import distance
from scipy.stats import spearmanr
import editdistance

from evaluations.intrinsic_eval import Embedding, PhrasePair
from evaluations.intrinsic_eval import cherry_words, generic_words
from decomposer import Decomposer, DecomposerConfig

import random
random.seed(42)

torch.manual_seed(42)
np.random.seed(42)

DEVICE = torch.device('cuda:1')

In [27]:
PE = torch.load(
    '../../results/pretrained/init.pt', map_location=DEVICE)['model']
GD = PE.grounding
cherry_ids = torch.tensor([PE.word_to_id[c] for c in cherry_words])
generic_ids = torch.tensor([PE.word_to_id[c] for c in generic_words])

GOP_words = [w for w in PE.word_to_id.keys()
             if GD[w]['freq'] > 99 and GD[w]['R_ratio'] > 0.75]
print(len(GOP_words))
GOP_ids = torch.tensor([PE.word_to_id[w] for w in GOP_words])

54


In [28]:
def heterogeneity_continuous(
        model,
        query_ids,
        top_k: int = 10,
        pretty_print: bool = True
        ) -> float:
    query_ids = query_ids.to(DEVICE)
    
    embed = model.embedding.weight.detach()
    query_embed = embed[query_ids]

    top_neighbor_ids = [ 
        nn.functional.cosine_similarity(
            q.view(1, -1), embed).argsort(descending=True)
        for q in query_embed]

    heterogeneity = [] 
    inspect = []  # List of List[distance, query, n1, n2...]
    
    for query_index, sorted_target_indices in enumerate(top_neighbor_ids):
        query_id = query_ids[query_index].item()
        query_words = model.id_to_word[query_id]
        num_neighbors = 0

        query_R_ratio = GD[query_words]['R_ratio']
        freq_ratio_distances = []
        inspect.append([0, query_words])
        
        for sort_rank, target_id in enumerate(sorted_target_indices):
            target_id = target_id.item()
            if num_neighbors == top_k:
                break
            if query_id == target_id:
                continue
            target_words = model.id_to_word[target_id]
            if editdistance.eval(query_words, target_words) < 3:
                continue
            num_neighbors += 1
            target_R_ratio = GD[target_words]['R_ratio']
            # freq_ratio_distances.append((target_R_ratio - query_R_ratio) ** 2)
            freq_ratio_distances.append(abs(target_R_ratio - query_R_ratio))
            
            inspect[-1].append(target_words)
            
        inspect[-1][0] = np.mean(freq_ratio_distances)
        
        # heterogeneity.append(np.sqrt(np.mean(freq_ratio_distances)))
        heterogeneity.append(np.mean(freq_ratio_distances))
        
#     inspect = random.sample(inspect, 100)

    if pretty_print:
        for i in sorted(inspect[:100], key=lambda t: t[0]):
            print(f'{i[0]:.4f} {"  ".join(i[1:6])}')
        return np.mean(heterogeneity)
    else:
        return inspect
    
    
def heterogeneity_continuous_export(
        models,
        query_ids,
        out_path,
        top_k=10,
        print_freq=False,
        ) -> None:
    tables = {}
    for model_name, model in models.items():
        table = heterogeneity_continuous(model, query_ids, pretty_print=False)
        tables[model_name] = table
    
    def detail_freq(word):
        combined_freq = GD[word]['freq']
        ratio = GD[word]['R_ratio']
        entry = f'{word}, {combined_freq}, {ratio:.2%}'   
        return entry

    first_table = tuple(tables.keys())[0]
    table_len = len(tables[first_table])
    with open(out_path, 'w') as file:
        file.write('model\tquery\theterogeneity\tn1\tn2\tn3\tn4\tn5\tn6\tn7\tn8\tn9\tn10\n')
        for row_index in range(table_len):  # iter over queries
            for model_name, table in tables.items():
                row = table[row_index]
                query = row[1]
                hetero = f'{row[0]:.4f}'
                neighbors = row[2:]
                
                query = detail_freq(query)
                if print_freq:
                    neighbors = [detail_freq(n) for n in neighbors]
#                 else:        
#                     m = models[model_name]
#                     combined_freq = m.Dem_frequency[query] + m.GOP_frequency[query]
#                     ratio = m.R_ratio(query)
#                     query = f'{query}, {combined_freq}, {ratio:.2%}'

                print(model_name, query, hetero, *neighbors, sep='\t', file=file)
            file.write('\n')

In [None]:
heterogeneity_continuous_export(
    models, cherry_ids, '../../results/detail_cherry_neighborhood.tsv', print_freq=True)
heterogeneity_continuous_export(
    models, GOP_ids, '../../results/detail_GOP_neighborhood.tsv', print_freq=True)
heterogeneity_continuous_export(
    models, generic_ids, '../../results/detail_generic_neighborhood.tsv', print_freq=True)

In [16]:
heterogeneity_continuous_export(
    models, cherry_ids, '../../results/cherry_neighborhood.tsv')
heterogeneity_continuous_export(
    models, GOP_ids, '../../results/GOP_neighborhood.tsv')
heterogeneity_continuous_export(
    models, generic_ids, '../../results/generic_neighborhood.tsv')

### Export en masse 

In [29]:
def export_neighbor_en_masse(in_dir, endswith, out_path):
    models = {}
    models['pretrained'] = PE
    for dirpath, _, filenames in os.walk(in_dir):
        for file in filenames:
            if file.endswith(endswith):
                path = os.path.join(dirpath, file)
                name = path.lstrip(in_dir)
                models[name] = torch.load(
                    path, map_location=DEVICE)['model']
    heterogeneity_continuous_export(models, cherry_ids, out_path)

In [25]:
export_neighbor_en_masse(
    in_dir='../../results/toy/BS64 NGC', 
    endswith='epoch150.pt',
    out_path='../../analysis/adv64_NCG_cherry_neighborhood.tsv')

### Manual Inspection

In [33]:
def load(path):
    path = '../../results/' + path
    model = torch.load(path, map_location=DEVICE)['model']
    return model.embedding.weight.detach().cpu().numpy()

In [34]:
E5 = load('toy/only remove deno/L4 E5 embed42/epoch50.pt')

In [36]:
E5[:10]

array([[ 0.20989545,  0.20820804,  0.16802971,  0.17149678,  0.13462569,
         0.11935333,  0.08635937,  0.19942373,  0.08815943,  0.15533209,
         0.16642086,  0.18624912,  0.18482862,  0.18867253,  0.23323508,
         0.25218633,  0.15253443,  0.15915084,  0.21690959,  0.14118168,
         0.16216236,  0.18454884,  0.08942422,  0.04033117,  0.03498505,
         0.15103433,  0.22899777,  0.1373053 ,  0.09291944,  0.11199819,
         0.25223726,  0.03155887,  0.05974406,  0.0735632 ,  0.07083415,
         0.22723348,  0.67370504,  0.20017044,  0.10953286,  0.31604183,
         0.104411  ,  0.11429827],
       [ 0.05128092,  0.12789717,  0.19922662,  0.02559984,  0.06493242,
        -0.01047843,  0.02186588, -0.04793435,  0.05556713,  0.02345957,
         0.04995031,  0.11724861,  0.08427715,  0.01839533,  0.08079092,
         0.3601503 ,  0.2232121 ,  0.10550333,  0.195397  , -0.00089734,
        -0.01246764,  0.04178951,  0.02131795,  0.04980367,  0.12333693,
         0.05483