In [None]:
import json
import numpy as np
import matplotlib.pyplot as plt
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

plt.rcParams['figure.dpi'] = 200

torch.set_grad_enabled(False)

import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import sentence_transformers
model = sentence_transformers.SentenceTransformer("Qwen/Qwen3-Embedding-0.6B", device=device)

In [None]:
# Run this cell for the CORE dataset

with open('CORE_edited.json', 'r') as f:
    data = json.load(f)

label_fullnames = {
    "PB": "Personal Blog",
    "LY": "Lyrical",
    "SP": "Spoken",
    "IT": "Interview",
    "ID": "Discussion",
    "NA": "Narrative",
    "NE": "News Report",
    "SR": "Sports Report",
    "NB": "Blog",
    "HI": "Instructional",
    "RE": "Recipe",
    "IN": "Informational",
    "EN": "Encyclopedia",
    "RA": "Research",
    "DTP": "Thing/person",
    "FI": "FAQ",
    "LT": "Legal",
    "OP": "Opinion",
    "RV": "Review",
    "OB": "Opinion",
    "RS": "Religious",
    "AV": "Advice",
    "IP": "Persuasion",
    "DS": "To sell",
    "ED": "Editorial",
}

In [None]:
# Run this cell for the synthetic dataset

with open('synthetic.json', 'r') as f:
    data = json.load(f)

label_fullnames = {
    "code": "code",
    "explanatory": "explanatory",
    "instructional": "instructional",
    "narrative": "narrative",
    "speech": "speech",
}

In [None]:
labels = [item['label'] for item in data]
unique_labels = sorted(set(labels))

texts = [item['text'][:1000] for item in data]
embeddings = model.encode(texts, show_progress_bar=True, batch_size=16)

In [None]:
import phate
phate_operator = phate.PHATE(n_jobs=-1)
phate_data = phate_operator.fit_transform(embeddings)

In [None]:
plt.figure(figsize=(6, 4))

handles = []
for idx, label in enumerate(unique_labels):
    # Get indices for this label, but only take the first 200
    indices = [i for i, l in enumerate(labels) if l == label][:200]
    if indices:  # Only plot if there are any indices
        scatter = plt.scatter(
            phate_data[indices, 0],
            phate_data[indices, 1],
            s=3,
            label=label_fullnames.get(label),
            alpha=0.7,
        )
        handles.append(scatter)

plt.xlabel("PHATE 1")
plt.ylabel("PHATE 2")
plt.xticks([])
plt.yticks([])
plt.legend(loc='best', markerscale=5, fontsize='9.5')
plt.tight_layout()
plt.show()
