In [None]:
from promptore_utils_n import *


if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

print(device)
    # Simulating argparse in a notebook environment
class Args:
    def __init__(self):
        self.seed = 0  # Random seed
        self.n_rel = 160  # Number of relations/clusters
        self.max_len = 300  # Maximum length of tokens
        self.auto_n_rel = False  # Set to True if you want to estimate the number of clusters
        self.min_n_rel = 777  # Minimum number of relations to estimate (if auto_n_rel=True)
        self.max_n_rel = 1000  # Maximum number of relations to estimate (if auto_n_rel=True)
        self.step_n_rel = 5  # Step size for relation estimation (if auto_n_rel=True)
        self.files = []  # Files to load from Fewrel (leave empty for now)
        self.data = "ls"

args = Args()

# Read wikiphi3 files
df_dataset = parse_wikiphi3_with_dynamic_markers("DATA/wikiphi3_data_49410.pickle", "[E1] ", " [/E1]", "[E2] ", " [/E2]", "[MASK]")
# parse_labelstudio_with_dynamic_markers("DATA/project-6-at-2025-04-22-13-14-67864b63.json", "[E1] ", " [/E1]", "[E2] ", " [/E2]", "[MASK]") # parse_wikiphi3("DATA/wikiphi3_data_49410.pickle")
# parse_labelstudio("DATA/project-6-at-2025-04-22-13-14-67864b63.json")

# Compute relation embeddings
print("Compute relation embeddings")
relation_embeddings = compute_promptore_relation_embedding(
    df_dataset, 
    template="{sent}", 
    max_len=args.max_len, 
    device=device, 
    emb=4, 
    data="wikiphi3")

# Compute clustering
print("Compute clustering")
if args.auto_n_rel:
    n_rel = estimate_n_rel(
        relation_embeddings, args.seed, (args.min_n_rel, args.max_n_rel), args.step_n_rel)
    print(f'Estimated n_rel={n_rel}')
else:
    n_rel = args.n_rel

print("Predict labels")
predicted_labels = compute_kmeans_clustering(relation_embeddings, n_rel, args.seed)

# Evaluation
b3, b3_prec, b3_rec, v, v_hom, v_comp, ari = evaluate_promptore(relation_embeddings, 
                                                                predicted_labels)
print(f'B3: prec={b3_prec} rec={b3_rec} f1={b3}')
print(f'V-measure: hom={v_hom} comp={v_comp} f1={v}')
print(f'ARI={ari}')



In [None]:
len(predicted_labels)
relation_embeddings["predicted_labels"] = predicted_labels

In [None]:
relation_embeddings[relation_embeddings["predicted_labels"] == 34].sort_values(by="output_r")[["sentence", "head", "output_r", "tail"]]

## Visuals

In [None]:
from sklearn.preprocessing import MinMaxScaler
summary = relation_embeddings.groupby('predicted_labels').agg(
    total_instances=('output_r', 'count'),
    unique_output_r=('output_r', pd.Series.nunique)
).reset_index()



# Apply Min-Max scaling to 'total_instances'
scaler = MinMaxScaler()
summary['total_instances_scaled'] = scaler.fit_transform(summary[['total_instances']])

print(summary)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
# Bar chart comparing total instances vs unique output_r values
plt.figure(figsize=(10, 6))
sns.barplot(data=summary.melt(id_vars='predicted_labels'), x='predicted_labels', y='value', hue='variable')
plt.title('Cluster-wise: Total Instances and Unique output_r Values')
plt.xlabel('Predicted Label (Cluster)')
plt.ylabel('Count')
plt.xticks(rotation=0)
plt.legend(title='')
plt.tight_layout()
plt.show()


In [None]:
summary['diversity_ratio'] = (summary['unique_output_r'] / summary['total_instances']) / summary["total_instances_scaled"]

summary.sort_values(by='diversity_ratio', ascending=True, inplace=True)
plt.figure(figsize=(16, 10))
sns.barplot(data=summary, x='predicted_labels', y='diversity_ratio', width=0.5, palette='viridis', order=summary.sort_values(by='diversity_ratio', ascending=False).predicted_labels)
plt.title('Diversity Ratio per Cluster (Unique output_r / Total Instances)')
plt.xlabel('Predicted Label (Cluster)')
plt.ylabel('Diversity Ratio')
plt.xticks(rotation = 90)
plt.ylim(0, 1.05)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()


In [None]:
summary

In [None]:
df_dataset.iloc[1397]

print(df_dataset.iloc[1397]["sent"])
print(df_dataset.iloc[1397]["r"])
print(len(df_dataset.iloc[1397]["sent"]))

for _, a in df_dataset.sample(20).iterrows():
    print(len(a["sent"]))
    