In [2]:
import pandas as pd
import numpy as np
import plotly.express as px
from sklearn.preprocessing import LabelEncoder
from sklearn.manifold import TSNE



def plot_tsne(points, labels, n_dim=2, seed=None):
    tsne = TSNE(n_components=n_dim, random_state=seed)
    embeddings_2d = tsne.fit_transform(points)
    
    # Step 1: Encode labels
    label_encoder = LabelEncoder()
    encoded_labels = label_encoder.fit_transform(labels)
    # label_names = label_encoder.inverse_transform(np.unique(encoded_labels))

    # Step 2: Prepare DataFrame for Plotly
    df = pd.DataFrame(embeddings_2d, columns=["TSNE-1", "TSNE-2"])
    df['Label'] = [label_encoder.inverse_transform([l])[0] for l in encoded_labels]

    # Optional: Add color info using plotly's built-in color palettes or your own
    # If you want to customize colors:
    # from matplotlib.cm import get_cmap
    # clrs = generate_gradient_colors(len(label_names), cmap_name='Paired')
    # df['Color'] = [clrs[l] for l in encoded_labels]

    # Step 3: Plot using Plotly
    fig = px.scatter(
        df,
        x="TSNE-1",
        y="TSNE-2",
        color="Label",
        title="Interactive t-SNE Reduced Embeddings",
        labels={"Label": "Description"},
        hover_data={"Label": True},
    )

    # Step 4: Improve interactivity
    fig.update_traces(marker=dict(size=6, opacity=0.8),
                    selector=dict(mode='markers'))
    fig.update_layout(
        legend_title_text='Descriptions',
        legend=dict(itemsizing='constant'),
        width=1000,   # Set your desired width in pixels
        height=1000,   # Set your desired height in pixels
    )

    # Show the plot
    fig.show()

In [3]:
import json
import random




def load_data(path):
    df = pd.read_csv(path, index_col=0)
    df = df.fillna('')

    # Drop majority of rows with empty message
    empty_ids = df.index[df["lang"] == ''].tolist()
    if len(empty_ids) > len(df) * 0.1:
        drop_ids = random.sample(empty_ids, int(len(empty_ids) - len(df) * 0.1))
        df = df.drop(drop_ids)

    # Prepare data
    df["lang"] = df["lang"].apply(lambda x: x.split(" "))
    df["obs"] = df["obs"].apply(json.loads)

    obs_dim = len(df.iloc[0]["obs"])
    max_len = max(df["lang"].apply(len))

    print(f"Obs dim = {obs_dim}, Max message length = {max_len}.")

    return df, obs_dim, max_len



In [5]:
df, obs_dim, max_mess_len = load_data("../../results/data/lamarl_data/FH_15.csv")
df.head()


Obs dim = 78, Max message length = 8.


Unnamed: 0,obs,lang
2,"[0.5, 0.42857142857142855, 0.0, 0.0, 0.0, 0.0,...","[Green, Gem, Center]"
4,"[0.42857142857142855, 0.42857142857142855, 0.0...","[Green, Gem, Center]"
5,"[0.42857142857142855, 0.5, 0.0, 0.0, 0.0, 0.0,...","[Green, Gem, Center]"
8,"[0.42857142857142855, 0.42857142857142855, 0.0...",[]
9,"[0.42857142857142855, 0.5, 0.0, 0.0, 0.0, 0.0,...",[]


Create and load model

In [7]:
import torch

from src.algo.policy_diff.langground import LanguageGrounder


context_dim = 8
hidden_dim = 64
embed_dim = 4
n_layer = 2
lr = 0.007
vocab = [
    "Prey", "Center", "North", "South", "East", "West",
    "Gem", "Yellow", "Green", "Purple", "Danger"]

ll = LanguageGrounder(obs_dim, context_dim, hidden_dim, embed_dim, n_layer, lr, vocab, max_mess_len)
ll.load_params(torch.load("../../results/data/lamarl_data/FH15_langground.pt"))
ll.prep_rollout("cpu")


Drop data if needed

In [14]:
eval_data = df.iloc[-5000:]

# Remove rows with long messages
rids = []
for i, row in eval_data.iterrows():
    if len(row["lang"]) > 4 or "Green" not in row["lang"]:# or len(row["lang"]) <= 1 or 10 not in row["lang"]:# or 10 not in row["lang"]:
        rids.append(i)
display_data = eval_data.drop(rids)
display_data

Unnamed: 0,obs,lang
3993350,"[0.5, 0.07142857142857142, 0.0, 0.0, 0.0, 0.0,...","[Green, Gem, West]"
3993357,"[0.14285714285714285, 0.7857142857142857, 0.0,...","[Green, Gem, North, East]"
3993373,"[0.2857142857142857, 0.07142857142857142, 0.0,...","[Green, Gem, West]"
3993378,"[0.07142857142857142, 0.14285714285714285, 0.0...","[Green, Gem, North, West]"
3993382,"[0.6428571428571429, 0.5714285714285714, 0.0, ...","[Green, Gem, Center]"
...,...,...
3999989,"[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[Green, Gem, North, East]"
3999990,"[0.8571428571428571, 0.7857142857142857, 0.0, ...","[Green, Gem, South, East]"
3999996,"[0.8571428571428571, 0.5, 0.0, 0.0, 0.0, 0.0, ...","[Green, Gem, South]"
3999997,"[0.0, 0.35714285714285715, 0.0, 0.0, 0.0, 0.0,...","[Green, Gem, North]"


Generate embeddings and labels

In [15]:
obs = torch.Tensor(np.array(list(display_data["obs"])))
labels = [" ".join(s) for s in list(display_data["lang"])]

enc_obs = ll.obs_encoder(obs).detach().numpy()

plot_tsne(enc_obs, labels)

# Now with trained models

Lang in PredPrey_18

In [7]:
import torch

from src.utils.config import get_config
from src.utils.utils import load_args
from src.envs.make_env import make_env
from src.algo.lgmarl_diff import LanguageGroundedMARL

# Load data
df, obs_dim, max_mess_len = load_data("../../results/data/lamarl_data/PPrgb_18.csv")
df.head()

cfg = get_config().parse_args("")
cfg.model_dir = "../../models/magym_PredPrey_RGB/18s50np_lang_ce0/run1"
cfg.continue_run = True
load_args(cfg)

# Create model
envs, parser = make_env(cfg, 1)
n_agents = envs.n_agents
obs_space = envs.observation_space
shared_obs_space = envs.shared_observation_space
act_space = envs.action_space
model = LanguageGroundedMARL(
    cfg, 
    n_agents, 
    obs_space, 
    shared_obs_space, 
    act_space,
    parser)

# Load params
pretrained_model_path = cfg.model_dir + "/model_ep.pt"
model.load(pretrained_model_path)

Obs dim = 77, Max message length = 6.


In [11]:
df

Unnamed: 0,obs,lang
7,"[0.7058823529411765, 0.7058823529411765, 0.0, ...",[]
9,"[0.5882352941176471, 0.7058823529411765, 0.0, ...",[]
19,"[0.7058823529411765, 0.6470588235294118, 0.0, ...",[]
27,"[0.7647058823529411, 0.6470588235294118, 0.0, ...",[]
28,"[0.8823529411764706, 0.17647058823529413, 0.0,...","[Prey, South, West]"
...,...,...
3999987,"[1.0, 0.8235294117647058, 0.0, 0.0, 0.0, 0.0, ...",[]
3999988,"[0.47058823529411764, 0.5294117647058824, 0.0,...","[Prey, Center]"
3999991,"[1.0, 0.8823529411764706, 0.0, 0.0, 0.0, 0.0, ...",[]
3999992,"[0.47058823529411764, 0.47058823529411764, 0.0...","[Prey, Center]"


In [12]:
eval_data = df.iloc[-5000:]

# Remove rows with long messages
rids = []
for i, row in eval_data.iterrows():
    if len(row["lang"]) > 4: # or "Green" not in row["lang"]:# or len(row["lang"]) <= 1 or 10 not in row["lang"]:# or 10 not in row["lang"]:
        rids.append(i)
display_data = eval_data.drop(rids)
print(len(display_data))

obs = torch.Tensor(np.array(list(display_data["obs"])))
labels = [" ".join(s) for s in list(display_data["lang"])]

enc_obs = model.model.agents[0].obs_in(obs).detach().numpy()
print(enc_obs.shape)

4971
(4971, 128)


In [13]:
plot_tsne(enc_obs, labels)