In [1]:
import pickle
import csv
import os
from pathlib import Path
from typing import Set, Tuple, NamedTuple, List, Dict, Counter, Optional

import torch
import numpy as np
from scipy.spatial import distance
from scipy.stats import spearmanr

from evaluations.euphemism import Embedding, PhrasePair  
from recomposer import Recomposer, RecomposerConfig
from decomposer import Decomposer, DecomposerConfig

torch.manual_seed(42)
np.random.seed(42)

### Load Pretrained Embedding

In [8]:
# CR bill/topic paper submission
# pretrained = Embedding('../../data/pretrained_word2vec/no min freq/CR_ctx3_HS.txt', 'plain_text')
# pretrained = Embedding('../../data/pretrained_word2vec/maybe ctx5/bill_mentions_HS.txt', 'plain_text')
# pretrained = Embedding('../../data/pretrained_word2vec/CR_ctx3_HS.txt', 'plain_text')
# pretrained = Embedding('../../data/pretrained_word2vec/CR_ctx3_SGNS.txt', 'plain_text')

# CR Skip, eval on HS, train on SGNS
# pretrained = Embedding('../../data/pretrained_word2vec/for_real.txt', 'plain_text') # HS

# PN skip
pretrained = Embedding('../../data/pretrained_word2vec/partisan_news_HS.txt', 'plain_text')

vocab_size = 138,394, num_dimensions = 300
Loading embeddings from ../../data/pretrained_word2vec/partisan_news_HS.txt
Done


In [85]:
PE = Embedding('../../data/pretrained_word2vec/CR_ctx5_HS_replica.txt', 'plain_text')
# PE = Embedding('../../data/pretrained_word2vec/maybe ctx5/bill_mentions_HS.txt', 'plain_text')
# PE = Embedding('../../data/pretrained_word2vec/no min freq/CR_ctx3_HS.txt', 'plain_text')
PE_cs = PE.cosine_similarity

vocab_size = 24,005, num_dimensions = 300
Loading embeddings from ../../data/pretrained_word2vec/CR_ctx5_HS_replica.txt
Done


In [84]:
vs = len(PE.word_to_id)
avgs = []
for i in range(vs):
    v1 = PE.embedding[i]
    v2 = PE_r.embedding[i]
    cs = torch.nn.functional.cosine_similarity(v1, v2, dim=0).item()
    avgs.append(cs)
print(np.mean(avgs), np.std(avgs))

0.6996991923862983 0.055522398154501676


In [86]:
print(PE_cs('undocumented', 'illegal_aliens'))
print(PE_cs('estate_tax', 'death_tax'))
print(PE_cs('capitalism', 'free_market'))
print(PE_cs('foreign_trade', 'international_trade'))
print(PE_cs('public_option', 'governmentrun'))
print(PE_cs('federal_government', 'washington'))
print(PE_cs('death_taxes', 'death_tax'))

0.7568801045417786
0.8945048451423645
0.6935302019119263
0.7467148900032043
0.554466724395752
-0.02232639491558075
0.6717826128005981


In [53]:
PE.nearest_neighbor('death_tax')

0.8886	estate_tax
0.8805	taxed
0.8694	tax_burden
0.8672	wealthiest
0.8531	income_taxes
0.8518	taxes
0.8497	tax_code
0.8479	brackets
0.8477	breaks
0.8473	marriage_penalty




In [25]:
print(PE_cs('undocumented', 'illegal_aliens'))
print(PE_cs('estate_tax', 'death_tax'))
print(PE_cs('capitalism', 'free_market'))
print(PE_cs('foreign_trade', 'international_trade'))
print(PE_cs('public_option', 'governmentrun'))
print(PE_cs('death_taxes', 'death_tax'))

0.8068656921386719
0.8990585207939148
0.791033148765564
0.8984659910202026
0.671758770942688
0.7146835327148438


In [9]:
# Spreadsheet Submission
# stuff = Path('../../results/submission/CR_topic/L4R B512 LR1e-03/epoch100.pt') #ctx3 HS
# stuff = Path('../../results/submission/CR_bill/L4RL B512 LR1e-03/epoch100.pt') # submission??

# Spreadsheet & Paper
# stuff = Path('../../results/submission/CR_skip/B8 NS10/epoch4.pt')
stuff = Path('../../results/submission/PN_skip/NS10/epoch45.pt')

# stuff = Path('../../results/CR_bill/Ctx3 HS/L4R B128 LR1e-03/epoch100.pt') # B128??

# Paper Submission
# stuff = Path('../../results/CR_topic/M236_epoch25.pt')
# stuff = Path('../../results/CR_bill/Ctx3/L2 B2048 LR1e-03/epoch50.pt')

deno_space = Embedding(
    stuff, 'recomposer_deno', device=torch.device('cpu'))
cono_space = Embedding(
    stuff, 'recomposer_cono', device=torch.device('cpu'))

In [4]:
# cucumber = torch.load(stuff, map_location='cpu')
# md = cucumber['model']
# cucumber['config']
# md

In [5]:
def tabulate(q1, q2):
    PE_cs = pretrained.cosine_similarity(q1, q2)
    DS_cs = deno_space.cosine_similarity(q1, q2)
    CS_cs = cono_space.cosine_similarity(q1, q2)
#     print(round(PE_cs, 4), f'{DS_cs:.4f}', f'{CS_cs:.4f}', 
#           q1, q2, sep='\t')
    print(round(PE_cs, 4), f'{DS_cs - PE_cs:+.4f}', f'{CS_cs - PE_cs:+.4f}', 
          q1, q2, sep='\t')

def tabulate_rank(q1, q2):
    print(
        pretrained.neighbor_rank(q1, q2),
        pretrained.neighbor_rank(q2, q1),
        deno_space.neighbor_rank(q1, q2),
        deno_space.neighbor_rank(q2, q1),
        cono_space.neighbor_rank(q1, q2),
        cono_space.neighbor_rank(q2, q1),
        q1, q2, sep='\t')

    
# def cf(q1, q2)
#     pretrained.nearest_neighbor(q1)
#     model.nearest_neighbor(q1)
#     print('\n')
#     pretrained.nearest_neighbor(q2)
#     model.nearest_neighbor(q2)

In [10]:
cherry_pairs = [
#     # Luntz Report, all GOP euphemisms
#     ('federal_government', 'washington'),
#     ('private_account', 'personal_account'),
#     ('tax_reform', 'tax_simplification'),
#     ('undocumented', 'illegal_aliens'),  # OOV undocumented_workers
#     ('estate_tax', 'death_tax'),
#     ('capitalism', 'free_market'),  # global economy, globalization
#     ('outsourcing', 'innovation'),  # "root cause" of outsourcing, regulation
#     ('foreign_trade', 'international_trade'),  # foreign, global all bad
#     ('drilling_for_oil', 'exploring_for_energy'),
#     ('drilling', 'energy_exploration'),
#     ('tort_reform', 'lawsuit_abuse_reform'),
#     ('trial_lawyer', 'personal_injury_lawyer'),  # aka ambulance chasers
#     ('corporate_transparency', 'corporate_accountability'),
#     ('school_choice', 'parental_choice'),  # equal_opportunity_in_education
#     ('healthcare_choice', 'right_to_choose')

#     # CR
#     ('undocumented', 'illegal_aliens'),  
#     ('estate_tax', 'death_tax'),
#     ('capitalism', 'free_market'),  
#     ('foreign_trade', 'international_trade'),  
#     ('public_option', 'governmentrun'),
    
#     ('undocumented_workers', 'illegal_aliens'), 
#     ('trickledown', 'cut_taxes'),
#     ('socialized_medicine', 'singlepayer'),
#     ('voodoo', 'supplyside'), 
#     ('tax_expenditures', 'spending_programs'),
#     ('waterboarding', 'interrogation'),
#     ('political_speech', 'campaign_spending'),
#     ('star_wars', 'strategic_defense_initiative'),
#     ('nuclear_option', 'constitutional_option'),
    
    # PN
    ('undocumented_workers', 'illegal_aliens'),  # OOV undocumented_workers
    ('estate_tax', 'death_tax'),
    ('capitalism', 'free_market'),  # global economy, globalization
    ('foreign_trade', 'international_trade'),  # foreign, global all bad
    ('public_option', 'government_run'),
    ('voodoo', 'supply_side'), 
    ('supply_side', 'cut_taxes'),
    ('socialized_medicine', 'single_payer'),
    ('tax_expenditures', 'spending_programs'),
    ('waterboarding', 'interrogation'),
    ('political_speech', 'campaign_spending'),
]



In [11]:
for q1, q2 in cherry_pairs:  # bill
    tabulate(q1, q2)

0.6677	-0.2315	-0.3633	undocumented_workers	illegal_aliens
0.658	+0.0599	-0.2083	estate_tax	death_tax
0.8313	-0.2579	-0.1974	capitalism	free_market
0.641	-0.0747	+0.0521	foreign_trade	international_trade
0.7108	-0.0622	-0.1200	public_option	government_run
0.2572	+0.4375	-0.1313	voodoo	supply_side
0.4925	+0.1381	-0.1421	supply_side	cut_taxes
0.6284	-0.1355	-0.0454	socialized_medicine	single_payer
0.6122	+0.0978	-0.2843	tax_expenditures	spending_programs
0.6368	+0.0823	-0.0421	waterboarding	interrogation
0.471	+0.1393	-0.2144	political_speech	campaign_spending


In [8]:
for q1, q2 in cherry_pairs:
    tabulate_rank(q1, q2)

14374	633	187	5736	1890	15918	undocumented	illegal_aliens
2	1	5	1	19423	18829	estate_tax	death_tax
49415	16256	3	6850	6765	15908	capitalism	free_market
122	27	155	8800	800	6300	foreign_trade	international_trade
35313	86391	743	15395	19468	19030	public_option	governmentrun
