In [1]:
from pathlib import Path
from typing import Tuple, Union, List, Dict, Iterable, Optional

import torch
from tqdm.notebook import tqdm

from decomposer import Decomposer, DecomposerConfig
from recomposer import Recomposer, RecomposerConfig
# from evaluations.helpers import GroundedWord, load_recomposers_en_masse
# from evaluations.clustering import graph_en_masse
# from evaluations.euphemism import cherry_words

In [2]:
from dataclasses import dataclass

BASE_DIR = Path.home() / 'Research/congressional_adversary/results'
# sup_PE = torch.load(BASE_DIR / 'SGNS deno/pretrained super large/init.pt')['model']
sup_PE = torch.load(BASE_DIR / 'news/validation/pretrained/init.pt')['model']
WTI = sup_PE.word_to_id
ITW = sup_PE.id_to_word
sup_PE = sup_PE.embedding.weight.detach().cpu().numpy()
print(f'Vocab size = {len(WTI):,}')


sub_PE = torch.load(BASE_DIR / 'bill topic/pretrained subset/init.pt')['model']
sub_PE_WID = sub_PE.word_to_id
sub_PE_GD = sub_PE.grounding
del sub_PE


@dataclass
class GroundedWord():
    word: str

#     def __post_init__(self) -> None:
#         self.word_id: int = WTI[self.word]
#         metadata = sub_PE_GD[self.word]
#         self.freq: int = metadata['freq']
#         self.R_ratio: float = metadata['R_ratio']
#         self.majority_deno: int = metadata['majority_deno']

#         self.PE_neighbors = self.neighbors(sup_PE)
            
#     def neighbors(self, embed, top_k=10): 
#         query_id = sup_PE.word_to_id[self.word]
#         query_vec = sup_PE[query_id]
#         distances = [
#             distance.cosine(query_vec, neighbor_vec)
#             for neighbor_vec in sup_PE]
#         self.sup_PE_neighbors = set()
#         for sort_rank, neighbor_id in enumerate(sorted_neighbor_indices):
#             if num_neighbors == top_k:
#                 break
#             if query_id == neighbor_id:
#                 continue
#             neighbor_word = self.id_to_word[neighbor_id]
#             if editdistance.eval(query_word, neighbor_word) < 3:
#                 continue
            
#         self.sub_PE_neighbors: List[str] = nearest(, sub_PE)

    def __str__(self) -> str:
        return str(vars(self))
    
    
capitalism: List[GroundedWord] = []
socialism: List[GroundedWord] = []
for word in sub_PE_WID.keys():
    ratio = sub_PE_GD[word]['R_ratio']
    freq = sub_PE_GD[word]['freq']
    word = GroundedWord(word)
    if ratio < 0.2 and freq > 100:  # 0.2:
        socialism.append(word)
    elif ratio > 0.8 and freq > 100:  # 0.8:
        capitalism.append(word)

print(
    f'{len(capitalism)} capitalists\n'
    f'{len(socialism)} socialists')
polarization = capitalism + socialism

Vocab size = 77,647




29 capitalists
64 socialists


In [3]:
import numpy as np

def get_embed(model: Decomposer) -> np.ndarray:
    return model.embedding.weight.detach().cpu().numpy()


def load(
        path: Path,
        match_vocab: bool = False,
        device: str = 'cpu'
        ) -> np.ndarray:
    model = torch.load(path, map_location=device)['model']
    try:
        assert model.word_to_id == WTI
    except AssertionError:
        print(f'Vocabulary mismatch: {path}')
        print(f'Vocab size = {len(model.word_to_id)}')
        if match_vocab:
            raise RuntimeError
        else:
            return None
    return get_embed(model)


def load_decomposers_en_masse(
        in_dirs: Union[Path, List[Path]],
        patterns: Union[str, List[str]]
        ) -> Tuple[Dict[str, np.ndarray], ...]:
    if not isinstance(in_dirs, List):
        in_dirs = [in_dirs, ]
    if not isinstance(patterns, List):
        patterns = [patterns, ]
    checkpoints: List[Path] = []
    for in_dir in in_dirs:
        for pattern in patterns:
            checkpoints += list(in_dir.glob(pattern))
    if len(checkpoints) == 0:
        raise FileNotFoundError('No model with path pattern found at in_dir?')

    models = {
#         'pretrained superset': load(BASE_DIR / 'bill topic/pretrained superset/init.pt'),
#         'pretrained subset': load(BASE_DIR / 'bill topic/pretrained subset/init.pt')
    }
    for path in tqdm(checkpoints):
        tqdm.write(f'Loading {path}')
        embed = load(path) 
        if embed is None:
            continue
#         name = path.parent.name
        name = path.name 
        models[name] = embed
    return models

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

sns.set()

def plot(
        coordinates: np.ndarray,
        words: List[GroundedWord],
        path: Path
        ) -> None:
    fig, ax = plt.subplots(figsize=(15, 10))
#     skew = [w.R_ratio for w in words]
#     freq = [w.freq for w in words]
    sns.scatterplot(
        coordinates[:, 0], coordinates[:, 1],
#         hue=skew, palette='coolwarm',  # hue_norm=(0, 1),
#         size=freq, sizes=(200, 1000),
        legend=None, ax=ax)
    for coord, w in zip(coordinates, words):
        ax.annotate(w.word, coord, fontsize=20)
    with open(path, 'wb') as file:
        fig.savefig(file, dpi=300)
    plt.close(fig)


def plot_categorical(
        coordinates: np.ndarray,
        words: List[GroundedWord],
        path: Path,
        fancy: bool = True
        ) -> None:
    if fancy:
        fig, ax = plt.subplots(figsize=(20, 10))
        categories = [w.majority_deno for w in words]
        freq = [w.freq for w in words]
        sns.scatterplot(
            coordinates[:, 0], coordinates[:, 1],
            hue=categories, palette='muted', hue_norm=(0, 1),
            size=freq, sizes=(200, 1000),
            legend='brief', 
            ax=ax)
        chartBox = ax.get_position()
        ax.set_position(  # adjust legend
            [chartBox.x0, chartBox.y0, chartBox.width * 0.6, chartBox.height])
        ax.legend(loc='upper center', bbox_to_anchor=(1.45, 0.8), ncol=1)
    else:
        fig, ax = plt.subplots(figsize=(20, 10))
        freq = [w.freq for w in words]
        sns.scatterplot(
            coordinates[:, 0], coordinates[:, 1], ax=ax)

    for coord, w in zip(coordinates, words):
        ax.annotate(w.word, coord, fontsize=12)
    with open(path, 'wb') as file:
        fig.savefig(file, dpi=300)
    plt.close(fig)


def graph_en_masse(
        models: Dict[str, np.ndarray],
        out_dir: Path,
        reduction: str,  # 'PCA', 'TSNE', or 'both'
        words: List[GroundedWord],
        # hues: Union[List[float], List[int]],
        # sizes: List[int],
        perplexity: Optional[int] = None,
        categorical: bool = False
        ) -> None:
    Path.mkdir(out_dir, parents=True, exist_ok=True)
    word_ids = np.array([w.word_id for w in words])
    for model_name, embed in tqdm(models.items()):
        space = embed[word_ids]
        if reduction == 'PCA':
            visual = PCA(n_components=2).fit_transform(space)
        elif reduction == 'TSNE':
            assert perplexity is not None
            visual = TSNE(
                perplexity=perplexity, learning_rate=10,
                n_iter=5000, n_iter_without_progress=1000).fit_transform(space)
        elif reduction == 'both':
            assert perplexity is not None
            space = PCA(n_components=30).fit_transform(space)
            visual = TSNE(
                perplexity=perplexity, learning_rate=10,
                n_iter=5000, n_iter_without_progress=1000).fit_transform(space)
        else:
            raise ValueError('unknown dimension reduction method')
        if categorical:
            plot_categorical(visual, words, out_dir / f'{model_name}.png')
        else:
            plot(visual, words, out_dir / f'{model_name}.png')


In [None]:
cherry_words = [
    'government', 'washington',
    'estate_tax', 'death_tax',
    'public_option', 'governmentrun',
    'foreign_trade', 'international_trade',
    'cut_taxes', 'trickledown'
]

cherry_words = [GroundedWord(w) for w in cherry_words]

In [4]:
# base_dir = Path('../../results/SGNS deno/sans recomposer')
# deno_space = load_decomposers_en_masse(base_dir, patterns='*/epoch10.pt')

base_dir = Path('../../results/news/validation')
deno_space = load_decomposers_en_masse(base_dir, patterns='*/epoch*.pt')

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

Loading ../../results/news/validation/-3c/epoch3.pt
Loading ../../results/news/validation/-3c/epoch5.pt
Loading ../../results/news/validation/-3c/epoch4.pt
Loading ../../results/news/validation/-3c/epoch2.pt
Loading ../../results/news/validation/-3c/epoch1.pt



In [None]:
models = deno_space
stuff = polarization

graph_en_masse(
    models, out_dir=base_dir / 'cherry/topic/t-SNE p5',
    reduction='TSNE', perplexity=5, words=stuff, categorical=True)
graph_en_masse(
    models, out_dir=base_dir / 'cherry/topic/t-SNE p3',
    reduction='TSNE', perplexity=3, words=stuff, categorical=True)
graph_en_masse(
    models, out_dir=base_dir / 'cherry/topic/t-SNE p2',
    reduction='TSNE', perplexity=2, words=stuff, categorical=True)

In [5]:
from scipy.spatial.distance import cosine as cos_dist
import editdistance

def vec(query: str, embed: np.ndarray) -> np.ndarray:
    try:
        query_id = WTI[query]
    except KeyError:
        raise KeyError(f'Out of vocabulary: {query}')
    return embed[query_id]


def nearest_neighbors(
        query: str,
        embed: np.ndarray,
        top_k: int = 10
        ) -> None:
    query_vec = vec(query, embed)
    print(f"{query}'s neareset neighbors:")
    distances = [
        cos_dist(query_vec, neighbor_vec)
        for neighbor_vec in embed]
    neighbor_indices = np.argsort(distances)
    num_neighbors = 0        
    for sort_rank, neighbor_id in enumerate(neighbor_indices):
        if num_neighbors == top_k:
            break
#         if query_id == neighbor_id:
#             continue
        neighbor_word = ITW[neighbor_id]

        if editdistance.eval(query, neighbor_word) < 3:
            continue
        cosine_similarity = 1 - distances[neighbor_id]
        # neighbor_ids.append(neighbor_id)
        num_neighbors += 1
        print(f'{cosine_similarity:.4f}\t{neighbor_word}')
    print()

In [None]:
deno_space['bill topic'] = 

In [6]:
deno_space.keys()

dict_keys(['epoch3.pt', 'epoch5.pt', 'epoch4.pt', 'epoch2.pt', 'epoch1.pt'])

In [8]:
nn = nearest_neighbors
our_model = deno_space['epoch5.pt']

In [9]:
query = 'estate_tax'
nn(query, sup_PE)
nn(query, our_model)

KeyError: 'Out of vocabulary: estate_tax'

In [21]:
query = 'pro_choice'
nn(query, sup_PE)
nn(query, our_model)

KeyError: 'Out of vocabulary: pro_choice'

In [10]:
query = 'undocumented'
nn(query, sup_PE)
nn(query, our_model)

undocumented's neareset neighbors:
0.9609	immigrants
0.9382	unauthorized
0.9333	aliens
0.9316	deportation
0.9298	immigrant
0.9217	citizenship
0.9067	eligible
0.9053	sanctuary
0.9024	deport
0.9001	illegal

undocumented's neareset neighbors:
0.6909	immigrants
0.6336	deportation
0.5718	immigration
0.5396	citizenship
0.5113	deport
0.5095	immigrant
0.5012	aliens
0.4949	illegals
0.4696	deported
0.4663	border



In [11]:
query = 'illegals'
nn(query, sup_PE)
nn(query, our_model)

illegals's neareset neighbors:
0.9825	dreamers
0.9733	deport
0.9731	deporting
0.9725	birthright
0.9701	stripped
0.9680	discriminating
0.9646	registering
0.9641	penalized
0.9638	second_amendment
0.9632	bribe

illegals's neareset neighbors:
0.5105	deportation
0.5005	immigrants
0.4949	undocumented
0.4411	amnesty
0.4403	aliens
0.3950	deport
0.3861	alien
0.3749	immigration
0.3669	minors
0.3661	deported



In [25]:
query = 'aliens'
nn(query, sup_PE)
nn(query, our_model)

aliens's neareset neighbors:
0.9636	unauthorized
0.9584	deportation
0.9532	illegal
0.9431	offenses
0.9406	inmates
0.9400	jails
0.9372	prisons
0.9333	undocumented
0.9311	deport
0.9292	felons

aliens's neareset neighbors:
0.5995	illegal
0.5481	deportation
0.5104	immigrants
0.5012	undocumented
0.4507	amnesty
0.4403	illegals
0.4392	deport
0.4235	deporting
0.4070	immigration
0.3991	citizenship



In [26]:
query = 'leftists'
nn(query, sup_PE)
nn(query, our_model)

leftists's neareset neighbors:
0.9869	fundamentalists
0.9849	atheists
0.9832	patriotism
0.9818	feminists
0.9774	righteous
0.9764	patriotic
0.9753	believers
0.9752	fundamentalist
0.9743	fanatics
0.9740	intolerant

leftists's neareset neighbors:
0.3259	progressives
0.3237	liberals
0.3066	moderates
0.2980	anarchists
0.2965	movement
0.2925	anti-american
0.2820	islamists
0.2793	nationalists
0.2793	activists
0.2742	supremacists



In [27]:
query = 'antifa'
nn(query, sup_PE)
nn(query, our_model)

antifa's neareset neighbors:
0.9769	mob
0.9750	kkk
0.9749	neo-nazis
0.9701	slurs
0.9679	martyr
0.9676	muhammad
0.9630	black_lives_matter
0.9617	mohammed
0.9606	preacher
0.9604	savage

antifa's neareset neighbors:
0.3674	black_lives_matter
0.3640	marches
0.3483	supremacists
0.3327	charlottesville
0.3285	rally
0.3278	occupy_wall_street
0.3072	protesters
0.2905	supremacist
0.2844	chants
0.2842	sympathizer



In [30]:
query = 'anti-abortion'
nn(query, sup_PE)
nn(query, our_model)

anti-abortion's neareset neighbors:
0.9377	challenged
0.9376	sonia_sotomayor
0.9361	clarence_thomas
0.9349	the_national_rifle_association
0.9348	anthony_kennedy
0.9348	naacp
0.9333	aclu
0.9329	pro-life
0.9327	u.s._supreme_court
0.9327	samuel_alito

anti-abortion's neareset neighbors:
0.5418	abortion
0.4835	fetal
0.4782	planned_parenthood
0.4431	abortions
0.4280	fetus
0.4271	reproductive
0.4253	fetuses
0.3958	pro-life
0.3908	aborted
0.3750	clinics



In [31]:
query = 'wealthiest'
nn(query, sup_PE)
nn(query, our_model)

wealthiest's neareset neighbors:
0.9701	earners
0.9667	benefit
0.9638	entitlements
0.9632	millionaires
0.9622	benefiting
0.9616	brackets
0.9604	retirees
0.9602	jobless
0.9598	regressive
0.9589	richest

wealthiest's neareset neighbors:
0.4102	tax
0.4082	earners
0.4035	rich
0.3713	taxes
0.3692	wealthy
0.3581	income
0.3485	corporations
0.3439	cuts
0.3406	richest
0.3352	breaks



In [14]:
query = 'wall'
nn(query, sup_PE)
nn(query, our_model)

wall's neareset neighbors:
0.9527	street
0.8507	off
0.8440	back
0.8347	fence
0.8298	along
0.8263	onto
0.8225	away
0.8206	stops
0.8204	started
0.8198	pulling

wall's neareset neighbors:
0.6686	street
0.3769	wall_street
0.3458	border
0.3075	bailout
0.3048	build
0.2892	mexican
0.2856	window
0.2788	building
0.2679	market
0.2647	managed



In [None]:
query = ''
nn(query, sup_PE)
nn(query, our_model)

In [16]:
query = 'obamacare'
nn(query, sup_PE)
nn(query, our_model)

obamacare's neareset neighbors:
0.9803	the_affordable_care_act
0.9688	repeal
0.9678	repealing
0.9674	aca
0.9565	medicare
0.9515	payer
0.9392	healthcare
0.9346	ahca
0.9302	skinny
0.9254	social_security

obamacare's neareset neighbors:
0.5771	the_affordable_care_act
0.5704	repeal
0.5469	insurance
0.5284	aca
0.5032	repealing
0.4914	healthcare
0.4864	insurers
0.4863	medicaid
0.4794	premiums
0.4517	marketplaces



In [20]:
query = 'aca'
nn(query, sup_PE)
nn(query, our_model)

aca's neareset neighbors:
0.9905	the_affordable_care_act
0.9862	repealing
0.9713	payer
0.9674	obamacare
0.9626	medicare
0.9607	repealed
0.9553	entitlement
0.9547	marketplaces
0.9543	enrollees
0.9537	eligibility

aca's neareset neighbors:
0.5284	obamacare
0.5278	insurers
0.5112	insurance
0.5005	the_affordable_care_act
0.4786	pre-existing
0.4594	premiums
0.4515	uninsured
0.4511	coverage
0.4305	marketplaces
0.4281	mandate



In [18]:
query = 'socialized'
nn(query, sup_PE)
nn(query, our_model)

socialized's neareset neighbors:
0.9868	needy
0.9863	sicker
0.9845	handouts
0.9835	livelihood
0.9830	separating
0.9821	professions
0.9818	unaffordable
0.9807	chronically
0.9794	takers
0.9792	vaccinations

socialized's neareset neighbors:
0.4164	medicine
0.3018	hancock
0.2846	nostalgic
0.2830	alleys
0.2787	jettisoned
0.2777	sanford
0.2770	asses
0.2748	the_french_revolution
0.2726	alphabet
0.2692	unalienable



In [22]:
query = 'guns'
nn(query, sup_PE)
nn(query, our_model)

guns's neareset neighbors:
0.9117	cops
0.9054	carry
0.9046	firearms
0.9013	stealing
0.8996	physically
0.8964	criminals
0.8947	smoking
0.8851	bars
0.8837	alien
0.8836	aliens

guns's neareset neighbors:
0.5922	ammunition
0.5623	firearms
0.5047	rifles
0.4887	nra
0.4859	rifle
0.4815	semi-automatic
0.4495	handgun
0.4401	automatic
0.4378	concealed
0.4259	firearm



In [19]:
query = 'trade'
nn(query, sup_PE)
nn(query, our_model)

trade's neareset neighbors:
0.9401	nafta
0.9194	deficits
0.9179	restructuring
0.9134	liberalization
0.9096	china_economics
0.9086	consolidation
0.9086	the_european_central_bank
0.9083	expansionary
0.9081	agreements
0.9078	competitiveness

trade's neareset neighbors:
0.4774	agreements
0.4137	the_trans_-_pacific_partnership
0.3983	nafta
0.3968	agreement
0.3711	wto
0.3511	free
0.3454	deals
0.3384	tpp
0.3326	the_world_trade_organization
0.3273	liberalization



In [24]:
query = 'bernie'
nn(query, sup_PE)
nn(query, our_model)

bernie's neareset neighbors:
0.9878	jeb
0.9877	newt
0.9803	loser
0.9794	mitt
0.9709	gopers
0.9705	huckabee
0.9694	barack
0.9684	dems
0.9658	cavuto
0.9646	flop

bernie's neareset neighbors:
0.3841	bernie_sanders
0.3549	hillary
0.3037	sanders
0.2437	delegate
0.2394	delegates
0.2379	opponent
0.2339	ralph_nader
0.2320	tying
0.2283	dictate
0.2234	dnc



In [None]:
models = cono_space

graph_en_masse(
    models, out_dir=base_dir / 'cherry/party/t-SNE p5',
    reduction='TSNE', perplexity=5, words=cherry_words, categorical=False)

graph_en_masse(
    models, out_dir=base_dir / 'cherry/party/t-SNE p3',
    reduction='TSNE', perplexity=3, words=cherry_words, categorical=False)

graph_en_masse(
    models, out_dir=base_dir / 'cherry/party/t-SNE p2',
    reduction='TSNE', perplexity=2, words=cherry_words, categorical=False)

In [None]:
models = deno_space

graph_en_masse(
    models,
    out_dir=f'{base_dir}/decomposed deno/party/t-SNE p25',
    reduction='TSNE', perplexity=25,
    word_ids=J_ids, words=J_words, hues=J_skew, sizes=J_freq)

graph_en_masse(
    models,
    out_dir=f'{base_dir}/decomposed deno/party/t-SNE p50',
    reduction='TSNE', perplexity=50,
    word_ids=J_ids, words=J_words, hues=J_skew, sizes=J_freq)

In [None]:
models = cono_space

graph_en_masse(
    models, out_dir=f'{base_dir}/decomposed cono/topic/t-SNE p5',
    reduction='TSNE', perplexity=5,
    word_ids=J_ids, words=J_words, hues=J_deno, sizes=J_freq,
    categorical=True)

graph_en_masse(
    models,
    out_dir=f'{base_dir}/decomposed cono/topic/t-SNE p3',
    reduction='TSNE', perplexity=3,
    word_ids=J_ids, words=J_words, hues=J_deno, sizes=J_freq,
    categorical=True)

graph_en_masse(
    models,
    out_dir=f'{base_dir}/decomposed cono/topic/t-SNE p10',
    reduction='TSNE', perplexity=10,
    word_ids=J_ids, words=J_words, hues=J_deno, sizes=J_freq,
    categorical=True)

In [None]:
models = cono_space

# graph_en_masse(
#     models,
#     out_dir=f'{base_dir}/Joint/topic/PCA',
#     reduction='PCA',
#     word_ids=J_ids, words=J_words, hues=J_skew, sizes=J_freq)

graph_en_masse(
    models,
    out_dir=f'{base_dir}/decomposed cono/party/t-SNE p25',
    reduction='TSNE', perplexity=25,
    word_ids=J_ids, words=J_words, hues=J_skew, sizes=J_freq)

graph_en_masse(
    models,
    out_dir=f'{base_dir}/decomposed cono/party/t-SNE p50',
    reduction='TSNE', perplexity=50,
    word_ids=J_ids, words=J_words, hues=J_skew, sizes=J_freq)

# Homogeneity V-Measure

In [None]:
# Deno space, eval deno, higher is better
for model_name, model in deno_space.items():
    cluster_labels, true_labels = NN_cluster_ids(
        model, J_ids, categorical=True, top_k=10)    
    homogeneity, completeness, v_measure = np.around(
        homogeneity_completeness_v_measure(true_labels, cluster_labels), 4)
    print(model_name, homogeneity, completeness, v_measure, sep='\t')
#     print(pred_labels)

In [None]:
# Deno space, eval cono, lower is better
for model_name, model in deno_space.items():
    cluster_labels, true_labels = NN_cluster_ids(
        model, J_ids, categorical=False, top_k=5)    
    homogeneity, completeness, v_measure = np.around(
        homogeneity_completeness_v_measure(true_labels, cluster_labels), 4)
    print(model_name, homogeneity, completeness, v_measure, sep='\t')
#     print(pred_labels)

In [None]:
# Cono space, eval cono, higher is better
for model_name, model in cono_space.items():
    cluster_labels, true_labels = NN_cluster_ids(
        model, J_ids, categorical=False, top_k=5)    
    homogeneity, completeness, v_measure = np.around(
        homogeneity_completeness_v_measure(true_labels, cluster_labels), 4)
    print(model_name, homogeneity, completeness, v_measure, sep='\t')

In [None]:
# Cono space, eval deno, lower is better
for model_name, model in cono_space.items():
    cluster_labels, true_labels = NN_cluster_ids(
        model, J_ids, categorical=True, top_k=5)    
    homogeneity, completeness, v_measure = np.around(
        homogeneity_completeness_v_measure(true_labels, cluster_labels), 4)
    print(model_name, homogeneity, completeness, v_measure, sep='\t')