# Data Visualisation

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

def load_dance_dataset(file_path):
    try:
        data = np.load(file_path)
        print(f"Dataset loaded successfully from {file_path}. Data shape: {data.shape}")
        return data
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return None

def animate_dance(data, interval=50):
    num_joints, num_timesteps, _ = data.shape


    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')
    ax.set_title("3D Dance Animation")


    ax.set_xlim(np.min(data[:,:,0]) - 0.1, np.max(data[:,:,0]) + 0.1)
    ax.set_ylim(np.min(data[:,:,1]) - 0.1, np.max(data[:,:,1]) + 0.1)
    ax.set_zlim(np.min(data[:,:,2]) - 0.1, np.max(data[:,:,2]) + 0.1)
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    scat = ax.scatter([], [], [], c='red', s=50)

    def init():
        scat._offsets3d = ([], [], [])
        return scat,

    def update(frame):
        x = data[:, frame, 0]
        y = data[:, frame, 1]
        z = data[:, frame, 2]
        scat._offsets3d = (x, y, z)
        return scat,

    anim = FuncAnimation(fig, update, frames=num_timesteps, init_func=init, interval=interval, blit=False)
    plt.close(fig)
    return anim

file_path = "/content/mariel_knownbetter.npy"
dance_data = load_dance_dataset(file_path)
if dance_data is not None:
    anim = animate_dance(dance_data, interval=50)
    display(HTML(anim.to_jshtml()))


# Preprocessing

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import random


def slice_dance_phrases(mocap_data, window_size=30, step_size=10):
    _, total_timesteps, _ = mocap_data.shape
    phrases = []
    for start in range(0, total_timesteps - window_size + 1, step_size):
        phrase = mocap_data[:, start:start+window_size, :]
        phrases.append(phrase)
    return phrases

# For synthetic labels, we define a small vocabulary.
vocab = {"spin": 0, "jump": 1, "kick": 2, "step": 3, "wave": 4, "run":5}
inv_vocab = {v: k for k, v in vocab.items()}

def generate_synthetic_labels(num_phrases):
    labels = []
    possible_labels = list(vocab.keys())
    for _ in range(num_phrases):
        labels.append(random.choice(possible_labels))
    return labels

class DanceTextDataset(Dataset):
    def __init__(self, mocap_data, window_size=30, step_size=10):
        super().__init__()
        self.phrases = slice_dance_phrases(mocap_data, window_size, step_size)
        self.labels = generate_synthetic_labels(len(self.phrases))
        self.window_size = window_size
        self.num_joints = mocap_data.shape[0]

    def __len__(self):
        return len(self.phrases)

    def __getitem__(self, idx):
        phrase = self.phrases[idx]
        phrase = phrase.transpose(1, 0, 2).reshape(self.window_size, -1)
        label = self.labels[idx]
        token_idx = vocab[label]
        token_tensor = torch.tensor([token_idx], dtype=torch.long)
        phrase_tensor = torch.tensor(phrase, dtype=torch.float)
        return phrase_tensor, token_tensor



# Define Model

In [None]:

class DanceEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, output_dim=64, num_layers=1):
        super(DanceEncoder, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # x: (batch, window_size, input_dim)
        _, (hn, _) = self.lstm(x)
        embedding = self.fc(hn[-1])
        return embedding  # (batch, output_dim)

# Text Encoder: processes tokenized text (for simplicity, each label is a one-word sequence)
class TextEncoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim=50, hidden_dim=128, output_dim=64, num_layers=1):
        super(TextEncoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        # Since our text is very short, we simply average embeddings.
        self.fc = nn.Linear(embedding_dim, output_dim)

    def forward(self, x):
        # x: (batch, seq_length) where seq_length=1 here.
        emb = self.embedding(x)  # (batch, seq_length, embedding_dim)
        # For one-word labels, simply squeeze the sequence dimension.
        emb = emb.squeeze(1)  # (batch, embedding_dim)
        out = self.fc(emb)
        return out  # (batch, output_dim)



# Contrastive Learning

In [None]:
def contrastive_loss(dance_embeds, text_embeds, temperature=0.07):
    """
    Compute InfoNCE loss between dance and text embeddings.
    Assumes embeddings are normalized.
    """
    logits = torch.matmul(dance_embeds, text_embeds.t()) / temperature
    batch_size = logits.shape[0]
    labels = torch.arange(batch_size).to(logits.device)
    loss_dance2text = F.cross_entropy(logits, labels)
    loss_text2dance = F.cross_entropy(logits.t(), labels)
    return (loss_dance2text + loss_text2dance) / 2



# Training Loop

In [None]:
mocap_data=dance_data
window_size = 30
step_size = 10
dataset = DanceTextDataset(mocap_data, window_size=window_size, step_size=step_size)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
input_dim = 55 * 3
dance_encoder = DanceEncoder(input_dim=input_dim, hidden_dim=128, output_dim=64)
text_encoder = TextEncoder(vocab_size=len(vocab), embedding_dim=50, hidden_dim=128, output_dim=64)
optimizer = optim.Adam(list(dance_encoder.parameters()) + list(text_encoder.parameters()), lr=1e-3)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    epoch_loss = 0.0
    for dance_phrase, token_tensor in dataloader:
        dance_embed = dance_encoder(dance_phrase)  # (batch, 64)
        text_embed = text_encoder(token_tensor)      # (batch, 64)
        dance_embed = F.normalize(dance_embed, dim=1)
        text_embed = F.normalize(text_embed, dim=1)
        loss = contrastive_loss(dance_embed, text_embed)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs} Loss: {epoch_loss/len(dataloader):.4f}")



# Testing

In [None]:
holdout_size = 100
holdout_dataset = torch.utils.data.Subset(dataset, list(range(len(dataset)-holdout_size, len(dataset))))
holdout_loader = DataLoader(holdout_dataset, batch_size=holdout_size, shuffle=False)

with torch.no_grad():
    for dance_phrases_hold, token_hold in holdout_loader:
        dance_emb_hold = F.normalize(dance_encoder(dance_phrases_hold), dim=1)
        text_emb_hold = F.normalize(text_encoder(token_hold), dim=1)
        break

def retrieve_dance_from_text(query_text, dance_embeds, dance_dataset):
    """
    Given a natural language query (e.g., "jump"), retrieve the closest dance phrase from the holdout set.
    """
    # Tokenize the query (assuming query is a single word in our simple vocabulary)
    token_idx = vocab.get(query_text, None)
    if token_idx is None:
        print("Unknown query word.")
        return None
    query_tensor = torch.tensor([[token_idx]])
    with torch.no_grad():
        query_embed = F.normalize(text_encoder(query_tensor), dim=1)
        # Compute cosine similarities with holdout dance embeddings.
        sims = torch.matmul(query_embed, dance_emb_hold.t())
        best_idx = sims.argmax(dim=1).item()
        # Retrieve the corresponding dance phrase (and its synthetic label).
        retrieved_phrase, retrieved_label = dance_dataset[len(dance_dataset)-holdout_size + best_idx]
        return retrieved_phrase, retrieved_label

# Example 2: Generating natural language from a dance sequence input.
def retrieve_text_from_dance(query_dance, text_embeds, text_dataset):
    """
    Given a dance sequence query, retrieve the closest text description from the holdout set.
    """
    with torch.no_grad():
        query_embed = F.normalize(dance_encoder(query_dance.unsqueeze(0)), dim=1)
        sims = torch.matmul(query_embed, text_emb_hold.t())
        best_idx = sims.argmax(dim=1).item()
        _, retrieved_token = holdout_dataset[best_idx]
        retrieved_word = inv_vocab[retrieved_token.item()]
        return retrieved_word

retrieved_phrase, retrieved_label = retrieve_dance_from_text("wave", dance_emb_hold, holdout_dataset)
print("For text query 'jump', retrieved dance phrase has synthetic label:", inv_vocab[retrieved_label.item()])


sample_dance, sample_token = holdout_dataset[0]
retrieved_text = retrieve_text_from_dance(sample_dance, text_emb_hold, holdout_dataset)
print("For a dance sequence query, retrieved text description is:", retrieved_text)
