In [None]:
!pip install matplotlib

In [None]:
import torch
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

def tsne_cluster(tensor: torch.Tensor, n_components: int = 2, perplexity: float = 30.0, random_state: int = 42):
    """
    Applies t-SNE to a 3186x256 tensor and plots the 2D result.

    Args:
        tensor (torch.Tensor): Input tensor of shape (3186, 256)
        n_components (int): Dimensionality of the reduced space (2 or 3)
        perplexity (float): Perplexity parameter for t-SNE
        random_state (int): Random seed for reproducibility
    """
    assert tensor.shape == (3186, 256), f"Expected tensor shape (3186, 256), got {tensor.shape}"
    
    tensor_np = tensor.detach().cpu().numpy()  # Convert to numpy

    tsne = TSNE(n_components=n_components, perplexity=perplexity, random_state=random_state)
    reduced = tsne.fit_transform(tensor_np)

    # Plotting for 2D
    if n_components == 2:
        plt.figure(figsize=(8, 6))
        plt.scatter(reduced[:, 0], reduced[:, 1], s=10, cmap='viridis')
        plt.title("t-SNE Clustering (2D)")
        plt.xlabel("Component 1")
        plt.ylabel("Component 2")
        plt.grid(True)
        plt.show()
    else:
        print("Only 2D plotting supported in this example.")
