In [1]:
import pickle
import csv
import os
from typing import Set, Tuple, NamedTuple, List, Dict, Counter, Optional

import torch
from torch import nn
import numpy as np
from scipy.spatial import distance
from scipy.stats import spearmanr
import editdistance

from evaluations.intrinsic_eval import Embedding, PhrasePair
from bill_decomposer import Decomposer, DecomposerConfig

import random
random.seed(42)

torch.manual_seed(42)
np.random.seed(42)

DEVICE = torch.device('cuda:0')

In [43]:
def heterogeneity_continuous(
        model,
        query_ids,
        top_k: int = 10,
        pretty_print: bool = True
        ) -> float:
    query_ids = query_ids.to(DEVICE)
    query_embed = model.embedding(query_ids)

    with torch.no_grad():
        top_neighbor_ids = [
            nn.functional.cosine_similarity(
                q.view(1, -1), model.embedding.weight).argsort(descending=True)
            for q in query_embed]

    homogeneity = []  # RENAME to hetero
    inspect = []  # List of rows, where each row is [dist, query, n1, n2...]
    
    for query_index, sorted_target_indices in enumerate(top_neighbor_ids):
        query_id = query_ids[query_index].item()
        query_words = model.id_to_word[query_id]
        num_neighbors = 0

        query_R_ratio = model.R_ratio(query_words)
        freq_ratio_distances = []
        inspect.append([0, query_words])
        
        for sort_rank, target_id in enumerate(sorted_target_indices):
            target_id = target_id.item()
            if num_neighbors == top_k:
                break
            if query_id == target_id:
                continue
            target_words = model.id_to_word[target_id]
            if editdistance.eval(query_words, target_words) < 3:
                continue
            num_neighbors += 1
            target_R_ratio = model.R_ratio(target_words)
            # freq_ratio_distances.append((target_R_ratio - query_R_ratio) ** 2)
            freq_ratio_distances.append(abs(target_R_ratio - query_R_ratio))
            
            inspect[-1].append(target_words)
            
        inspect[-1][0] = np.mean(freq_ratio_distances)
        
        # homogeneity.append(np.sqrt(np.mean(freq_ratio_distances)))
        homogeneity.append(np.mean(freq_ratio_distances))
        
#     inspect = random.sample(inspect, 100)

    if pretty_print:
        for i in sorted(inspect[:100], key=lambda t: t[0]):
            print(f'{i[0]:.4f} {"  ".join(i[1:6])}')
        return np.mean(homogeneity)
    else:
        return inspect

# def heterogeneity_continuous_compare(
#         model,
#         baseline_model,
#         query_ids,
#         top_k=10,
#         sample=None,
#         ) -> None:
    
#     for model_name, model in models.item():
#         models[model_name]
    
#     table0 = heterogeneity_continuous(baseline_model, query_ids, pretty_print=False)
#     table1 = heterogeneity_continuous(model, query_ids, pretty_print=False)
    
    
#     def print_row(model_name, row):
#         print(f'{row[0]:.4f} {"\t".join(row[1:])}')
        
    
#     for index, row in enumerate(table0):
#         print('Pretrained:', end='\t')
#         print_row(row)
#         print('Model:', end='\t')
#         print_row(table1[index])
#         print()
    
    
def heterogeneity_continuous_export(
        models,
        query_ids,
        out_path,
        top_k=10,
        print_freq=False,
        ) -> None:
    tables = {}
    for model_name, model in models.items():
        table = heterogeneity_continuous(model, query_ids, pretty_print=False)
        tables[model_name] = table
    
    m = models['pretrained']
    def detail_freq(word):
        combined_freq = m.Dem_frequency[word] + m.GOP_frequency[word]
        ratio = m.R_ratio(word)
        entry = f'{word}, {combined_freq}, {ratio:.2%}'   
        return entry

    table_len = len(tables['pretrained'])
    with open(out_path, 'w') as file:
        file.write('model\tquery\theterogeneity\tn1\tn2\tn3\tn4\tn5\tn6\tn7\tn8\tn9\tn10\n')
        for row_index in range(table_len):  # iter over queries
            for model_name, table in tables.items():
                row = table[row_index]
                query = row[1]
                hetero = f'{row[0]:.4f}'
                neighbors = row[2:]
                
                query = detail_freq(query)
                if print_freq:
                    neighbors = [detail_freq(n) for n in neighbors]
#                 else:        
#                     m = models[model_name]
#                     combined_freq = m.Dem_frequency[query] + m.GOP_frequency[query]
#                     ratio = m.R_ratio(query)
#                     query = f'{query}, {combined_freq}, {ratio:.2%}'

                print(model_name, query, hetero, *neighbors, sep='\t', file=file)
            file.write('\n')

In [15]:
cherries = (
    'military_budget', 'defense_budget',
    # 'nuclear_option', 'constitutional_option',
    'prochoice', # 'proabortion',
    'star_wars', # 'strategic_defense_initiative',
    'political_speech', 'campaign_spending',
    'singlepayer', # 'socialized_medicine',
    # 'voodoo', 'supplyside',
    'tax_expenditures', # 'spending_programs',
    'waterboarding', 'interrogation',
    'cap_and_trade', 'national_energy_tax',
    'governmentrun', 'public_option',
    'medical_liability_reform', # 'tort_reform',
    # 'corporate_profits', 'earnings',
    'equal_pay',  # 'the_paycheck_fairness_act',
    'military_spending', # 'washington_spending',
    'higher_taxes', # 'bigger_government',
    'social_justice', # 'womens_rights',
    # 'national_health_insurance', # 'welfare_state', 
    'nuclear_war', 'deterrence',
    'suffrage', # 'womens_rights',
    'inequality', 'racism',
    # 'sweatshops', 'factories',
    'trickledown', 'cut_taxes',
    'equal_pay', 'pay_discrimination',
    'wealthiest_americans', 'tax_breaks',
    'record_profits', 'big_oil_companies',
    # 'private_insurance_companies', 'medicare_advantage_program',
    'trickledown', # 'universal_health_care',
    'big_banks', # 'occupation_of_iraq',
    # 'obamacare', 'islamists'
)

generics = (
    'government',
    'taxes',
    'laws',
    'jobs',
    'tariff',
    'health_care',
    'finance',
    'social_security',
    'medicare',
    'regulations',
    'immigration',
    'research',
    'technology',
)

In [4]:
path = '../../results/pretrained/init.pt'
pretrained_model = torch.load(path, map_location=DEVICE)['model']

In [None]:
path = '../../results/cono space remove deno/L1 -0.05d/epoch50.pt'
model = torch.load(path, map_location=DEVICE)['model']

In [None]:
model.R_ratio('marriage_tax_penalty')
# model.GOP_frequency['four_trillion']

### Cherries

In [52]:
PM = pretrained_model
cherry_ids = torch.tensor([PM.word_to_id[c] for c in cherries])
generic_ids = torch.tensor([PM.word_to_id[c] for c in generics])

select = torch.tensor(
    [i.item() 
     for i in PM.GOP_ids 
     if PM.GOP_frequency[PM.id_to_word[i.item()]] >= 100])
sample = torch.randint(high=len(select), size=(100,))
GOP_ids = select[sample]

In [11]:
def load(path):
    path = '../../results/' + path
    return torch.load(path, map_location=DEVICE)['model']

In [26]:
models = {
    'pretrained': load('pretrained/init.pt'),
    'L1 -0.05d': load('cono space remove deno/L1 -0.05d/epoch50.pt'),
    'L4 -0.05d': load('cono space remove deno/L4 -0.05d/epoch50.pt'),

    'L4 +5 -0.05d': load('affine/L4 +5 -0.05d/epoch50.pt'),
    'L4 +5 -0.1d': load('affine/L4 +5 -0.1d/epoch50.pt'),
    'L4 +5 -0.2d': load('affine/L4 +5 -0.2d/epoch50.pt'),
    'L4 +5 -0.5d': load('affine/L4 +5 -0.5d/epoch50.pt'),
    'L4 +5 -1d': load('affine/L4 +5 -1d/epoch50.pt'),
    'L4 +10 -1.5d': load('affine/L4 +10 -1.5d/epoch50.pt'),
    'L4 +10 -2d': load('affine/L4 +10 -2d/epoch50.pt'),
    'L4 +5 0c -1d': load('affine/L4 +5 0c/epoch50.pt'),
}

In [22]:
# models = {
#     'pretrained': load('pretrained/init.pt'),
#     '1': load('analysis/retrained/L4 from L4 LLR/epoch10.pt'),
#     '2': load('analysis/retrained/L4 from L4 LLR/epoch20.pt'),
#     '3': load('analysis/retrained/L4 from L4 LLR/epoch30.pt'),
#     '4': load('analysis/retrained/L4 from L4 LLR/epoch40.pt'),
#     '5': load('analysis/retrained/L4 from L4 LLR/epoch50.pt'),
#     '6': load('analysis/retrained/L4 from L4 LLR/epoch80.pt'),
#     '7': load('analysis/retrained/L4 from L4 LLR/epoch100.pt'),
#     '8': load('analysis/retrained/L4 from L4 LLR/epoch150.pt'),
#     '9': load('analysis/retrained/L4 from L4 LLR/epoch200.pt'),
#     '10': load('analysis/retrained/L4 from L4 LLR/epoch250.pt'),
#     '11': load('analysis/retrained/L4 from L4 LLR/epoch300.pt'),
#     '12': load('analysis/retrained/L4 from L4 LLR/epoch400.pt'),
#     '13': load('analysis/retrained/L4 from L4 LLR/epoch500.pt'),
# }

In [53]:
heterogeneity_continuous_export(
    models, cherry_ids, '../../results/detail_cherry_neighborhood.tsv', print_freq=True)
heterogeneity_continuous_export(
    models, GOP_ids, '../../results/detail_GOP_neighborhood.tsv', print_freq=True)
heterogeneity_continuous_export(
    models, generic_ids, '../../results/detail_generic_neighborhood.tsv', print_freq=True)

In [54]:
heterogeneity_continuous_export(
    models, cherry_ids, '../../results/cherry_neighborhood.tsv')
heterogeneity_continuous_export(
    models, GOP_ids, '../../results/GOP_neighborhood.tsv')
heterogeneity_continuous_export(
    models, generic_ids, '../../results/generic_neighborhood.tsv')

## Partisan Skew > 80%
Average difference between query R_ratio and neighbor R_ratios 

In [None]:
heterogeneity_continuous(pretrained_model, model.GOP_ids)

In [None]:
heterogeneity_continuous(model, model.GOP_ids)