In [1]:
import random
import os

import torch
from torch import nn
from matplotlib import pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from tqdm import tqdm

from evaluations.intrinsic_eval import cherry_words, generic_words
from decomposer import Decomposer, DecomposerConfig

random.seed(42)
torch.manual_seed(42)
sns.set()

DEVICE = 'cpu'
PE = torch.load(
    '../../results/pretrained/init.pt', map_location=DEVICE)['model']
GD = PE.grounding

In [60]:
def load(path):
    stuff = torch.load(path, map_location=DEVICE)['model']
    return stuff.embedding.weight.detach().numpy()

def gather(words):
    word_ids = [PE.word_to_id[w] for w in words]
    freq = [GD[w]['freq'] for w in words]
    skew = [GD[w]['R_ratio'] for w in words]
    maj_deno = [GD[w]['majority_deno'] for w in words]
    return word_ids, freq, skew, maj_deno

def plot(coordinates, words, freq, skew, path):
    fig, ax = plt.subplots(figsize=(15,10))    
    sns.scatterplot(
        coordinates[:,0], coordinates[:,1], 
        hue=skew, palette='coolwarm', # hue_norm=(0, 1), 
        size=freq, sizes=(100, 1000), 
        legend=None, ax=ax)
    for coord, word in zip(coordinates, words):
        ax.annotate(word, coord, fontsize=12)
    with open(path, 'wb') as file:
        fig.savefig(file, dpi=300)
    plt.close(fig)

def plot_categorical(coordinates, words, freq, skew, path):
    fig, ax = plt.subplots(figsize=(20,10))    
    sns.scatterplot(
        coordinates[:,0], coordinates[:,1], 
        hue=skew, palette='muted', hue_norm=(0, 1),
        size=freq, sizes=(100, 1000), 
        legend='brief', ax=ax)
    chartBox = ax.get_position()
    ax.set_position([chartBox.x0, chartBox.y0, chartBox.width*0.6, chartBox.height])
    ax.legend(loc='upper center', bbox_to_anchor=(1.45, 0.8), ncol=1)
    for coord, word in zip(coordinates, words):
        ax.annotate(word, coord, fontsize=12)
    with open(path, 'wb') as file:
        fig.savefig(file, dpi=300)
    plt.close(fig)
    
def load_en_masse(in_dir, endswith):
    models = {}
    for dirpath, _, filenames in tqdm(os.walk(in_dir)):
        for file in filenames:
            if file.endswith(endswith):
                path = os.path.join(dirpath, file)
                name = path.lstrip(in_dir).replace('/', ' ')
                models[name] = load(path)
    print(*models.keys(), sep='\n')
    return models
    
def graph_en_masse(
        in_dir,
        out_dir, 
        reduction,  #  'PCA', 'TSNE', or 'both'
        word_ids,  
        words, 
        hues,
        sizes,
        perplexity=None,
        categorical=False):
    os.makedirs(out_dir, exist_ok=True)
    for model_name, embed in tqdm(models.items()):
        space = embed[word_ids]
        if reduction == 'PCA':
            visual = PCA(n_components=2).fit_transform(space)
        elif reduction == 'TSNE':
            assert perplexity is not None
            visual = TSNE(
                perplexity=perplexity, learning_rate=10, 
                n_iter=5000, n_iter_without_progress=1000).fit_transform(space)
        elif reduction == 'both':
            assert perplexity is not None
            space = PCA(n_components=30).fit_transform(space)
            visual = TSNE(
                perplexity=perplexity, learning_rate=10, 
                n_iter=5000, n_iter_without_progress=1000).fit_transform(space)
        else: 
            raise ValueError('unknown dimension reduction method')
        if not categorical:
            plot(visual, words, sizes, hues, 
                 os.path.join(out_dir, f'{model_name}.png'))
        else:
            plot_categorical(visual, words, sizes, hues, 
                 os.path.join(out_dir, f'{model_name}.png'))

In [30]:
ch_ids, ch_freq, ch_skew, ch_deno = gather(cherry_words)
gen_ids, gen_freq, gen_skew, gen_deno = gather(generic_words)

random_words = [w for w in PE.word_to_id.keys() 
                if GD[w]['freq'] > 99]
random_words = random.sample(random_words, 50)
rand_ids, rand_freq, rand_skew, rand_deno = gather(random_words)

51


In [47]:
R_words = [w for w in PE.word_to_id.keys()
             if GD[w]['freq'] > 99 and GD[w]['R_ratio'] > 0.75]
R_words.remove('federal_debt_stood')  # outliers in clustering graphs
R_words.remove('statements_relating')
R_words.remove('legislative_days_within')
print(len(GOP_words))
# GOP_words = random.sample(GOP_words, 50)
R_ids, R_freq, R_skew, R_deno = gather(R_words)

50


In [43]:
# D_words = [w for w in PE.word_to_id.keys()
#            if GD[w]['freq'] > 99 and GD[w]['R_ratio'] < 0.25]

D_words = ['war_in_iraq', 'unemployed', 'detainees', 'solar', 
    'wealthiest', 'minorities', 'gun_violence', 
    'amtrak', 'unemployment_benefits', 
    'citizens_united', 'mayors', 'prosecutor', 'working_families', 
    'cpsc', 'sexual_assault',
    'affordable_housing', 'vietnam_veterans', 'drug_companies', 'handguns',
    'hungry', 'college_education', 
    'main_street', 'trauma', 'simon', 'pandemic', 
    'reagan_administration', 'guns', 
    'million_jobs', 'airline_industry', 'mergers', 'blacks', 
    'industrial_base', 'unemployment_insurance',
    'vacancies', 'trade_deficit', 'lost_their_jobs', 'food_safety', 
    'darfur', 'trains', 'deportation', 'credit_cards', 
    'surface_transportation', 'solar_energy', 'ecosystems', 'layoffs', 
    'wall_street', 'steelworkers', 'puerto_rico', 'hunger', 
    'child_support', 'naacp', 'domestic_violence', 'seaports', 
    'hate_crimes', 'underfunded', 'registrants', 'sanctuary', 
    'coastal_zone_management', 'vermonters', 'automakers', 
    'violence_against_women', 'unemployment_rate', 
    'select_committee_on_indian_affairs', 'judicial_nominees', 
    'school_construction', 'clarence_mitchell', 'confidential', 
    'domain_name', 'community_development', 'pell_grant', 'asylum', 'vawa', 
    'somalia', 'african_american', 'traders', 'jersey', 'fdic', 'shameful', 
    'homelessness', 'african_americans', 'payroll_tax',]
#     'retraining', 'unemployed_workers', 'the_disclose_act', 'baltimore', 
#     'assault_weapons', 'credit_card', 'the_patriot_act', 'young_woman', 
#     'trades', 'aye', 'poisoning', 'police_officers', 'mammal', 'toys', 
#     'whistleblowers', 'north_dakota', 'californias', 'computer_crime', 
#     'explosives', 'fast_track', 'bus', 'redlining', 'seclusion', 'gender', 
#     'hawaiian', 'pay_discrimination', 'ledbetter', 'phd', 'supra', 'baggage', 
#     'las_vegas', 'the_voting_rights_act', 'enron', 'richest', 'vra', 'chip', 
#     'tax_break', 'the_usa_patriot_act', 'advance_notice', 'derivatives', 
#     'the_patients_bill_of_rights', 'shelf', 'divestment', 'sa', 
#     'submitted_an_amendment', 'bill_hr', 'first_responders',
#     'unemployment_compensation', 'tax_breaks', 'carbon', 
#     'college_cost_reduction', 'clean_energy', 'waives', 
#     'unregulated', 'taa', 'truman', 'lesbian', 'coupons', 
#     'large_numbers', 'anonymous', 'whites', 'logging']

print(len(D_words))
D_words = random.sample(D_words, 50)
D_ids, D_freq, D_skew, D_deno = gather(D_words)

81


In [80]:
J_words = D_words + R_words
J_ids = D_ids + R_ids
J_freq = D_freq + R_freq
J_skew = D_skew + R_skew
J_deno = D_deno + R_deno

In [75]:
GD['joliet']

Counter({'Labor and employment': 6,
         'R': 105,
         'Taxation': 87,
         'Law': 2,
         'D': 7,
         'Health': 3,
         'Public lands and natural resources': 1,
         'International affairs': 5,
         'Transportation and public works': 2,
         'Science, technology, communications': 2,
         'Families': 2,
         'Crime and law enforcement': 1,
         'Education': 1,
         'majority_deno': 'Taxation',
         'freq': 112,
         'R_ratio': 0.9375})

In [85]:
# base_dir = '../../results/only remove deno BS128'
base_dir = '../../results/cono space remove deno'
# base_dir = '../../results/deno space remove cono'
models = load_en_masse(base_dir, endswith='epoch10.pt')
models['pretrained superset'] = load('../../results/pretrained/init.pt')
models['pretrained'] = load('../../results/pretrained bill mentions/init.pt')

21it [00:03,  5.99it/s]


E1 A1 epoch10.pt
E4 A1 superset epoch10.pt
E4 A1 epoch10.pt
E1 A1 superset epoch10.pt


### Graph by Party Skew (for removing connotation)

In [61]:
graph_en_masse(
    models,
    out_dir=f'{base_dir}/PCA',
    reduction='PCA',
    word_ids=R_ids,
    words=R_words,
    hues=R_skew,
    sizes=R_freq,
)

graph_en_masse(
    models,
    out_dir=f'{base_dir}/t-SNE p5',
    reduction='TSNE',
    perplexity=5,
    word_ids=R_ids,
    words=R_words,
    hues=R_skew,
    sizes=R_freq,
)

graph_en_masse(
    models,
    out_dir=f'{base_dir}/t-SNE p3',
    reduction='TSNE',
    perplexity=3,
    word_ids=R_ids,
    words=R_words,
    hues=R_skew,
    sizes=R_freq,
)

100%|██████████| 2/2 [00:02<00:00,  1.36s/it]
100%|██████████| 2/2 [00:03<00:00,  1.84s/it]
100%|██████████| 2/2 [00:04<00:00,  2.12s/it]


In [86]:
graph_en_masse(
    models,
    out_dir=f'{base_dir}/PCA',
    reduction='PCA',
    word_ids=J_ids,
    words=J_words,
    hues=J_skew,
    sizes=J_freq,
)

graph_en_masse(
    models,
    out_dir=f'{base_dir}/t-SNE p5',
    reduction='TSNE',
    perplexity=5,
    word_ids=J_ids,
    words=J_words,
    hues=J_skew,
    sizes=J_freq,
)

graph_en_masse(
    models,
    out_dir=f'{base_dir}/t-SNE p3',
    reduction='TSNE',
    perplexity=3,
    word_ids=J_ids,
    words=J_words,
    hues=J_skew,
    sizes=J_freq,
)

100%|██████████| 6/6 [00:08<00:00,  1.36s/it]
100%|██████████| 6/6 [00:14<00:00,  2.35s/it]
100%|██████████| 6/6 [00:14<00:00,  2.35s/it]


### Graph by Topic Denotation (for removing denotation)

In [73]:
graph_en_masse(
    models,
    out_dir=f'{base_dir}/Highly GOP/PCA',
    reduction='PCA',
    perplexity=5,
    word_ids=GOP_ids,
    words=GOP_words,
    hues=GOP_deno,
    sizes=GOP_freq,
    categorical=True
)

graph_en_masse(
    models,
    out_dir=f'{base_dir}/Highly GOP/t-SNE p5',
    reduction='TSNE',
    perplexity=5,
    word_ids=GOP_ids,
    words=GOP_words,
    hues=GOP_deno,
    sizes=GOP_freq,
    categorical=True
)

graph_en_masse(
    models,
    out_dir=f'{base_dir}/Highly GOP/t-SNE p3',
    reduction='TSNE',
    perplexity=3,
    word_ids=GOP_ids,
    words=GOP_words,
    hues=GOP_deno,
    sizes=GOP_freq,
    categorical=True
)

100%|██████████| 5/5 [00:09<00:00,  1.86s/it]
100%|██████████| 5/5 [00:11<00:00,  2.37s/it]
100%|██████████| 5/5 [00:12<00:00,  2.46s/it]


In [74]:
graph_en_masse(
    models,
    out_dir=f'{base_dir}/Highly Dem/PCA',
    reduction='PCA',
    perplexity=5,
    word_ids=D_ids,
    words=D_words,
    hues=D_deno,
    sizes=D_freq,
    categorical=True
)

graph_en_masse(
    models,
    out_dir=f'{base_dir}/Highly Dem/t-SNE p5',
    reduction='TSNE',
    perplexity=5,
    word_ids=D_ids,
    words=D_words,
    hues=D_deno,
    sizes=D_freq,
    categorical=True
)

graph_en_masse(
    models,
    out_dir=f'{base_dir}/Highly Dem/t-SNE p3',
    reduction='TSNE',
    perplexity=3,
    word_ids=D_ids,
    words=D_words,
    hues=D_deno,
    sizes=D_freq,
    categorical=True
)

100%|██████████| 5/5 [00:09<00:00,  1.91s/it]
100%|██████████| 5/5 [00:12<00:00,  2.47s/it]
100%|██████████| 5/5 [00:12<00:00,  2.41s/it]


In [81]:
graph_en_masse(
    models,
    out_dir=f'{base_dir}/Joint/PCA',
    reduction='PCA',
    perplexity=5,
    word_ids=J_ids,
    words=J_words,
    hues=J_deno,
    sizes=J_freq,
    categorical=True
)

graph_en_masse(
    models,
    out_dir=f'{base_dir}/Joint/t-SNE p5',
    reduction='TSNE',
    perplexity=5,
    word_ids=J_ids,
    words=J_words,
    hues=J_deno,
    sizes=J_freq,
    categorical=True
)

graph_en_masse(
    models,
    out_dir=f'{base_dir}/Joint/t-SNE p3',
    reduction='TSNE',
    perplexity=3,
    word_ids=J_ids,
    words=J_words,
    hues=J_deno,
    sizes=J_freq,
    categorical=True
)

100%|██████████| 5/5 [00:09<00:00,  1.94s/it]
100%|██████████| 5/5 [00:15<00:00,  3.07s/it]
100%|██████████| 5/5 [00:15<00:00,  3.02s/it]


### Graph Manually

In [None]:
out_dir = '../../analysis/PCA/cherry/topic_live'
os.makedirs(out_dir, exist_ok=True)
for model_name, embed in models.items():
    space = embed[ch_ids]
    visual = PCA(n_components=2).fit_transform(space)
#     visual = TSNE(
#         perplexity=4, learning_rate=10, 
#         n_iter=5000, n_iter_without_progress=1000).fit_transform(visual)
    plot(visual, cherry_words, ch_freq, ch_deno, 
         os.path.join(out_dir, f'{model_name}.png'))

In [None]:
out_dir = '../../analysis/t-SNE'
for model_name, embed in models.items():
    space = embed[generic_ids]
    visual = TSNE(perplexity=5, learning_rate=1).fit_transform(space)
    plot(visual, generic_words, generic_freq, generic_skew, 
         f'{out_dir}/generic {model_name}.png')

In [None]:
out_dir = '../../analysis/t-SNE'
for model_name, embed in models.items():
    space = embed[random_ids]
    visual = TSNE(
        perplexity=20, learning_rate=10, n_iter=5000).fit_transform(space)
    plot(visual, random_words, random_freq, random_skew, 
         f'{out_dir}/random {model_name}.png')

In [None]:
models = {
    'M0 pretrained': load('pretrained/init.pt'),
    'M1 L1 -0.05d': load('cono space remove deno/L1 -0.05d/epoch50.pt'),
    'M2 L4 -0.05d': load('cono space remove deno/L4 -0.05d/epoch50.pt'),
    'M3 +5 -0.05d': load('affine/L4 +5 -0.05d/epoch50.pt'),
    'M4 +5 -0.1d': load('affine/L4 +5 -0.1d/epoch50.pt'),
    'M5 +5 -0.2d': load('affine/L4 +5 -0.2d/epoch50.pt'),
    'M6 +5 -0.5d': load('affine/L4 +5 -0.5d/epoch50.pt'),
    'M7 +5 -1d': load('affine/L4 +5 -1d/epoch50.pt'),
    'M8 +10 -1.5d': load('affine/L4 +10 -1.5d/epoch50.pt'),
    'M9 +10 -2d': load('affine/L4 +10 -2d/epoch50.pt'),
    'M10 +5 0c -1d': load('affine/L4 +5 0c/epoch50.pt'),
}

In [None]:
models = {
    'all': load('../../results/pretrained/init.pt'),
    'bill mentioned': load('../../results/pretrained bill mentions/init.pt')
}

In [None]:
out_dir = '../../analysis/pretrained/t-SNE p5'
os.makedirs(out_dir, exist_ok=True)
for model_name, embed in models.items():
    space = embed[ch_ids]
    visual = TSNE(perplexity=5, learning_rate=1).fit_transform(space)
    plot(visual, cherry_words, ch_freq, ch_skew, 
         f'{out_dir}/cherry {model_name}.png')