In [1]:
import sys
sys.path.append("..")
sys.path.append("../../eqnet")

In [2]:
import os
import gzip
import csv
import torch
import numpy as np
import pandas as pd
import plotly.express as px
from expemb import ExpEmbTx, Tokenizer
from sklearn.decomposition import PCA
from typing import List

In [3]:
class EmbeddingPlot:
    def __init__(self, model_dir: str, test_file: str, outfile_prefix: str, ckpt_name: str = "best", dim_red_algo = "pca"):
        self.model_dir = model_dir
        self.test_file = test_file
        self.outfile_prefix = outfile_prefix
        self.ckpt_name = ckpt_name
        self.dim_red_algo = dim_red_algo
        
        
    def load_test_data(self):
        exp_list, exp_cls = [], []
        with gzip.open(self.test_file, "rt") as file:
            csvreader = csv.reader(file)
            for row in csvreader:
                exp_list.append(row[0])
                exp_cls.append(row[1])
                
        return exp_list, exp_cls
    
    
    @torch.no_grad()
    def get_embeddings(self, model: ExpEmbTx, exp_list: List[str]):
        emb_list = []
        for exp in exp_list:
            emb = model.get_embedding([exp], mode = "max")[0]
            emb_list.append(emb.cpu().numpy())
            
        return np.array(emb_list)
    
    
    def reduce_dimensionality(self, emb_list: List[torch.Tensor]):
        if self.dim_red_algo == "pca":
            pca = PCA(n_components = 2)
            return pca.fit_transform(emb_list)
        
    
    
    def generate_plot(self, emb_list: np.ndarray, exp_list: List[str], exp_cls: List[str], tokenizer: Tokenizer):
        df = pd.DataFrame({
            "Component 1": emb_list[:, 0],
            "Component 2": emb_list[:, 1],
            "Class": exp_cls,
            "Exp": [f"{str(tokenizer.prefix_to_sympy(eq, evaluate = False))}" for eq in exp_list],
        })

        fig = px.scatter(
            df,
            x="Component 1",
            y="Component 2",
            color="Class",
            symbol="Class",
            hover_data=["Exp"],
            template="plotly_white",
            width=1000,
            height=900,
            color_discrete_sequence=["#dd8452", "#55a868", "#4c72b0", "#c44e52"]
        )
        fig.update_xaxes(showgrid = False, zeroline = False, linewidth = 0.5, linecolor = 'gray')
        fig.update_yaxes(showgrid = False, zeroline = False, linewidth = 0.5, linecolor = 'gray')
        fig.update_traces(marker = dict(size=4))
        fig.update_layout(
            font = dict(
                size=20
            ),
            legend = dict(
                title_text = '',
                x = 1,
                y = 1.1,
                traceorder = "normal",
                orientation = "h",
                yanchor = "bottom",
                xanchor = "right",
                itemsizing = "constant",
                font = dict(
                    size = 30
                ),
            )
        )
        fig.show()
        fig.write_image(os.path.join(self.model_dir, f"{self.outfile_prefix}_pca_plot.svg"))
        fig.write_html(os.path.join(self.model_dir, f"{self.outfile_prefix}_pca_plot.html"))
        
        
    def plot(self):
        # Load model
        print("Loading model...")
        ckpt_path = os.path.join(self.model_dir, f"saved_models/{self.ckpt_name}.ckpt")
        tokenizer = torch.load(ckpt_path)["tokenizer"]
        model = ExpEmbTx.load_from_checkpoint(ckpt_path, tokenizer = tokenizer)
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        model.to(device)
        
        print("Loading test data...")
        # Load expressions and their class
        exp_list, exp_cls = self.load_test_data()
        
        # Get embeddings
        print("Computing embedding vectors...")
        emb_list = self.get_embeddings(model, exp_list)
        print(f"Embedding dimensions: {emb_list.shape}")
        print("Reducing dimensionality...")
        emb_list = self.reduce_dimensionality(emb_list)
        print(f"Size after dimensionality reduction: {emb_list.shape}")
        
        self.generate_plot(emb_list, exp_list, exp_cls, tokenizer)

In [4]:
emb_plot = EmbeddingPlot("../models/equivexp/20221126-104533874570/", "../data/embedding_plot.txt.gz", "expembe")
emb_plot.plot()

Loading model...
Loading test data...
Computing embedding vectors...
Embedding dimensions: (7030, 512)
Reducing dimensionality...
Size after dimensionality reduction: (7030, 2)


In [5]:
emb_plot = EmbeddingPlot("../models/equivexp/20221127-082715223727//", "../data/embedding_plot.txt.gz", "expemba")
emb_plot.plot()

Loading model...
Loading test data...
Computing embedding vectors...
Embedding dimensions: (7030, 512)
Reducing dimensionality...
Size after dimensionality reduction: (7030, 2)
