In [None]:
import random

import torch
from torch import nn
from matplotlib import pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from tqdm import tqdm

from evaluations.intrinsic_eval import cherry_words, generic_words
from decomposer import Decomposer, DecomposerConfig

random.seed(42)
torch.manual_seed(42)
sns.set()

DEVICE = 'cpu'
PE = torch.load(
    '../../results/pretrained/init.pt', map_location=DEVICE)['model']

In [None]:
def load(path):
    path = '../../results/' + path
    stuff = torch.load(path, map_location=DEVICE)['model']
    return stuff.embedding.weight.detach().numpy()

def gather(words):
    word_ids = [PE.word_to_id[w] for w in words]
    freq = [PE.Dem_frequency[w] + PE.GOP_frequency[w] for w in words]
    skew = [PE.R_ratio(w) for w in words]
    return word_ids, freq, skew

def plot(coordinates, words, freq, skew, path):
    fig, ax = plt.subplots(figsize=(15,10))
    sns.scatterplot(
        coordinates[:,0], coordinates[:,1], 
        hue=skew, palette='coolwarm', hue_norm=(0, 1),
        size=freq, sizes=(100, 1000), 
        legend=None, ax=ax)
    for coord, word in zip(coordinates, words):
        ax.annotate(word, coord, fontsize=12)
    with open(path, 'wb') as file:
        fig.savefig(file, dpi=300)
    return None 

In [None]:
cherry_ids, cherry_freq, cherry_skew = gather(cherry_words)
generic_ids, generic_freq, generic_skew = gather(generic_words)

GOP_words = [w for w in PE.word_to_id.keys()
             if PE.GOP_frequency[w] > 99 and PE.R_ratio(w) > 0.75]
print(len(GOP_words))
# GOP_words = random.sample(GOP_words, 100)
GOP_ids, GOP_freq, GOP_skew = gather(GOP_words)

random_words = [w for w in PE.word_to_id.keys() 
                if PE.GOP_frequency[w] + PE.Dem_frequency[w] > 99]
random_words = random.sample(random_words, 50)
random_ids, random_freq, random_skew = gather(random_words)

In [None]:
models = {
#     'M0 pretrained': load('pretrained/init.pt'),
#     'M1 L1 -0.05d': load('cono space remove deno/L1 -0.05d/epoch50.pt'),
#     'M2 L4 -0.05d': load('cono space remove deno/L4 -0.05d/epoch50.pt'),
    'M3 +5 -0.05d': load('affine/L4 +5 -0.05d/epoch50.pt'),
    'M4 +5 -0.1d': load('affine/L4 +5 -0.1d/epoch50.pt'),
    'M5 +5 -0.2d': load('affine/L4 +5 -0.2d/epoch50.pt'),
    'M6 +5 -0.5d': load('affine/L4 +5 -0.5d/epoch50.pt'),
    'M7 +5 -1d': load('affine/L4 +5 -1d/epoch50.pt'),
    'M8 +10 -1.5d': load('affine/L4 +10 -1.5d/epoch50.pt'),
    'M9 +10 -2d': load('affine/L4 +10 -2d/epoch50.pt'),
    'M10 +5 0c -1d': load('affine/L4 +5 0c/epoch50.pt'),
}

In [None]:
out_dir = '../../analysis/t-SNE'
for model_name, embed in models.items():
    space = embed[GOP_ids]
    visual = TSNE(perplexity=10, learning_rate=1).fit_transform(space)
    plot(visual, GOP_words, GOP_freq, GOP_skew, 
         f'{out_dir}/GOP {model_name}.png')

In [None]:
out_dir = '../../analysis/t-SNE'
for model_name, embed in models.items():
    space = embed[cherry_ids]
#     space = PCA(n_components=30).fit_transform(space)
    visual = TSNE(
        perplexity=4, learning_rate=10, 
        n_iter=5000, n_iter_without_progress=1000).fit_transform(space)
    
    plot(visual, cherry_words, cherry_freq, cherry_skew, 
         f'{out_dir}/cherry {model_name}.png')

In [None]:
out_dir = '../../analysis/t-SNE'
for model_name, embed in models.items():
    space = embed[generic_ids]
    visual = TSNE(perplexity=5, learning_rate=1).fit_transform(space)
    plot(visual, generic_words, generic_freq, generic_skew, 
         f'{out_dir}/generic {model_name}.png')

In [None]:
out_dir = '../../analysis/t-SNE'
for model_name, embed in models.items():
    space = embed[random_ids]
    visual = TSNE(
        perplexity=20, learning_rate=10, n_iter=5000).fit_transform(space)
    plot(visual, random_words, random_freq, random_skew, 
         f'{out_dir}/random {model_name}.png')