In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import geopandas as gpd
from scipy.spatial import Voronoi, voronoi_plot_2d
from scipy.special import softmax
import torch
from transformers import AutoTokenizer

CLUSTER_DATA_FILE = 'models-clustering-v7m-3000/cluster_df.csv'
BYT5_MODEL_FILE = 'models/byt5-class-2-best-model'
BYT5_TOKENIZER_NAME = 'google/byt5-small'
MAX_LENGTH = 140

cluster_df = pd.read_csv(CLUSTER_DATA_FILE)
device = 'cpu'
byt5_model = torch.load(BYT5_MODEL_FILE, map_location=torch.device(device))
byt5_tokenizer = AutoTokenizer.from_pretrained(BYT5_TOKENIZER_NAME)
def preprocess_text(text):
    """
    Preprocess the text by converting it to lower case.
    """
    return text.lower()


def predict_cluster_probabilities(text):
    """
    Predict cluster probabilities for the given text using the BYT5 model.
    """
    text = preprocess_text(text)
    input_ids = byt5_tokenizer([text], truncation=True, padding="max_length", max_length=MAX_LENGTH,
                               return_tensors='pt')['input_ids']
    input_ids = input_ids.to(device).unsqueeze(0)
    with torch.no_grad():
        predictions = byt5_model(input_ids)
    softmax = torch.nn.Softmax(dim=1)(torch.tensor(predictions.detach().cpu())).numpy()
    return softmax[0]

def plot_voronoi_diagram(cluster_data, prediction_scores):
    # Sort categories based on prediction scores
    top_categories = np.argsort(-1 * prediction_scores)[:25]
    remaining_categories = np.argsort(-1 * prediction_scores)[25:]
    category_1 = [(cluster_data.iloc[i]['lng'], cluster_data.iloc[i]['lat']) for i in top_categories]
    category_2 = [(cluster_data.iloc[i]['lng'], cluster_data.iloc[i]['lat']) for i in remaining_categories]

    # Combine data for Voronoi calculation
    all_points = np.array(category_1 + category_2)
    vor = Voronoi(all_points)

    # Plot world map and Voronoi regions
    fig, ax = plt.subplots(figsize=(14, 10))
    world = gpd.read_file(gpd.datasets.get_path("naturalearth_lowres"))
    world.plot(color="lightgrey", ax=ax)
    voronoi_plot_2d(vor, ax=ax, show_points=False, show_vertices=False, line_colors='gray', line_alpha=0.4)

    # Fill regions for the top category with varying transparency
    for i in range(len(category_1)):
        region_index = vor.point_region[i]
        region = vor.regions[region_index]
        if -1 not in region:  # Exclude regions outside Voronoi diagram bounds
            polygon = [vor.vertices[j] for j in region]
            # Transparency correlated with prediction probability
            alpha = prediction_scores[top_categories[i]]
            plt.fill(*zip(*polygon), color="yellow", alpha=alpha*1.5)

    # Set plot limits and display
    ax.set_xlim(-180, 180)
    ax.set_ylim(-90, 90)
    plt.show()





In [None]:
sample_prediction = predict_cluster_probabilities("I live in Sweden")
plot_voronoi_diagram(cluster_df, sample_prediction)
