In [1]:
import random
import os

import torch
from torch import nn
from matplotlib import pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from tqdm import tqdm

from evaluations.intrinsic_eval import cherry_words, generic_words
from decomposer import Decomposer, DecomposerConfig

random.seed(42)
torch.manual_seed(42)
sns.set()

DEVICE = 'cpu'
PE = torch.load(
    '../../results/pretrained/init.pt', map_location=DEVICE)['model']
GD = PE.grounding

In [17]:
def load(path):
    stuff = torch.load(path, map_location=DEVICE)['model']
    return stuff.embedding.weight.detach().numpy()

def gather(words):
    word_ids = [PE.word_to_id[w] for w in words]
    freq = [GD[w]['freq'] for w in words]
    skew = [GD[w]['R_ratio'] for w in words]
    maj_deno = [GD[w]['majority_deno'] for w in words]
    return word_ids, freq, skew, maj_deno

def plot(coordinates, words, freq, skew, path):
    fig, ax = plt.subplots(figsize=(15,10))    
    sns.scatterplot(
        coordinates[:,0], coordinates[:,1], 
        hue=skew, palette='coolwarm', hue_norm=(0, 1), 
        size=freq, sizes=(100, 1000), 
        legend=None, ax=ax)
    for coord, word in zip(coordinates, words):
        ax.annotate(word, coord, fontsize=12)
    with open(path, 'wb') as file:
        fig.savefig(file, dpi=300)
    plt.close(fig)

def plot_categorical(coordinates, words, freq, skew, path):
    fig, ax = plt.subplots(figsize=(20,10))    
    sns.scatterplot(
        coordinates[:,0], coordinates[:,1], 
        hue=skew, palette='muted', hue_norm=(0, 1),
        size=freq, sizes=(100, 1000), 
        legend='brief', ax=ax)
    chartBox = ax.get_position()
    ax.set_position([chartBox.x0, chartBox.y0, chartBox.width*0.6, chartBox.height])
    ax.legend(loc='upper center', bbox_to_anchor=(1.45, 0.8), ncol=1)
    for coord, word in zip(coordinates, words):
        ax.annotate(word, coord, fontsize=12)
    with open(path, 'wb') as file:
        fig.savefig(file, dpi=300)
    plg.close(fig)

In [3]:
ch_ids, ch_freq, ch_skew, ch_deno = gather(cherry_words)
gen_ids, gen_freq, gen_skew, gen_deno = gather(generic_words)

GOP_words = [w for w in PE.word_to_id.keys()
             if GD[w]['freq'] > 99 and GD[w]['R_ratio'] > 0.75]
print(len(GOP_words))
# GOP_words = random.sample(GOP_words, 100)
GOP_ids, GOP_freq, GOP_skew, GOP_deno = gather(GOP_words)

random_words = [w for w in PE.word_to_id.keys() 
                if GD[w]['freq'] > 99]
random_words = random.sample(random_words, 50)
rand_ids, rand_freq, rand_skew, rand_deno = gather(random_words)

54


In [None]:
models = {
    'M0 pretrained': load('pretrained/init.pt'),
    'M1 L1 -0.05d': load('cono space remove deno/L1 -0.05d/epoch50.pt'),
    'M2 L4 -0.05d': load('cono space remove deno/L4 -0.05d/epoch50.pt'),
    'M3 +5 -0.05d': load('affine/L4 +5 -0.05d/epoch50.pt'),
    'M4 +5 -0.1d': load('affine/L4 +5 -0.1d/epoch50.pt'),
    'M5 +5 -0.2d': load('affine/L4 +5 -0.2d/epoch50.pt'),
    'M6 +5 -0.5d': load('affine/L4 +5 -0.5d/epoch50.pt'),
    'M7 +5 -1d': load('affine/L4 +5 -1d/epoch50.pt'),
    'M8 +10 -1.5d': load('affine/L4 +10 -1.5d/epoch50.pt'),
    'M9 +10 -2d': load('affine/L4 +10 -2d/epoch50.pt'),
    'M10 +5 0c -1d': load('affine/L4 +5 0c/epoch50.pt'),
}

In [None]:
out_dir = '../../analysis/t-SNE'
for model_name, embed in models.items():
    space = embed[GOP_ids]
    visual = TSNE(perplexity=10, learning_rate=1).fit_transform(space)
    plot(visual, GOP_words, GOP_freq, GOP_skew, 
         f'{out_dir}/GOP {model_name}.png')

In [29]:
GD['deterrence']

Counter({'Crime and law enforcement': 23,
         'D': 74,
         'Armed forces and national security': 13,
         'International affairs': 9,
         'Foreign trade and international finance': 8,
         'R': 53,
         'Native Americans': 7,
         'Arms control': 5,
         'Civil rights and liberties, minority issues': 2,
         'Finance and financial sector': 5,
         'Law': 2,
         'Labor and employment': 4,
         'Commerce': 7,
         'Transportation and public works': 9,
         'Public lands and natural resources': 4,
         'Environmental protection': 15,
         'Water resources development': 1,
         'Taxation': 1,
         'Agriculture and food': 1,
         'Families': 2,
         'Health': 1,
         'Science, technology, communications': 1,
         'Government operations and politics': 5,
         'Economics and public finance': 2,
         'majority_deno': 'Crime and law enforcement',
         'freq': 127,
         'R_ratio': 0.417322

In [25]:
def load_en_masse(in_dir, endswith):
    models = {}
    for dirpath, _, filenames in tqdm(os.walk(in_dir)):
        for file in filenames:
            if file.endswith(endswith):
                path = os.path.join(dirpath, file)
                name = path.lstrip(in_dir).replace('/', ' ')
                models[name] = load(path)
    return models
    
def graph_en_masse(
        in_dir,
        out_dir, 
        reduction,  #  'PCA', 'TSNE', or 'both'
        word_ids,  
        words, 
        hues,
        sizes,
        perplexity=None):
    os.makedirs(out_dir, exist_ok=True)
    for model_name, embed in tqdm(models.items()):
        space = embed[word_ids]
        if reduction == 'PCA':
            visual = PCA(n_components=2).fit_transform(space)
        elif reduction == 'TSNE':
            assert perplexity is not None
            visual = TSNE(
                perplexity=perplexity, learning_rate=10, 
                n_iter=5000, n_iter_without_progress=1000).fit_transform(space)
        elif reduction == 'both':
            assert perplexity is not None
            space = PCA(n_components=30).fit_transform(space)
            visual = TSNE(
                perplexity=perplexity, learning_rate=10, 
                n_iter=5000, n_iter_without_progress=1000).fit_transform(space)
        else: 
            raise ValueError('unknown dimension reduction method')
        plot(visual, words, sizes, hues, 
             os.path.join(out_dir, f'{model_name}.png'))

In [11]:
models = load_en_masse('../../results/search delta/BS512', 'epoch30.pt')

201it [01:22,  2.44it/s]


In [26]:
graph_en_masse(
    models,
    out_dir='../../analysis/search delta/PCA + t-SNE p5',
    reduction='both',
    perplexity=5,
    word_ids=ch_ids,
    words=cherry_words,
    hues=ch_skew,
    sizes=ch_freq,
)


  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:01<02:48,  1.70s/it][A
  2%|▏         | 2/100 [00:03<02:47,  1.71s/it][A
  3%|▎         | 3/100 [00:05<02:54,  1.80s/it][A
  4%|▍         | 4/100 [00:07<02:53,  1.81s/it][A
  5%|▌         | 5/100 [00:08<02:45,  1.74s/it][A
  6%|▌         | 6/100 [00:10<02:41,  1.71s/it][A
  7%|▋         | 7/100 [00:12<02:44,  1.77s/it][A
  8%|▊         | 8/100 [00:14<02:44,  1.79s/it][A
  9%|▉         | 9/100 [00:16<02:44,  1.80s/it][A
 10%|█         | 10/100 [00:17<02:36,  1.74s/it][A
 11%|█         | 11/100 [00:19<02:37,  1.77s/it][A
 12%|█▏        | 12/100 [00:21<02:38,  1.80s/it][A
 13%|█▎        | 13/100 [00:22<02:27,  1.70s/it][A
 14%|█▍        | 14/100 [00:24<02:19,  1.63s/it][A
 15%|█▌        | 15/100 [00:26<02:24,  1.70s/it][A
 16%|█▌        | 16/100 [00:28<02:26,  1.74s/it][A
 17%|█▋        | 17/100 [00:29<02:26,  1.77s/it][A
 18%|█▊        | 18/100 [00:31<02:24,  1.77s/it][A
 19%|█▉        | 19/100 [00:3

In [27]:
graph_en_masse(
    models,
    out_dir='../../analysis/search delta/PCA + t-SNE p10',
    reduction='both',
    perplexity=10,
    word_ids=ch_ids,
    words=cherry_words,
    hues=ch_skew,
    sizes=ch_freq,
)


  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:01<02:39,  1.61s/it][A
  2%|▏         | 2/100 [00:03<02:46,  1.70s/it][A
  3%|▎         | 3/100 [00:04<02:38,  1.63s/it][A
  4%|▍         | 4/100 [00:06<02:30,  1.57s/it][A
  5%|▌         | 5/100 [00:07<02:25,  1.54s/it][A
  6%|▌         | 6/100 [00:09<02:26,  1.55s/it][A
  7%|▋         | 7/100 [00:10<02:20,  1.51s/it][A
  8%|▊         | 8/100 [00:12<02:17,  1.49s/it][A
  9%|▉         | 9/100 [00:13<02:13,  1.47s/it][A
 10%|█         | 10/100 [00:15<02:11,  1.46s/it][A
 11%|█         | 11/100 [00:16<02:09,  1.45s/it][A
 12%|█▏        | 12/100 [00:18<02:08,  1.46s/it][A
 13%|█▎        | 13/100 [00:19<02:06,  1.46s/it][A
 14%|█▍        | 14/100 [00:21<02:06,  1.47s/it][A
 15%|█▌        | 15/100 [00:22<02:10,  1.53s/it][A
 16%|█▌        | 16/100 [00:24<02:10,  1.56s/it][A
 17%|█▋        | 17/100 [00:25<02:06,  1.52s/it][A
 18%|█▊        | 18/100 [00:27<02:03,  1.51s/it][A
 19%|█▉        | 19/100 [00:2

In [22]:
graph_en_masse(
    models,
    out_dir='../../analysis/search delta/t-SNE p10',
    reduction='TSNE',
    perplexity=10,
    word_ids=ch_ids,
    words=cherry_words,
    hues=ch_skew,
    sizes=ch_freq,
)


  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:01<02:33,  1.55s/it][A
  2%|▏         | 2/100 [00:03<02:29,  1.53s/it][A
  3%|▎         | 3/100 [00:04<02:25,  1.50s/it][A
  4%|▍         | 4/100 [00:05<02:23,  1.49s/it][A
  5%|▌         | 5/100 [00:07<02:20,  1.48s/it][A
  6%|▌         | 6/100 [00:09<02:25,  1.55s/it][A
  7%|▋         | 7/100 [00:10<02:24,  1.55s/it][A
  8%|▊         | 8/100 [00:12<02:19,  1.52s/it][A
  9%|▉         | 9/100 [00:13<02:17,  1.51s/it][A
 10%|█         | 10/100 [00:15<02:13,  1.49s/it][A
 11%|█         | 11/100 [00:16<02:11,  1.48s/it][A
 12%|█▏        | 12/100 [00:17<02:08,  1.46s/it][A
 13%|█▎        | 13/100 [00:19<02:06,  1.45s/it][A
 14%|█▍        | 14/100 [00:20<02:07,  1.48s/it][A
 15%|█▌        | 15/100 [00:22<02:06,  1.49s/it][A
 16%|█▌        | 16/100 [00:23<02:05,  1.49s/it][A
 17%|█▋        | 17/100 [00:25<02:04,  1.50s/it][A
 18%|█▊        | 18/100 [00:27<02:11,  1.61s/it][A
 19%|█▉        | 19/100 [00:2

In [None]:
# Manually Graph

In [31]:
out_dir = '../../analysis/PCA/cherry/topic_live'
os.makedirs(out_dir, exist_ok=True)
for model_name, embed in models.items():
    space = embed[ch_ids]
    visual = PCA(n_components=2).fit_transform(space)
#     visual = TSNE(
#         perplexity=4, learning_rate=10, 
#         n_iter=5000, n_iter_without_progress=1000).fit_transform(visual)
    plot(visual, cherry_words, ch_freq, ch_deno, 
         os.path.join(out_dir, f'{model_name}.png'))

In [None]:
out_dir = '../../analysis/t-SNE'
for model_name, embed in models.items():
    space = embed[generic_ids]
    visual = TSNE(perplexity=5, learning_rate=1).fit_transform(space)
    plot(visual, generic_words, generic_freq, generic_skew, 
         f'{out_dir}/generic {model_name}.png')

In [None]:
out_dir = '../../analysis/t-SNE'
for model_name, embed in models.items():
    space = embed[random_ids]
    visual = TSNE(
        perplexity=20, learning_rate=10, n_iter=5000).fit_transform(space)
    plot(visual, random_words, random_freq, random_skew, 
         f'{out_dir}/random {model_name}.png')