# TSNE

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import torch.nn.functional as F
import torch

from dataset import Dataset
from initialize import initialize_model, load_config
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader

In [2]:
dataset_config = {
  **load_config("./dataset.cfg"),
  "datasets_dir": "../../../../../datasets"
}

model_config = load_config("./model.cfg")
run_config = load_config("./run.cfg")

device = run_config["device"] 

In [None]:
# Load Split Indexes
train_indexes = np.load("./train_indexes.npy")
test_indexes = np.load("./test_indexes.npy")

print(f"train_indexes ({len(train_indexes)}): {train_indexes}")
print(f"test_indexes ({len(test_indexes)}): {test_indexes}")

In [4]:
EPOCH = "last"

In [5]:
indexes = train_indexes
# indexes = test_indexes

dataset = Dataset(dataset_config, indexes = indexes)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

In [None]:
if not os.path.isdir("tsne"):
  os.mkdir("tsne")

checkpoint_path = f"checkpoints/{EPOCH}.pth"
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))

model = initialize_model(model_config["name"])
print(model)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()

# Assuming you have a dataset named 'dataset' and a trained model named 'model'
# Step 1: Extract latent space representations
latent_space = []
labels = []

with torch.no_grad():
  for index, (video, target) in enumerate(dataloader):
    video = video.to(device)
    target = target.to(device)

    output = model(video)

    if model in ["vae", "unet_vae"]:
        output = output[0].detach().cpu()
    else: 
        output = output.detach().cpu()

    output = F.normalize(output, p=2, dim=-1)

    latent_space.append(output.numpy())  # Assuming outputs are numpy arrays
    labels.append(indexes[index])  # Assuming targets are numpy arrays

latent_space = np.concatenate(latent_space, axis=0)
latent_space = latent_space.reshape(latent_space.shape[0], -1)

In [None]:
# Step 2: Reduce dimensionality with t-SNE
tsne = TSNE(n_components=2, perplexity=1, random_state=42)
latent_space_tsne = tsne.fit_transform(latent_space)

# Step 3: Plot the reduced latent space
plt.figure(figsize=(8, 6))
plt.scatter(latent_space_tsne[:, 0], latent_space_tsne[:, 1], c=labels, cmap='viridis')
plt.colorbar(label='Class')
plt.xlabel('t-SNE Dimension 1')
plt.ylabel('t-SNE Dimension 2')
plt.title(f"t-SNE Plot of Latent Space (Epoch: {EPOCH})")
plt.savefig(f"tsne/{EPOCH}.png")
plt.show()