In [1]:
from pathlib import Path
from typing import Set, Tuple, Union, List, Dict, Iterable, Optional

import numpy as np
import torch
from torch.nn import functional as F
from tqdm.notebook import tqdm
import editdistance

from decomposer import Decomposer, DecomposerConfig
# from recomposer import Recomposer, RecomposerConfig
from utils.improvised_typing import Scalar, Vector, Matrix, R3Tensor

DEVICE = torch.device('cuda:1')

In [2]:
BASE_DIR = Path.home() / 'Research/congressional_adversary/results'
base_path = BASE_DIR / 'news/validation/3bins/-3c L1/epoch3.pt'
model = torch.load(base_path)['model'].to(DEVICE)



In [3]:
def nearest_neighbors(
        self,
        query_ids: Vector,
        top_k: int = 10,
        verbose: bool = False,
        ) -> Matrix:
    with torch.no_grad():
        query_vectors = self.embedding(query_ids)
        try:
            cos_sim = F.cosine_similarity(
                query_vectors.unsqueeze(1),
                self.embedding.weight.unsqueeze(0),
                dim=2)
        except RuntimeError:  # insufficient GPU memory
            cos_sim = torch.stack([
                F.cosine_similarity(qv.unsqueeze(0), self.embedding.weight)
                for qv in query_vectors])
        cos_sim, neighbor_ids = cos_sim.topk(k=top_k, dim=-1)
        if verbose:
            return cos_sim[:, 1:], neighbor_ids[:, 1:]
        else:  # excludes the first neighbor, which is always the query itself
            return neighbor_ids[:, 1:]


# def init_deno_grounding(self, top_k = 10):  
#     self.deno_grounding: Dict[int, Set[int]] = {} 
# #     self.deno_grounding: List[Set[int]] = []  # only when iterating all vocab ids
#     all_vocab_ids = torch.arange(self.embedding.num_embeddings, device=DEVICE)
#     for qid in tqdm(all_vocab_ids, desc='Initializing deno grounding'):
#         qv = self.pretrained_embed(qid)
#         qid = qid.item()
#         qw = self.id_to_word[qid]
#         cos_sim = F.cosine_similarity(qv.unsqueeze(0), self.pretrained_embed.weight)
#         cos_sim, neighbor_ids = cos_sim.topk(k=top_k + 5, dim=-1)
#         neighbor_ids = [
#             nid for nid in neighbor_ids.tolist()
#             if editdistance.eval(qw, self.id_to_word[nid]) > 3]
#         self.deno_grounding[qid] = set(neighbor_ids[:top_k])
# #         self.deno_grounding.append(set(neighbor_ids[:top_k]))


def init_deno_grounding(self, query_ids: Vector, top_k = 10) -> Dict[int, Set[int]]:  
    deno_grounding: Dict[int, Set[int]] = {} 

    with torch.no_grad():
        query_vectors = self.embedding(query_ids)
    cos_sim = F.cosine_similarity(
        query_vectors.unsqueeze(1),
        self.pretrained_embed.weight.unsqueeze(0),
        dim=2)
    cos_sim, top_neighbor_ids = cos_sim.topk(k=top_k, dim=-1)
    
    for query_index, sorted_target_indices in enumerate(top_neighbor_ids):
        qid = query_ids[query_index].item()
        qw = self.id_to_word[qid]
        neighbor_ids = [
            nid for nid in sorted_target_indices.tolist()
            if editdistance.eval(qw, self.id_to_word[nid]) > 3]
        deno_grounding[qid] = set(neighbor_ids[:top_k])
    return deno_grounding



def deno_homogeneity(
        self,
        query_ids: Vector,
        top_k: int = 10
        ) -> float:
    neighbor_ids = self.nearest_neighbors(query_ids, top_k + 5)
    deno_homogeneity = []
    for query_index, sorted_neighbor_indices in enumerate(neighbor_ids):
        query_id = query_ids[query_index].item()
        query_word = self.id_to_word[query_id]
        query_deno: Set[int] = self.deno_grounding[query_id]

        num_neighbors = 0
        
        deno_overlap = len([
            nid for nid in sorted_neighbor_indices.tolist()
            if editdistance.eval(query_word, self.id_to_word[nid]) > 3
                and nid in query_deno]) / len(query_deno)
        deno_homogeneity.append(deno_overlap)
    return deno_homogeneity #np.mean(deno_homogeneity)

In [5]:
query = torch.tensor([42, 32, 52]).to(DEVICE)
dg = init_deno_grounding(model, query)
model.deno_grounding = dg
deno_homogeneity(model, query)

[1.0, 1.0, 0.7777777777777778]

In [None]:
dg

In [None]:
nearest_neighbors(model, query)

In [None]:
model.deno_grounding