In [None]:
import numpy as np

from src.behavior import get_actor
from src.eval.eval_utils import get_model_from_api_or_cached


from src.common.files import get_processed_paths, path_override
from torch.utils.data import DataLoader
from src.dataset.dataset import FurnitureImageDataset
from src.train.bc import to_native

import torch
from src.behavior.base import Actor
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

from tqdm import tqdm

In [None]:
run_id_naive = "real-ol-demo-scaling-1/3js1f6n1"
run_id_upwt = "real-ol-demo-scaling-1/31xxjkpb"
run_id_conf = "real-ol-demo-scaling-1/1knzc1b4"
run_id_confusion4 = "real-one_leg-cotrain-2/7grrzinv"
run_id_confusion3 = "real-one_leg-cotrain-2/xwawbdtk"
run_id_confusion2 = "real-one_leg-cotrain-2/f7usetuv"

In [None]:
# Get the config to load in the standard model with only pretrained weights
cfg, _ = get_model_from_api_or_cached(run_id_naive, "latest", wandb_mode="online")

cfg

In [None]:
actor: Actor = get_actor(cfg=cfg, device="cuda")

In [None]:
# Turn off the actor's training mode and gradient computation
actor.eval()

for param in actor.parameters():
    param.requires_grad = False

In [None]:
actor.model.eval()

In [None]:
actor.model.training, actor.encoder1.training, actor.encoder2.training

In [None]:
if cfg.data.data_paths_override is None:
    data_path = get_processed_paths(
        controller=to_native(cfg.control.controller),
        domain=to_native(cfg.data.environment),
        task=to_native(cfg.data.furniture),
        demo_source=to_native(cfg.data.demo_source),
        randomness=to_native(cfg.data.randomness),
        demo_outcome=to_native(cfg.data.demo_outcome),
        suffix=to_native(cfg.data.suffix),
    )
else:
    data_path = path_override(cfg.data.data_paths_override)

print(f"Using data from {data_path}")

dataset = FurnitureImageDataset(
    dataset_paths=data_path,
    pred_horizon=cfg.data.pred_horizon,
    obs_horizon=cfg.data.obs_horizon,
    action_horizon=cfg.data.action_horizon,
    # data_subset=cfg.data.data_subset,
    data_subset=5,
    control_mode=cfg.control.control_mode,
    predict_past_actions=cfg.data.predict_past_actions,
    pad_after=cfg.data.get("pad_after", True),
    max_episode_count=cfg.data.get("max_episode_count", None),
    minority_class_power=cfg.data.get("minority_class_power", False),
)

# Create dataloaders
trainload_kwargs = dict(
    dataset=dataset,
    # batch_size=cfg.training.batch_size,
    batch_size=64,
    num_workers=cfg.data.dataloader_workers,
    shuffle=True,
    pin_memory=True,
    drop_last=False,
    persistent_workers=False,
)

trainloader = DataLoader(**trainload_kwargs)

In [None]:
data_path

In [None]:
def get_embeddings(actor: Actor, batch):
    img1 = batch["color_image1"].to("cuda").squeeze()
    emb1 = actor.encoder1_proj(actor.encoder1(img1))

    img2 = batch["color_image2"].to("cuda").squeeze()
    emb2 = actor.encoder2_proj(actor.encoder2(img2))

    return emb1, emb2

In [None]:
def get_embeddings_and_domain_labels(actor: Actor, trainloader, sample_size=None):
    # Initialize empty lists to store embeddings and domain labels
    embeddings = []
    domain_labels = []

    # Iterate over the dataset
    for batch in tqdm(trainloader):
        emb1, emb2 = get_embeddings(actor, batch)

        # Concatenate the embeddings into a single tensor of shape (batch_size, 2 * embedding_size)
        emb = torch.cat([emb1, emb2], dim=1)
        embeddings.append(emb.cpu().numpy())

        domain_labels.extend(batch["domain"].cpu().numpy().tolist())

    # Concatenate the embeddings and convert to numpy array
    embeddings = np.concatenate(embeddings, axis=0)
    domain_labels = np.array(domain_labels).reshape(-1)

    # Print the average standard deviation of the embeddings
    print(
        f"Average standard deviation of embeddings: {np.mean(np.std(embeddings, axis=0))}"
    )

    # If sample_size is not None, sample a subset of the embeddings and domain labels
    # Stratified by domain label
    if sample_size is not None:
        sampled_embeddings = []
        sampled_domain_labels = []

        for domain_label in np.unique(domain_labels):
            idx = np.where(domain_labels == domain_label)[0]
            idx = np.random.choice(idx, size=sample_size, replace=False)

            sampled_embeddings.append(embeddings[idx])
            sampled_domain_labels.extend(domain_labels[idx].tolist())

        embeddings = np.concatenate(sampled_embeddings, axis=0)
        domain_labels = np.array(sampled_domain_labels).reshape(-1)

    return embeddings, domain_labels

In [None]:
simcolor = "#2398DA"
realcolor = "#E34A6F"


def visualize_embeddings_tsne(embeddings, domain_labels, title=None):
    # Apply t-SNE to reduce the dimensionality to 2
    tsne = TSNE(n_components=2, random_state=42)
    embeddings_tsne = tsne.fit_transform(embeddings)

    # Split the embeddings based on the domain labels
    embeddings_domain1 = embeddings_tsne[domain_labels == 0]
    embeddings_domain2 = embeddings_tsne[domain_labels == 1]

    # Plot the embeddings in two different colors
    plt.figure(figsize=(4, 4))
    plt.scatter(
        embeddings_domain1[:, 0],
        embeddings_domain1[:, 1],
        color=simcolor,
        label="Sim",
        alpha=0.2,
        s=2,
    )
    plt.scatter(
        embeddings_domain2[:, 0],
        embeddings_domain2[:, 1],
        color=realcolor,
        label="Real",
        alpha=0.2,
        s=2,
    )
    # plt.xlabel('t-SNE Dimension 1')
    # plt.ylabel('t-SNE Dimension 2')
    plt.legend(frameon=False)

    if title is not None:
        plt.title(title + " (t-SNE)")
    plt.show()


def visualize_embeddings_pca(embeddings, domain_labels, title=None):
    # Apply PCA to reduce the dimensionality to 2
    pca = PCA(n_components=2)
    embeddings_pca = pca.fit_transform(embeddings)

    # Split the embeddings based on the domain labels
    embeddings_domain1 = embeddings_pca[domain_labels == 0]
    embeddings_domain2 = embeddings_pca[domain_labels == 1]

    # Plot the embeddings in two different colors
    plt.figure(figsize=(4, 4))
    plt.scatter(
        embeddings_domain1[:, 0],
        embeddings_domain1[:, 1],
        color=simcolor,
        label="Sim",
        alpha=0.2,
        s=2,
    )
    plt.scatter(
        embeddings_domain2[:, 0],
        embeddings_domain2[:, 1],
        color=realcolor,
        label="Real",
        alpha=0.2,
        s=2,
    )
    # plt.xlabel('PCA Dimension 1')
    # plt.ylabel('PCA Dimension 2')
    plt.legend(frameon=False)

    if title is not None:
        plt.title(title + " (PCA)")
    plt.show()

In [None]:
from pytorch3d.transforms import so3_exponential_map, so3_relative_angle

### Plot embeddings for the pretrained R3M model

In [None]:
sample_size = 1_000

In [None]:
embeddings, domain_labels = get_embeddings_and_domain_labels(
    actor, trainloader, sample_size=sample_size
)

title = "Pre-trained R3M weights"
visualize_embeddings_tsne(embeddings, domain_labels, title=title)
visualize_embeddings_pca(embeddings, domain_labels, title=title)

### Plot embeddings for co-trained model with no tricks

In [None]:
# Get weights
_, wts = get_model_from_api_or_cached(run_id_naive, "latest", wandb_mode="online")

# Load the weights into the actor
state_dict = torch.load(wts)
if "model_state_dict" in state_dict:
    actor.load_state_dict(state_dict["model_state_dict"])
else:
    actor.load_state_dict(state_dict)


# Get the embeddings and domain labels
embeddings, domain_labels = get_embeddings_and_domain_labels(
    actor, trainloader, sample_size=sample_size
)

# Visualize the embeddings using t-SNE and PCA
title = "Co-training, naive mixing"
visualize_embeddings_tsne(embeddings, domain_labels, title=title)
visualize_embeddings_pca(embeddings, domain_labels, title=title)

### Plot embeddings for co-trained model with up-weighting of real data

In [None]:
# Get weights
_, wts = get_model_from_api_or_cached(run_id_upwt, "latest", wandb_mode="online")

# Load the weights into the actor
state_dict = torch.load(wts)
if "model_state_dict" in state_dict:
    actor.load_state_dict(state_dict["model_state_dict"])
else:
    actor.load_state_dict(state_dict)

# Get the embeddings and domain labels
embeddings, domain_labels = get_embeddings_and_domain_labels(
    actor, trainloader, sample_size=sample_size
)

# Visualize the embeddings using t-SNE and PCA
title = "Co-training, confusion loss $\\lambda=10^{-4}$"
visualize_embeddings_tsne(embeddings, domain_labels, title=title)
visualize_embeddings_pca(embeddings, domain_labels, title=title)

### Plot embeddings for co-trained model with confusion loss and up-weighting of real data

In [None]:
# Get weights
import math


cfg, wts = get_model_from_api_or_cached(run_id_conf, "latest", wandb_mode="online")

# Load the weights into the actor
state_dict = torch.load(wts)
if "model_state_dict" in state_dict:
    actor.load_state_dict(state_dict["model_state_dict"])
else:
    actor.load_state_dict(state_dict)

# Get the embeddings and domain labels
embeddings, domain_labels = get_embeddings_and_domain_labels(
    actor, trainloader, sample_size=sample_size
)

confusion = cfg.actor.confusion_loss_beta

# Visualize the embeddings using t-SNE and PCA
title = f"Co-training, confusion loss $\\lambda=10^{int(math.log10(float(confusion)))}$"
visualize_embeddings_tsne(embeddings, domain_labels, title=title)
visualize_embeddings_pca(embeddings, domain_labels, title=title)

In [None]:
# Get weights
_, wts = get_model_from_api_or_cached(run_id_confusion2, "latest", wandb_mode="online")

# Load the weights into the actor
state_dict = torch.load(wts)
if "model_state_dict" in state_dict:
    state_dict = state_dict["model_state_dict"]

if "model._dummy_variable" in state_dict:
    del state_dict["model._dummy_variable"]

actor.load_state_dict(state_dict)

# Get the embeddings and domain labels
embeddings, domain_labels = get_embeddings_and_domain_labels(
    actor, trainloader, sample_size=sample_size
)

# Visualize the embeddings using t-SNE and PCA
title = "Co-training, confusion loss $\\lambda=10^{-2}$"
visualize_embeddings_tsne(embeddings, domain_labels, title=title)
visualize_embeddings_pca(embeddings, domain_labels, title=title)

In [None]:
embeddings.mean(axis=0)