# DINOv2

In [1]:
import torch
from torchvision import datasets, transforms

# Load DINOv2 onto GPU
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(device)
model.eval()

# Transform for DINOv2 (zero-padding + normalization)
dinov2_transform = transforms.Compose([
    transforms.Pad((96, 96)),  # (224-32)/2 = 96 pixels padding
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Transform for original CIFAR-10 (just ToTensor to get raw pixels)
original_transform = transforms.ToTensor()

# Load dataset twice (once for DINOv2, once for original)
cifar_dinov2 = datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=dinov2_transform,
)

cifar_original = datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=original_transform,
)

# Create DataLoaders
loader_dinov2 = torch.utils.data.DataLoader(
    cifar_dinov2,
    batch_size=512,
    num_workers=4,
    pin_memory=True,
)

loader_original = torch.utils.data.DataLoader(
    cifar_original,
    batch_size=512,
    num_workers=4,
    pin_memory=True,
)

# Extract DINOv2 embeddings
embeddings, labels = [], []
with torch.no_grad():
    for images, targets in loader_dinov2:
        images = images.to(device, non_blocking=True)
        embeddings.append(model(images).cpu())
        labels.append(targets)

embeddings = torch.cat(embeddings)  # Shape: [10000, 384]
labels = torch.cat(labels)  # Shape: [10000]

# Extract original images (32x32, no padding/normalization)
original_images = []
for images, _ in loader_original:
    original_images.append(images)

original_images = torch.cat(original_images)  # Shape: [10000, 3, 32, 32]

# Save results (optional)
torch.save({
    'embeddings': embeddings,
    'labels': labels,
    'original_images': original_images,
}, 'cifar10_dinov2_features_and_originals.pt')

print("Shapes:")
print(f"Embeddings: {embeddings.shape}")  # [10000, 384]
print(f"Labels: {labels.shape}")  # [10000]
print(f"Original Images: {original_images.shape}")  # [10000, 3, 32, 32]

Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


Shapes:
Embeddings: torch.Size([10000, 384])
Labels: torch.Size([10000])
Original Images: torch.Size([10000, 3, 32, 32])


In [2]:
import numpy as np
from scipy.stats import pearsonr

def correlation_dissimilarity(emb1, emb2):
    """
    emb1 (np.array) : embedding in one feature space
    emb2 (np.array) : embedding in another feature space
    """
    dissim1 = 1. - np.corrcoef(emb1)
    dissim2 = 1. - np.corrcoef(emb2)

    triu_indices = np.triu_indices_from(dissim1, k=1)
    flat1 = dissim1[triu_indices]
    flat2 = dissim2[triu_indices]

    # Compute second-order similarity (Pearson correlation)
    r, _ = pearsonr(flat1, flat2)
    return r


from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

def train_linear_classifier(X, y, test_size=0.2, random_state=42, **kwargs):
    """
    Trains a linear classifier (Logistic Regression) and returns the model and accuracy.

    Parameters:
    X (array-like): Feature matrix
    y (array-like): Target vector
    test_size (float): Proportion of data to use for testing (default: 0.2)
    random_state (int): Random seed for reproducibility (default: 42)
    **kwargs: Additional arguments to pass to LogisticRegression

    Returns:
    tuple: (trained_model, accuracy_score)
    """
    # Split data into training and test sets
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=random_state
    )

    # Initialize and train the linear classifier
    model = LogisticRegression(**kwargs)
    model.fit(X_train, y_train)

    # Make predictions and calculate accuracy
    y_pred = model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)

    return model, accuracy

def encode_set(encoder_function: callable, loader, original_loader, device="cpu"):
    all_embeddings = []
    all_labels = []
    all_original_images = []

    with torch.no_grad():
        for (images_dino, targets), (images_orig, _) in zip(loader, original_loader):
            images_dino = images_dino.to(device, non_blocking=True)
            embeddings = encoder_function(images_dino).cpu()
            all_embeddings.append(embeddings)
            all_labels.append(targets)
            all_original_images.append(images_orig)

    embeddings = torch.cat(all_embeddings)  # [N, D]
    labels = torch.cat(all_labels)
    original_images = torch.cat(all_original_images)
    original_images = original_images.reshape(original_images.shape[0], -1)

    return (embeddings.numpy(),
            labels.numpy(),
            original_images.numpy())


def run(encoder_function: callable, loader: torch.utils.data.DataLoader, original_loader: torch.utils.data.DataLoader,
        logger = None, device = "cpu"):
    embeddings_np, labels_np, original_images_np = encode_set(encoder_function, loader, original_loader, device)
    if not logger is None:
        logger.log({
            "classification_accuracy" : train_linear_classifier(embeddings_np, labels_np)[1],
            "second_order_similarity" : correlation_dissimilarity(embeddings_np, original_images_np)
        })
    return embeddings_np, labels_np, original_images_np


In [4]:
# embeddings_np = embeddings.detach().cpu().numpy()
# original_images_np = original_images.detach().cpu().numpy().reshape(original_images.shape[0], -1)
# # correlation_dissimilarity(embeddings_np[:1000], embeddings_np[1000:2000])
# correlation_dissimilarity(embeddings_np[:100], original_images_np[:100])
import wandb
run_name = f'untrained_dino_metric'
config = {
    "encoder" : "dino",
    "type_log" : "metric",
}

logger = wandb.init(project = 'CV_frameworks', config = config, name = run_name)

embeddings, labels, original_images = run(model, loader_dinov2, loader_original,logger = logger, device= device)
logger.finish()

[34m[1mwandb[0m: Currently logged in as: [33mms-hate-life[0m ([33mms-hate-life-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


0,1
classification_accuracy,▁
second_order_similarity,▁

0,1
classification_accuracy,0.738
second_order_similarity,0.15215


In [33]:
from bokeh.plotting import figure, show
from bokeh.io import output_notebook
from bokeh.models import ColumnDataSource, ColorBar
from bokeh.transform import linear_cmap, factor_cmap
from bokeh.palettes import Viridis256
from bokeh.layouts import row

import plotly.express as px

def embedding_plotter(embedding, data=None, hue=None, hover=None, tools = None, nv_cat = 5, height = 400, width = 400, display_result=True):
    '''
    Рисовалка эмбеддинга. 2D renderer: bokeh. 3D renderer: plotly.
    Обязательные инструменты:
        - pan (двигать график)
        - box zoom
        - reset (вылезти из зума в начальное положение)

        embedding: something 2D/3D, slicable ~ embedding[:, 0] - валидно
            Эмбеддинг
        data: pd.DataFrame
            Данные, по которым был построен эмбеддинг
        hue: string
            Колонка из data, по которой красим точки. Поддерживает интерактивную легенду: по клику на каждое
                значение hue можно скрыть весь цвет.
        hover: string or list of strings
            Колонк[а/и] из data, значения которых нужно выводить при наведении мышки на точку
        nv_cat: int
            number of unique values to consider column categorical
        tools: iterable or string in form "tool1,tool2,..." or ["tool1", "tool2", ...]
            tools for the interactive plot
        height, width: int
            parameters of the figure
        display_result: boolean
            if the results are displayed or just returned

    '''
    if tools is None:
        tools = 'lasso_select,box_select,pan,zoom_in,zoom_out,reset,hover'
    else:
        if hover and not("hover" in tools):
            tools = 'hover,'+",".join(tools)


    if embedding.shape[1] == 3:
        if hover:
            hover_data = {h:True for h in hover}
        else:
            hover_data = None
        df = pd.DataFrame(embedding, columns = ['x', 'y', 'z'])
        df = pd.concat((df, data), axis=1)
        fig = px.scatter_3d(
            data_frame = df,
            x='x',
            y='y',
            z='z',
            color=df[hue],
            hover_data = hover_data
        )

        fig.update_layout(
            modebar_add=tools.split(","),
        )

        fig.update_traces(marker_size=1, selector=dict(type='scatter3d'))

        if display_result: fig.show()

    if embedding.shape[1] == 2:
        output_notebook()
        df = pd.DataFrame(embedding, columns = ['x', 'y'])
        df = pd.concat((df, data), axis=1)
        tooltips = [
            ('x, y', '$x, $y'),
            ('index', '$index')
        ]
        if hover:
            for col in hover:
                tooltips.append((col, "@"+col))
        fig = figure(tools=tools, width=width, height=height, tooltips=tooltips)
        if df[hue].nunique() < nv_cat or df[hue].dtype == "category":
            df[hue] = df[hue].astype(str)
            source = ColumnDataSource(df)
            color_mapper = factor_cmap(
            field_name=hue,
            palette='Category10_3',
            factors=df[hue].unique()
            )
            fig.scatter(
            x='x', y='y',
            color=color_mapper,
            source=source,
            legend_group=hue)

            fig.legend.location = 'bottom_left'
            fig.legend.click_policy = 'mute'
        else:
            source = ColumnDataSource(df)
            color_mapper = linear_cmap(
                field_name=hue,
                palette=Viridis256,
                low=min(df[hue]),
                high=max(df[hue]))
            fig.scatter(
                x='x', y='y',
                color=color_mapper,
                source=source)
            color_bar = ColorBar(color_mapper=color_mapper['transform'], width=8, location=(0,0), title = hue)
            fig.add_layout(color_bar, 'right')


        if display_result: show(fig)

    if embedding.shape[1] > 3:
        print("wrong species, doooooodes")
    else: return fig

import pandas as pd
import numpy as np
from sklearn.manifold import TSNE

# Assuming your embeddings are in a variable called 'embeddings'
# embeddings = np.random.rand(10000, 384)  # Example - replace with your actual embeddings

# Create t-SNE
tsne = TSNE(n_components=2, random_state=1,
            init='pca', n_iter=5000,
            metric='cosine')

# Fit and transform your data
tsne_results = tsne.fit_transform(embeddings[:1000])

# Prepare the data DataFrame correctly
data_df = pd.DataFrame({
    'label': np.array(labels[:1000])  # Assuming you have labels
    # Add any other columns you want for hover information
})

# Call the plotting function correctly
embedding_plotter(
    embedding=tsne_results,  # This should be your 2D t-SNE results (1000x2 array)
    data=data_df,            # This contains your labels and other metadata
    hue='label',             # Column name in data_df to use for coloring
)



In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import timm  # For ViT models
import copy

from tqdm import tqdm

import wandb

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- DINO Loss ---
class DINOLoss(nn.Module):
    def __init__(self, out_dim, teacher_temp=0.04, student_temp=0.1, center_momentum=0.9):
        super().__init__()
        self.teacher_temp = teacher_temp
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.center = torch.zeros(1, out_dim).to(device)

    def forward(self, student_out, teacher_out):
        student_out = F.log_softmax(student_out / self.student_temp, dim=-1)
        teacher_out = F.softmax((teacher_out - self.center) / self.teacher_temp, dim=-1)
        loss = -torch.sum(teacher_out * student_out, dim=-1).mean()
        # update center
        self.center = self.center * self.center_momentum + (1 - self.center_momentum) * teacher_out.mean(dim=0, keepdim=True)
        return loss

# --- Data Augmentations ---
def get_dino_transforms():
    global_crop = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.5, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
        transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    return global_crop

class MultiCropWrapper(torch.utils.data.Dataset):
    def __init__(self, base_dataset, transform, num_crops=2):
        self.base_dataset = base_dataset
        self.transform = transform
        self.num_crops = num_crops

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        img, _ = self.base_dataset[idx]
        crops = [self.transform(img) for _ in range(self.num_crops)]
        return crops

# --- Models ---
def get_vit_model():
    model = timm.create_model('vit_small_patch8_224', pretrained=False)
    model.head = nn.Identity()
    return model.to(device)

# --- Training Setup ---
def train_one_epoch(student, teacher, loss_fn, optimizer, dataloader, momentum=0.996):
    student.train()
    total_loss = 0
    num_batches = 0
    for crops in tqdm(dataloader):
        crops = [c.to(device, non_blocking=True) for c in crops]
        global_crops = crops[:2]

        student_out = [student(crop.unsqueeze(0)) for crop in global_crops]
        teacher.eval()
        with torch.no_grad():
            teacher_out = teacher(global_crops[0].unsqueeze(0))

        loss = sum(loss_fn(s_out, teacher_out) for s_out in student_out) / len(student_out)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # EMA update teacher
        with torch.no_grad():
            for ps, pt in zip(student.parameters(), teacher.parameters()):
                pt.data.mul_(momentum).add_((1. - momentum) * ps.data)

        total_loss += loss.item()
        num_batches += 1
    return total_loss/num_batches
        # print(f"Loss: {loss.item():.4f}")

# --- Putting It All Together ---
def main():
    transform = get_dino_transforms()
    dataset = datasets.CIFAR10(root='./data', train=True, download=True)
    dino_dataset = MultiCropWrapper(dataset, transform, num_crops=2)
    dataloader = torch.utils.data.DataLoader(dino_dataset, batch_size=64, shuffle=True, num_workers=4, collate_fn=lambda x: list(zip(*x))[0])

    run_name = f'finetuning_dino'
    num_epochs = 2
    config = {
        "encoder" : "vit_small_patch8_224",
        "num_epochs" : num_epochs
    }

    logger = wandb.init(project = 'CV_frameworks', config = config, name = run_name)

    student = get_vit_model()
    teacher = copy.deepcopy(student)
    for p in teacher.parameters():
        p.requires_grad = False

    loss_fn = DINOLoss(out_dim=384)  # ViT-small has 384 dim
    optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4, weight_decay=0.1)

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}")
        epoch_loss = train_one_epoch(student, teacher, loss_fn, optimizer, dataloader)
        logger.log({
                    "epoch_loss": epoch_loss
                    },
                    step = epoch + 1
        )

    logger.finish()
    torch.save(student.state_dict(), "finetuned_dino_student.pth")


# main()


In [4]:
student = get_vit_model().to(device)
student.load_state_dict(torch.load("finetuned_dino_student.pth"))
student.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(8, 8), stride=(8, 8))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
  

In [5]:
dinov2_transform = transforms.Compose([
    transforms.Pad((96, 96)),  # (224-32)/2 = 96 pixels padding
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Transform for original CIFAR-10 (just ToTensor to get raw pixels)
original_transform = transforms.ToTensor()

# Load dataset twice (once for DINOv2, once for original)
cifar_dinov2 = datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=dinov2_transform,
)

cifar_original = datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=original_transform,
)

# Create DataLoaders
loader_dinov2 = torch.utils.data.DataLoader(
    cifar_dinov2,
    batch_size=32,
    num_workers=4,
    pin_memory=True,
)

loader_original = torch.utils.data.DataLoader(
    cifar_original,
    batch_size=32,
    num_workers=4,
    pin_memory=True,
)

embeddings, labels, original_images = encode_set(student, loader_dinov2, loader_original, device=device)


torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size

In [7]:
correlation_dissimilarity(embeddings[:5000], original_images[5000:10000])
# train_linear_classifier(embeddings, labels)

np.float64(-0.0005523259521992269)

In [8]:
run_name = f'dino_metric'
config = {
    "encoder" : "dino",
    "type_log" : "metric",
}

logger = wandb.init(project = 'CV_frameworks', config = config, name = run_name)

run(student, loader_dinov2, loader_original, logger, device=device)
logger.finish()

[34m[1mwandb[0m: Currently logged in as: [33mms-hate-life[0m ([33mms-hate-life-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size([32, 384])
torch.Size

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


0,1
classification_accuracy,▁
second_order_similarity,▁

0,1
classification_accuracy,0.2175
second_order_similarity,0.11566


In [24]:
from bokeh.plotting import figure, show
from bokeh.io import output_notebook
from bokeh.models import ColumnDataSource, ColorBar
from bokeh.transform import linear_cmap, factor_cmap
from bokeh.palettes import Viridis256
from bokeh.layouts import row

import plotly.express as px

def embedding_plotter(embedding, data=None, hue=None, hover=None, tools = None, nv_cat = 5, height = 400, width = 400, display_result=True):
    '''
    Рисовалка эмбеддинга. 2D renderer: bokeh. 3D renderer: plotly.
    Обязательные инструменты:
        - pan (двигать график)
        - box zoom
        - reset (вылезти из зума в начальное положение)

        embedding: something 2D/3D, slicable ~ embedding[:, 0] - валидно
            Эмбеддинг
        data: pd.DataFrame
            Данные, по которым был построен эмбеддинг
        hue: string
            Колонка из data, по которой красим точки. Поддерживает интерактивную легенду: по клику на каждое
                значение hue можно скрыть весь цвет.
        hover: string or list of strings
            Колонк[а/и] из data, значения которых нужно выводить при наведении мышки на точку
        nv_cat: int
            number of unique values to consider column categorical
        tools: iterable or string in form "tool1,tool2,..." or ["tool1", "tool2", ...]
            tools for the interactive plot
        height, width: int
            parameters of the figure
        display_result: boolean
            if the results are displayed or just returned

    '''
    if tools is None:
        tools = 'lasso_select,box_select,pan,zoom_in,zoom_out,reset,hover'
    else:
        if hover and not("hover" in tools):
            tools = 'hover,'+",".join(tools)


    if embedding.shape[1] == 3:
        if hover:
            hover_data = {h:True for h in hover}
        else:
            hover_data = None
        df = pd.DataFrame(embedding, columns = ['x', 'y', 'z'])
        df = pd.concat((df, data), axis=1)
        fig = px.scatter_3d(
            data_frame = df,
            x='x',
            y='y',
            z='z',
            color=df[hue],
            hover_data = hover_data
        )

        fig.update_layout(
            modebar_add=tools.split(","),
        )

        fig.update_traces(marker_size=1, selector=dict(type='scatter3d'))

        if display_result: fig.show()

    if embedding.shape[1] == 2:
        output_notebook()
        df = pd.DataFrame(embedding, columns = ['x', 'y'])
        df = pd.concat((df, data), axis=1)
        tooltips = [
            ('x, y', '$x, $y'),
            ('index', '$index')
        ]
        if hover:
            for col in hover:
                tooltips.append((col, "@"+col))
        fig = figure(tools=tools, width=width, height=height, tooltips=tooltips)
        if df[hue].nunique() < nv_cat or df[hue].dtype == "category":
            df[hue] = df[hue].astype(str)
            source = ColumnDataSource(df)
            color_mapper = factor_cmap(
            field_name=hue,
            palette='Category10_3',
            factors=df[hue].unique()
            )
            fig.scatter(
            x='x', y='y',
            color=color_mapper,
            source=source,
            legend_group=hue)

            fig.legend.location = 'bottom_left'
            fig.legend.click_policy = 'mute'
        else:
            source = ColumnDataSource(df)
            color_mapper = linear_cmap(
                field_name=hue,
                palette=Viridis256,
                low=min(df[hue]),
                high=max(df[hue]))
            fig.scatter(
                x='x', y='y',
                color=color_mapper,
                source=source)
            color_bar = ColorBar(color_mapper=color_mapper['transform'], width=8, location=(0,0), title = hue)
            fig.add_layout(color_bar, 'right')


        if display_result: show(fig)

    if embedding.shape[1] > 3:
        print("wrong species, doooooodes")
    else: return fig

import pandas as pd
import numpy as np
from sklearn.manifold import TSNE

# Assuming your embeddings are in a variable called 'embeddings'
# embeddings = np.random.rand(10000, 384)  # Example - replace with your actual embeddings

# Create t-SNE
tsne = TSNE(n_components=2, random_state=1,
            init='pca', n_iter=1000,
            metric='euclidean')

# Fit and transform your data
tsne_results = tsne.fit_transform(embeddings[:1000])

# Prepare the data DataFrame correctly
data_df = pd.DataFrame({
    'label': np.array(labels[:1000])  # Assuming you have labels
    # Add any other columns you want for hover information
})

# Call the plotting function correctly
embedding_plotter(
    embedding=tsne_results,  # This should be your 2D t-SNE results (1000x2 array)
    data=data_df,            # This contains your labels and other metadata
    hue='label',             # Column name in data_df to use for coloring
)






