In [2]:
import tmap as tm
import torch
import time
import numpy as np
from numpy.random import RandomState
import pandas as pd
import re

import graphistry

import os
from dotenv import load_dotenv
load_dotenv()  # take environment variables from .env.

from pyvis import network as net
import networkx as nx
from sklearn import manifold
from sklearn.decomposition import PCA

%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import ticker
plt.rcParams['figure.figsize'] = [20, 20]

In [3]:
graphistry.register(api=3, protocol="https", server="hub.graphistry.com", username=os.environ['GRAPHISTRY_USERNAME'], password=os.environ['GRAPHISTRY_PASSWORD'])

In [4]:
import wandb
from transformers import GPT2Tokenizer
from soft_prompt_tuning.soft_prompt_opt import ParaphraseOPT

# Init embedding space

In [5]:
wandb.init(project="test-popt-dump", entity="clyde013", name="test-model", allow_val_change=True)
wandb.config.update({"embedding_n_tokens": 111}, allow_val_change=True)

#checkpoint = r"training_checkpoints/30-05-2022-1.3b/soft-opt-epoch=179-val_loss=1.397.ckpt"
checkpoint_111 = r"training_checkpoints/optimize/soft-opt-epoch=029-val_loss=0.487-optimizer_type=Adam-embedding_n_tokens=111.ckpt"
checkpoint_59 = r"training_checkpoints/optimize/soft-opt-epoch=029-val_loss=0.793-optimizer_type=Adam-embedding_n_tokens=59.ckpt"
model_name = "facebook/opt-1.3b"

torch.cuda.empty_cache()

AVAIL_GPUS = min(1, torch.cuda.device_count())

model = ParaphraseOPT.load_from_custom_save(model_name, checkpoint_111)
model = model.eval()
learned_embeddings_111 = model.model.soft_embedding.learned_embedding.detach()

wandb.config.update({"embedding_n_tokens": 59}, allow_val_change=True)
model = ParaphraseOPT.load_from_custom_save(model_name, checkpoint_59)
model = model.eval()
learned_embeddings_59 = model.model.soft_embedding.learned_embedding.detach()

# default_model = ParaphraseOPT(model_name)

tokenizer = GPT2Tokenizer.from_pretrained(model_name)

original_embeddings = model.model.soft_embedding.wte.weight.detach()

[34m[1mwandb[0m: Currently logged in as: [33mclyde013[0m. Use [1m`wandb login --relogin`[0m to force relogin


Downloading:   0%|          | 0.00/653 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.45G [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/opt-1.3b were not used when initializing SoftOPTModelWrapper: ['model.decoder.final_layer_norm.weight', 'model.decoder.final_layer_norm.bias']
- This IS expected if you are initializing SoftOPTModelWrapper from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SoftOPTModelWrapper from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/opt-1.3b were not used when initializing SoftOPTModelWrapper: ['model.decoder.final_layer_norm.weight', 'model.decoder.final_layer_norm.bias']
- This IS expected if you are initializing SoftOPTModelWrapper from the checkpoint of a model trained on another task or with another architect

# PCA & TSNE Testing

As recommended by sklearn documentation, we first decompose the high dimensionality embeddings to a more reasonable size of 50 before running TSNE on it. TSNE however does not provide a function to fit the data and then transform any additional data fed into it like PCA does.

In [6]:
pca = PCA(n_components=50)
pca.fit(original_embeddings)

trans_original_embeddings = pca.transform(original_embeddings)
trans_learned_embeddings_59 = pca.transform(learned_embeddings_59)
trans_learned_embeddings_111 = pca.transform(learned_embeddings_111)

In [10]:
trans_original_embeddings.shape, trans_learned_embeddings_111.shape, trans_learned_embeddings_59.shape

((50272, 50), (111, 50), (59, 50))

In [18]:
concat_embeddings = np.concatenate([trans_original_embeddings, trans_learned_embeddings_59, trans_learned_embeddings_111])

In [19]:
tsne = manifold.TSNE(n_components=2, learning_rate='auto', init='pca')
embedded = tsne.fit_transform(concat_embeddings)
embedded.shape

(50442, 2)

In [33]:
df = pd.DataFrame(embedded, columns=['x', 'y']).reset_index()
df["title"] = np.concatenate([np.zeros(50272), np.ones(111), np.full(59, 2)])
df["title"] = df["title"].astype(int)
df

Unnamed: 0,index,x,y,title
0,0,0.516557,87.041412,0
1,1,1.971469,80.593559,0
2,2,70.881744,-50.733025,0
3,3,1.032740,87.292603,0
4,4,73.105972,-49.037640,0
...,...,...,...,...
50437,50437,75.097878,-53.299850,2
50438,50438,53.052044,69.566902,2
50439,50439,73.531258,-49.463772,2
50440,50440,46.075760,54.189522,2


In [34]:
df_edges = pd.DataFrame({'source':[0], 'target':[0]})
df_edges

Unnamed: 0,source,target
0,0,0


In [35]:
graph = graphistry.bind(source="source", destination="target", point_x="x", point_y="y", point_title="index")
graph = graph.edges(df_edges).nodes(df, 'index')
graph = graph.encode_point_color('title', categorical_mapping={'0': '#ff9999', '2': '#99F', '1': '#32a834'}, default_mapping='silver')
graph = graph.encode_point_size('title', categorical_mapping={'0': 1, '2': 3, '1': 3}, default_mapping=1)
graph = graph.settings(url_params={
      'play': 0,
      'menu': True, 'info': False,
      'showArrows': False,
      'pointSize': 0.07, 'edgeCurvature': 0.01, 'edgeSize': 1.0,
      'edgeOpacity': 0.5, 'pointOpacity': 0.9,
      'lockedX': True, 'lockedY': True, 'lockedR': False,
      'linLog': False, 'strongGravity': False, 'dissuadeHubs': False,
      'edgeInfluence': 1.0, 'precisionVsSpeed': 1.0, 'gravity': 1.0, 'scalingRatio': 0.5,
      'showLabels': True, 'showLabelOnHover': True,
      'showPointsOfInterest': False, 'showPointsOfInterestLabel': False, 'showLabelPropertiesOnHover': True,
      'pointsOfInterestMax': 0
    })
graph.plot()

The inconsistencies in the relative positions of the embedding points could be attributed to phase 4 in the tmap algorithm described [here](https://jcheminf.biomedcentral.com/articles/10.1186/s13321-020-0416-x#Sec2). They conduct kruskal's to get a constructed MST tree, reducing computation times by large margins, before using a spring based graph layout alogrithm (probably the Fruchterman-Reingold force-directed algorithm) to plot out the points. But the tradeoff is that if points are not connected on the MST then their relative distances are not accounted for when placing points on the graph, only the neighbours. While this *might* yield multiple locally optimal solutions, it is possible that relative distances between global clusters of points are not accounted for, resulting in vastly different positionings of embedding points. While their neighbours might always be close together no matter the random initialisation, their global position might vary, as there is no way to solve for a deterministic solution without the connections of a fully connected graph (which is too expensive to compute).

Have to create both an MST and the complete graph. We have 2 options:
1. using the weights of the complete graph to generate a layout and then only plotting the MST's edge connections (very computationally expensive)
1. following the original paper implementation of using MST for both layout generation and plotting edge connections.

When using `spring_connection` we have to invert the edge weights as the edge weights become spring attractive coefficients, whereas in other layouts such as `kamada_kawai` edge weights are used as cost functions.

# Custom class

Create a custom class to deal with visualisations. Should be able to initialise from a model's fixed embeddings. Then in the init function construct a minhash encoder and lsh forest (with a seed), and then using the lsh forest create an initial MST use the networkx layouts to find the x, y positions, then use those as anchor points. Should then be able to take in inputs of learned embeddings with another function, and then using `query linear scan` find the closest neighbours of all the passed in learned embeddings, from there add the points to the MST, either without trying to form another MST (just leave knn connections) or find a locally optimal MST solution. Output graphs should ideally be pyvis networks as they allow for interactive visualisations.

In [None]:
class vis():    
    def __init__(self, fixed_embeddings, dims:int=512, load_path:str=None, save_path:str=r"visualisations/vis_lf_fixed.dat", seed:int=69, verbose:bool=True):
        if verbose: print("seeding...")
        self.seed = seed
        np.random.seed(self.seed)
        
        self.fixed_embeddings = fixed_embeddings
        self.enc = tm.Minhash(self.fixed_embeddings.size(dim=1), seed=self.seed, sample_size=dims)
        self.lf = tm.LSHForest(dims * 2, 128)
        
        # init the LSHForest
        if load_path is None:
            tmp = []
            if verbose: print("batch add and indexing...")
            for i in fixed_embeddings:
                tmp.append(tm.VectorFloat(i.tolist()))
            self.lf.batch_add(self.enc.batch_from_weight_array(tmp))
            self.lf.index()
            self.lf.store(save_path)
        else:
            if verbose: print(f"loading from {load_path}...")
            self.lf.restore(load_path)
            
        # Construct the k-nearest neighbour graph
        if verbose: print("Getting KNN graph...")
        knng_from = tm.VectorUint()
        knng_to = tm.VectorUint()
        knng_weight = tm.VectorFloat()
        _ = self.lf.get_knn_graph(knng_from, knng_to, knng_weight, 10)        

        # find the MST of the knn graph
        if verbose: print("Finding MST...")
        self.g_mst = self.create_mst([i for i in zip(knng_from, knng_to, knng_weight) if i[0] != i[1]])

        # find x, y positions of the fixed embeddings layout
        if verbose: print("Generating layout...")
        self.pos = nx.nx_agraph.graphviz_layout(self.g_mst, prog="sfdp")
        self.fixed = list(self.pos.keys())
    
    def graph_learned_embeddings(self, learned_embeddings, type_str:str, g: nx.Graph=None):
        # create deepcopy of g_mst
        if g is None:
            g = nx.Graph(self.g_mst)
        # index to begin from (since indexes are 0 indexed we start from len)
        index = len(g)
        for i in learned_embeddings:
            query_hash = self.enc.from_weight_array(tm.VectorFloat(i.tolist()))
            # query_linear_scan returns list of tuples(weight, neighbour). invert the weights because spring layout.
            scan = self.lf.query_linear_scan(query_hash, 1)[0]
            g.add_edge(index, scan[1], weight=1-scan[0])
            g.nodes[index]['type_str'] = type_str
            index += 1
        
        return g
            
    # kruskals algorithm for finding MST
    def create_mst(self, edgelist):
        self.par = [i for i in range(0, self.fixed_embeddings.size(dim=0)+1)]
        self.rnk = [0 for i in range(0, self.fixed_embeddings.size(dim=0)+1)]
        edges = sorted(edgelist, key=lambda x:x[2])
        g_mst = nx.Graph()

        for edge in edges:
            x = edge[0]
            y = edge[1]

            if self._find_par(x) != self._find_par(y):
                # append edge to the mst. invert the weights because spring layout.
                g_mst.add_edge(edge[0], edge[1], weight=edge[2])
                self._join(x, y)

        return g_mst
    
    def _find_par(self, i):
        if self.par[i] == i:
            return i
        self.par[i] = self._find_par(self.par[i])
        return self.par[i]

    def _join(self, x, y):
        x = self._find_par(x)
        y = self._find_par(y)
        if x == y:
            return
        if self.rnk[x] < self.rnk[y]:
            self.par[x] = y
        else:
            self.par[y] = x
        if self.rnk[x] == self.rnk[y]:
            self.rnk[x] += 1

In [None]:
visualisation = vis(original_embeddings, load_path=r"visualisations/vis_lf_fixed.dat")

In [None]:
g = visualisation.graph_learned_embeddings(learned_embeddings_111, "111")
g = visualisation.graph_learned_embeddings(learned_embeddings_59, "59", g)
pos = nx.spring_layout(g, fixed=visualisation.g_mst.nodes, pos=visualisation.pos, k=0.0001)

mapping = {v: re.sub(r'(\\xc4|\\xa0)|[\'\"\\]', '', repr(k.encode("utf-8"))[2:-1]) for k, v in tokenizer.get_vocab().items()}

for n, p in pos.items():
    g.nodes[n]['x'] = float(p[0])
    g.nodes[n]['y'] = float(p[1])
    # there are some unreachable tokens as the tokenizer's vocab size does not match that of the config
    g.nodes[n]['title'] = mapping[n] if n < len(tokenizer) else f"learned embedding {n-len(visualisation.g_mst.nodes)}"
    # denote learned vs fixed embeddings
    if n < len(visualisation.g_mst.nodes):
        g.nodes[n]['type_str'] = 'F'

I really cannot tell if the bugginess comes from networkx integration with graphistry or what, so to be safe everything is being converted to pandas dataframes and fed in that way, which has actual documentation support.

In [None]:
edges = nx.to_pandas_edgelist(g)
nodes = pd.DataFrame.from_dict(dict(g.nodes(data=True)), orient='index').reset_index(level=0)
nodes

In [None]:
graph = graphistry.bind(source='source', destination='target', point_x="x", point_y="y", point_title="title")
graph = graph.edges(edges).nodes(nodes, 'index')
graph = graph.encode_point_color('type_str', categorical_mapping={'F': '#ff9999', '59': '#99F', '111': '#32a834'}, default_mapping='silver')
graph = graph.encode_point_size('type_str', categorical_mapping={'F': 1, '59': 3, '111': 3}, default_mapping=1)
graph = graph.settings(url_params={
      'play': 0,
      'menu': True, 'info': False,
      'showArrows': False,
      'pointSize': 0.07, 'edgeCurvature': 0.01, 'edgeSize': 1.0,
      'edgeOpacity': 0.5, 'pointOpacity': 0.9,
      'lockedX': True, 'lockedY': True, 'lockedR': False,
      'linLog': False, 'strongGravity': False, 'dissuadeHubs': False,
      'edgeInfluence': 1.0, 'precisionVsSpeed': 1.0, 'gravity': 1.0, 'scalingRatio': 0.5,
      'showLabels': True, 'showLabelOnHover': True,
      'showPointsOfInterest': False, 'showPointsOfInterestLabel': False, 'showLabelPropertiesOnHover': True,
      'pointsOfInterestMax': 0
    })
graph.plot()

In [85]:
graph = graphistry.bind(source='source', destination='target', point_x="x", point_y="y", point_title="title")
graph = graph.edges(edges).nodes(nodes, 'index')
graph = graph.encode_point_color('type_str', categorical_mapping={'F': '#ff9999', '59': '#99F', '111': '#32a834'}, default_mapping='silver')
graph = graph.encode_point_size('type_str', categorical_mapping={'F': 1, '59': 3, '111': 3}, default_mapping=1)
graph = graph.settings(url_params={
      'play': 0,
      'menu': True, 'info': False,
      'showArrows': False,
      'pointSize': 0.07, 'edgeCurvature': 0.01, 'edgeSize': 1.0,
      'edgeOpacity': 0.5, 'pointOpacity': 0.9,
      'lockedX': True, 'lockedY': True, 'lockedR': False,
      'linLog': False, 'strongGravity': False, 'dissuadeHubs': False,
      'edgeInfluence': 1.0, 'precisionVsSpeed': 1.0, 'gravity': 1.0, 'scalingRatio': 0.5,
      'showLabels': True, 'showLabelOnHover': True,
      'showPointsOfInterest': False, 'showPointsOfInterestLabel': False, 'showLabelPropertiesOnHover': True,
      'pointsOfInterestMax': 0
    })
graph.plot()