In [1]:
import pickle
import os
from typing import Set, Tuple, List, Dict

import torch
import numpy as np
import scipy
from sklearn.metrics import pairwise
from tqdm import tqdm

from adversarial import AdversarialDecomposer, AdversarialConfig
from preprocessing.S4_export_training_corpus import Document

np.random.seed(42)
torch.manual_seed(42)

class Embedding():
    
    def __init__(self, source: str, path: str):
        if source == 'adversarial':
            self.init_from_adversarial(path)
        elif source == 'skip_gram':
            self.init_from_skip_gram(path)
        elif source == 'plain_text':
            self.init_from_plain_text(path)
        else:
            raise ValueError('Unknown embedding source.')
            
    def init_from_adversarial(self, path: str, device=torch.device('cpu')):
        payload = torch.load(path, map_location=device)
        model = payload['model']
        self.word_to_id = model.word_to_id
        self.id_to_word = model.id_to_word 
        # encoded layer
        self.embedding = model.export_decomposed_embedding(device=device)
        
        # manually choose which layer to export
#         all_vocab_ids = torch.arange(
#             len(self.word_to_id), dtype=torch.long, device=device)
#         with torch.no_grad():
#             self.embedding = model.center_decoder(model.encoder_forward(all_vocab_ids))
            
#             self.embedding = self.encoder_forward(all_vocab_ids)
#             self.embedding = model.denotation_decoder(model.encoder_forward(all_vocab_ids))
#             self.embeddings = model.context_embedding(all_vocab_ids)  # galaxy brain
            
#     def init_from_adversarial(self, path: str):        
#         config = DenotationEncoderConfig()
#         config.input_dir = '../../data/processed/adversarial/44_Obama_1e-5'
#         data = AdversarialDataset(config)
#         model = DenotationEncoder(config, data)
#         model.load_state_dict(torch.load(path))
#         self.embedding = model.export_decomposed_embedding().to('cpu')
#         self.word_to_id = model.word_to_id
#         self.id_to_word = model.id_to_word

    def init_from_skip_gram(self, paths: Tuple[str, str]) -> None:
        """Directly extract the weights of a single layer."""
        model_path, vocab_path = paths
        with open(model_path, 'rb') as model_file:
            state_dict = torch.load(model_file, map_location='cpu')
    #     print(state_dict.keys())
        self.embedding = state_dict['center_embedding.weight'].numpy()
        with open(vocab_path, 'rb') as vocab_file:
            self.word_to_id, self.id_to_word, _ = pickle.load(vocab_file)

    def init_from_plain_text(self, path: str) -> Tuple[np.array, Dict[str, int]]:
        id_generator = 0
        word_to_id: Dict[str, int] = {}
        embeddings: List[float] = []
        embedding_file = open(path)
        vocab_size, num_dimensions = map(int, embedding_file.readline().split())
        print(f'vocab_size = {vocab_size:,}, num_dimensions = {num_dimensions}')
        print(f'Loading embeddings from {path}', flush=True)
        for line in embedding_file:
            line: List[str] = line.split()  # type: ignore
            word = line[0]
            vector = np.array(line[-num_dimensions:], dtype=np.float64)
            embeddings.append(vector)
            word_to_id[word] = id_generator
            id_generator += 1
        embedding_file.close()
        print('Done')
        self.id_to_word = {val: key for key, val in word_to_id.items()}
        self.word_to_id = word_to_id
        self.embedding = np.array(embeddings)
        
    def write_to_tensorboard_projector(self, tb_dir: str) -> None:
        from torch.utils import tensorboard
        tb = tensorboard.SummaryWriter(log_dir=tb_dir)
        all_vocab_ids = range(len(self.word_to_id))
        embedding_labels = [
            self.id_to_word[word_id]
            for word_id in all_vocab_ids]
        tb.add_embedding(
            self.embedding[:9999], 
            embedding_labels[:9999], 
            global_step=0)
        
    def export_web_projector(self, out_dir: str) -> None:
        random_indices = np.random.randint(len(self.embedding), size=10000)
        subset_embedding = self.embedding[random_indices].tolist()
        
        vector_path = os.path.join(out_dir, 'tensorboard.tsv')
        with open(vector_path, 'w') as vector_file:
            for vector in subset_embedding:
                vector_file.write('\t'.join(map(str, vector)) + '\n')

        label_path = os.path.join(out_dir, 'tensorboard_labels.tsv')
        with open(label_path, 'w') as label_file:
            for index in random_indices:
                label_file.write(self.id_to_word[index] + '\n')

    def cosine_similarity(self, query1: str, query2: str) -> float:
        try:
            query1_id = self.word_to_id[query1]
        except KeyError as error:
            print(f'{query1} is out of vocabulary. Sorry!')
            raise error
        try:
            query2_id = self.word_to_id[query2]
        except KeyError as error:
            print(f'{query2} is out of vocabulary. Sorry!')
            raise error
        vectors = self.embedding[(query1_id, query2_id), :]
        similarity = 1 - scipy.spatial.distance.cosine(vectors[0], vectors[1])
        return similarity

    def nearest_neighbor(self, query: str, top_k: int = 10):
        try:
            query_id = self.word_to_id[query]
        except KeyError:
            raise KeyError(f'{query} is out of vocabulary. Sorry!')    
        query_vec = self.embedding[query_id]
        
        distances = [scipy.spatial.distance.cosine(query_vec, vec) 
                     for vec in self.embedding]
        neighbors = np.argsort(distances)
        print(f"{query}'s neareset neighbors:")
        for ranking in range(1, top_k + 1):
            word_id = neighbors[ranking]
            word = self.id_to_word[word_id]
            cosine_similarity = 1 - distances[word_id]
            print(f'{cosine_similarity:.4f}\t{word}')
        print()
        

## Load Models

In [2]:
models = {}

In [3]:
models['w2v'] = Embedding(  # comparable to d1c0
    'plain_text', 
    '../../results/baseline/word2vec_Obama.txt')

vocab_size = 34,100, num_dimensions = 300
Loading embeddings from ../../results/baseline/word2vec_Obama.txt
Done


In [4]:
# linear-ReLU-linear encoder, frozen uniform init
base_dir = '../../results/adversarial/MLP encoder/'

models['vanilla deno'] = Embedding(
    'adversarial', base_dir + '1d0c/epoch50.pt')

models['deno minus cono'] = Embedding(
    'adversarial', base_dir + '1d-1c/epoch15.pt')  # need more epochs

models['vanilla cono'] = Embedding(
    'adversarial', base_dir + '0d1c/epoch50.pt')

# models['cono minus deno'] = Embedding(
#     'adversarial', 



In [None]:
# linear-ReLU-linear encoder, frozen uniform init, adversary-free context, no gradient clipping
base_dir = '../../results/adversarial/context sans adversary/'

models['vanilla deno'] = Embedding(
    'adversarial', base_dir + '1d0c bs1/epoch100.pt')

# models['deno minus cono'] = Embedding(
#     'adversarial', 

models['vanilla cono'] = Embedding(
    'adversarial', base_dir + '0d1c/epoch30.pt')

# models['cono minus deno'] = Embedding(
#     'adversarial', 

In [None]:
# linear-ReLU encoder, frozen uniform init

models['vanilla deno'] = Embedding(
    'adversarial', 
    '../../results/adversarial/nonlinear encoder/d1c0/epoch45.pt')

models['deno minus cono'] = Embedding(
    'adversarial', 
    '../../results/adversarial/nonlinear encoder/d1c-1_w2v/epoch50.pt')  # inconsistent init!

models['vanilla cono'] = Embedding(
    'adversarial', 
    '../../results/adversarial/nonlinear encoder/d0c1/epoch25.pt')

models['cono mius deno'] = Embedding(
    'adversarial', 
    '../../results/adversarial/nonlinear encoder/d-0.1c1/epoch50.pt')

In [None]:
# models['cono'].write_to_tensorboard_projector(
#     '../../results/adversarial/Obama/p8_.55to.75/d0_c1/embedding_projector')
# models['cono'].export_web_tensorboard('../../results/adversarial/Obama/p8_.55to.75/d0_c1/web_projector')

## Nearest Neighbors

In [None]:
model = models['vanilla cono']
model.nearest_neighbor('estate_tax')
model.nearest_neighbor('death_tax')
# model.nearest_neighbor('undocumented_immigrants')
model.nearest_neighbor('illegal_aliens')
model.nearest_neighbor('music')
model.nearest_neighbor('cry')

## Distances

In [5]:
def cherry_pick(model):
    
    def print_similarities(pairs):
        for word1, word2 in pairs:
            try: 
                print(f'{model.cosine_similarity(word1, word2):.4f}  '
                      f'{word1:<20}{word2:<20}')
            except KeyError:
                pass
        
    print('Same entity, different parties. Removing connotation should increase similarity:')
    cherries = [
        ('estate_tax', 'death_tax'), 
        ('undocumented', 'illegal_aliens'),
        ('obamacare', 'protection_and_affordable'),
        ('socialized_medicine', 'public_option'),
        ('second_amendment_rights', 'guns'), 
    #     'health_care_bill', ',
    #     'the_wall_street_reform_legislation', 'financial_stability', 'capital_gains_tax',
    #     'deficit_spending', 'bush_tax_cuts'
    ]    
    print_similarities(cherries)
     
    print('\n\nDifferent entities, same party. Removing denotation should increase similarity')
    ideologies = [
        ('tax_cuts', 'entitlement_reform'),
        ('religious_freedom', 'right_to_life')
    ]
    print_similarities(ideologies)
    
    print('\n\nDifferent entities, different parties. Removing connotation should not increase similarity:')
    controls = [
        ('taxes', 'antitrust_laws'),
        ('carbon', 'guns'),
        ('apple', 'piano'),
        ('beef', 'burger')
    ]
    print_similarities(controls)

In [6]:
cherry_pick(models['w2v'])

Same entity, different parties. Removing connotation should increase similarity:
0.7583  estate_tax          death_tax           
0.6868  undocumented        illegal_aliens      
0.3852  obamacare           protection_and_affordable
0.4785  socialized_medicine public_option       
0.4932  second_amendment_rightsguns                


Different entities, same party. Removing denotation should increase similarity
entitlement_reform is out of vocabulary. Sorry!
0.5378  religious_freedom   right_to_life       


Different entities, different parties. Removing connotation should not increase similarity:
0.2094  taxes               abortion            
0.1810  carbon              guns                
0.5238  apple               piano               
0.2939  beef                burger              


In [7]:
cherry_pick(models['vanilla deno'])

Same entity, different parties. Removing connotation should increase similarity:
0.9995  estate_tax          death_tax           
0.6701  undocumented        illegal_aliens      
0.9935  obamacare           protection_and_affordable
0.1158  socialized_medicine public_option       
0.9998  second_amendment_rightsguns                


Different entities, same party. Removing denotation should increase similarity
entitlement_reform is out of vocabulary. Sorry!
0.8678  religious_freedom   right_to_life       


Different entities, different parties. Removing connotation should not increase similarity:
0.1889  taxes               abortion            
0.3363  carbon              guns                
0.2952  apple               piano               
0.9531  beef                burger              


In [8]:
cherry_pick(models['deno minus cono'])

Same entity, different parties. Removing connotation should increase similarity:
0.8995  estate_tax          death_tax           
0.8180  undocumented        illegal_aliens      
0.8484  obamacare           protection_and_affordable
0.6634  socialized_medicine public_option       
0.9449  second_amendment_rightsguns                


Different entities, same party. Removing denotation should increase similarity
entitlement_reform is out of vocabulary. Sorry!
0.8614  religious_freedom   right_to_life       


Different entities, different parties. Removing connotation should not increase similarity:
0.8174  taxes               abortion            
0.8542  carbon              guns                
0.4833  apple               piano               
0.7790  beef                burger              
