In [None]:
import os
os.chdir("../")

In [None]:
from src import HighlightedCluBartModelForGenerativeQuestionAnswering
from datasets import load_dataset
from data import WikiTQHighlightedCellsDataset
import json
from utils import process_config
from tqdm import tqdm

In [None]:
with open("configs/wiki_sql_clustering_and_highlighting/tapex.json", "rb") as f:
    config = json.load(f)
config = process_config(config=config)

In [None]:
dataset = load_dataset("wikitablequestions")

In [None]:
train_dataset = WikiTQHighlightedCellsDataset(dataset=dataset, config=config, data_type="train")

In [None]:
import torch
model = HighlightedCluBartModelForGenerativeQuestionAnswering(config)
model.load_state_dict(torch.load("omnitab_best_ckpt/epoch=28.pt", map_location="cpu"))

In [None]:
model.to("cuda:0")

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

In [None]:
from sklearn.cluster import KMeans
import pandas as pd

In [None]:
def latent_space_analysis(index):

    input_ids, attention_mask, token_type_ids, decoder_input_ids, highlighted_cells, labels = train_dataset.__getitem__(index)
    input_ids = input_ids.unsqueeze(0).to("cuda:0")
    attention_mask = attention_mask.unsqueeze(0).to("cuda:0")
    decoder_input_ids = decoder_input_ids.unsqueeze(0).to("cuda:0")
    highlighted_cells = highlighted_cells.unsqueeze(0).to("cuda:0")
    labels = labels.unsqueeze(0).to("cuda:0")

    inputs_embeds = model.model.model.decomposer.embed_tokens(input_ids) * model.model.model.decomposer.embed_scale

    decomposer_outputs = model.model.model.decomposer(input_ids=None,
            attention_mask=attention_mask,
            head_mask=None,
            inputs_embeds=inputs_embeds,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
        )

    # token_scores = model.model.sigmoid(model.model.token_classifier(decomposer_outputs[0]))

    latent_rep = model.model.model.latent_rep_head(decomposer_outputs[0])
    # cluster_labels = torch.norm(latent_rep.unsqueeze(2) - model.model.model.cluster_centers.unsqueeze(0).unsqueeze(0), dim = -1).squeeze().argmin(dim = -1)

    soft_labels_numerator = (1 + torch.norm((latent_rep.unsqueeze(2) - model.model.model.cluster_centers.unsqueeze(0).unsqueeze(0)), dim = -1) / model.model.model.clu_alpha) ** (-(1 + model.model.model.clu_alpha) / 2)
    soft_labels = soft_labels_numerator / torch.sum(soft_labels_numerator, dim = -1).unsqueeze(-1)

    token_scores_1 = model.model.model.token_classifier_score1(latent_rep)
    token_scores_2 = model.model.model.token_classifier_score2(latent_rep)
    gaussian_rvs = model.model.gaussian_dist.sample(token_scores_1.shape).to(token_scores_1.device)
    relevance_logit = gaussian_rvs * token_scores_1 + token_scores_2
    relevance_score = model.model.model.sigmoid(relevance_logit)

    # NOTE: Uncomment as per requirement of the experiment
    
    # relevance_score = (0.7 * relevance_score + 0.3 * highlighted_cells.unsqueeze(-1)).squeeze()


    # cluster_labels = torch.zeros(960)
    # cluster_labels[relevance_score >= relevance_score.mean() - 0.018] = 1



    # x = (latent_rep - model.model.model.cluster_centers) ** 2
    # print(x.shape)
    # return

    latent_rep = latent_rep.squeeze().detach().cpu().numpy()
    tsne = TSNE(n_components=2, random_state=0)

    # Fit and transform your data
    tsne_result = tsne.fit_transform(latent_rep)

    # Fit and transform your mean vectors
    # mean_vector1_tsne = tsne.transform([model.model.model.cluster_centers[0]])
    # mean_vector2_tsne = tsne.transform([model.model.model.cluster_centers[1]])
    # print(cluster_labels)
    # return

    kmeans = KMeans(n_clusters=2, init='k-means++', max_iter=300, n_init=10, random_state=0)
    # print(tsne_result.shape)
    cluster_labels = kmeans.fit(tsne_result).labels_
    # print(cluster_labels)
    
    plt.scatter(tsne_result[:, 0][cluster_labels == 0], tsne_result[:, 1][cluster_labels == 0], label='Non-relevant tokens', alpha=0.3, c = "b", s = 15)
    plt.scatter(tsne_result[:, 0][cluster_labels == 1], tsne_result[:, 1][cluster_labels == 1], label='Relevant tokens', alpha=0.3, c = "r", s = 15)
    # plt.scatter(tsne_result[:, 0], tsne_result[:, 1], label='Data', alpha=0.3, c = cluster_labels)

    # Scatter plot for mean vectors
    # plt.scatter(mean_vector1_tsne[0, 0], mean_vector1_tsne[0, 1], c='red', marker='x', label='Mean Vector 1')
    # plt.scatter(mean_vector2_tsne[0, 0], mean_vector2_tsne[0, 1], c='blue', marker='x', label='Mean Vector 2')

    # Add labels, legend, and title
    # plt.xlabel('t-SNE Dimension 1')
    # plt.ylabel('t-SNE Dimension 2')
    
    # plt.title('t-SNE Visualization')
    plt.axis('off')
    plt.legend()
    plt.show()

    # plt.savefig("clustering_fig.png")

    # print(relevance_score)
    # print(relevance_score.mean())
    # print(relevance_score.min())
    # print(relevance_score.max())

    print(train_dataset.tokenizer.decode(input_ids[0], skip_special_tokens = True))
    print("Unsup: ", train_dataset.tokenizer.decode(input_ids[0][relevance_score.squeeze() >= relevance_score.squeeze().mean()], skip_special_tokens = True))
    # print("Unsup: ", train_dataset.tokenizer.decode(input_ids[0][relevance_score.squeeze() >= 0.9], skip_special_tokens=True))
    print("Highlighted cells: ", train_dataset.tokenizer.decode(input_ids[0][highlighted_cells.squeeze() == 1]))
    print(train_dataset.tokenizer.decode(labels[labels != -100]))

    table_column_names = dataset["train"][index]["table"]["header"]
    table_content_values = dataset["train"][index]["table"]["rows"]

    table = pd.DataFrame.from_dict({str(col).lower(): [str(table_content_values[j][i]).lower() for j in range(len(table_content_values))] for i, col in enumerate(table_column_names)})

    display(table)    

In [None]:
latent_space_analysis(23)

In [None]:
latent_space_analysis(32)

In [None]:
latent_space_analysis(1009)

In [None]:
latent_space_analysis(3057)

In [None]:
latent_space_analysis(1000)

In [None]:
for i in range(len(dataset["train"])):
    if dataset["test"][i]["question"].lower().strip() == "in how many games did the winning team score more than 4 points?":
        print(i)
        break

In [None]:
latent_space_analysis(13)

In [None]:
latent_space_analysis(931)

In [None]:
latent_space_analysis(2345)

In [None]:
latent_space_analysis(435)

In [None]:
latent_space_analysis(1009)