In [None]:
import pytesseract
import torch
import torch.nn as nn
from transformers import ViTModel, ViTImageProcessor
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import csv
from tqdm import tqdm

## Setup Decoder
decoder is seted up to fixed output size, so when training all images will be scaled to that size

In [None]:
class SimpleDecoder(nn.Module):
    def __init__(self, embedding_dim=768, img_size=224, channels=3):
        super().__init__()

        # The ViT uses 16x16 patches for a 224x224 image, so the patch grid is 14x14.
        # We'll start with a 14x14 feature map and upscale step-by-step.
        self.img_size = img_size
        self.channels = channels
        self.initial_resolution = 14
        self.initial_channels = 64

        # Map from embedding vector to a 14x14 feature map with 64 channels
        self.linear = nn.Linear(embedding_dim, self.initial_channels * self.initial_resolution * self.initial_resolution)

        # A series of ConvTranspose2d layers to gradually upscale 14x14 -> 28x28 -> 56x56 -> 112x112 -> 224x224
        # kernel_size=4, stride=2, padding=1 doubles spatial dimensions
        self.up1 = nn.ConvTranspose2d(self.initial_channels, self.initial_channels, kernel_size=4, stride=2, padding=1) # 14->28
        self.up2 = nn.ConvTranspose2d(self.initial_channels, self.initial_channels, kernel_size=4, stride=2, padding=1) # 28->56
        self.up3 = nn.ConvTranspose2d(self.initial_channels, self.initial_channels, kernel_size=4, stride=2, padding=1) # 56->112
        self.up4 = nn.ConvTranspose2d(self.initial_channels, self.initial_channels, kernel_size=4, stride=2, padding=1) # 112->224

        # Final convolution to reduce channels to 3 for RGB output
        self.final_conv = nn.Conv2d(self.initial_channels, channels, kernel_size=3, stride=1, padding=1)

        # Optional normalization or activation layers could be added here
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # x: [batch_size, seq_len, embedding_dim]
        # Take the CLS embedding (x[:, 0, :]) as a global image representation
        cls_embedding = x[:, 0, :] # [batch_size, embedding_dim]

        # Project to low-res feature map
        out = self.linear(cls_embedding) # [batch_size, 64 * 14 * 14]
        out = out.view(-1, self.initial_channels, self.initial_resolution, self.initial_resolution) # [batch, 64, 14, 14]

        # Upsample steps
        out = self.relu(self.up1(out)) # [batch, 64, 28, 28]
        out = self.relu(self.up2(out)) # [batch, 64, 56, 56]
        out = self.relu(self.up3(out)) # [batch, 64, 112, 112]
        out = self.relu(self.up4(out)) # [batch, 64, 224, 224]

        # Convert to 3 channels (RGB)
        out = self.final_conv(out) # [batch, 3, 224, 224]

        return out

## Load Foundation model
In the model I want to freeze first layers to not lose the low level feature extraction, tuning will happen on the last layers that produce the final embedding.

In [None]:
import glob
from torchvision.models import convnext_small, ConvNeXt_Small_Weights


def load_latest_checkpoint(checkpoint_dir_model, checkpoint_dir_decoder, embedding_dim):
    
    model_checkpoints = glob.glob(f"{checkpoint_dir_model}/checkpoint_epoch_*.pt")
    decoder_checkpoints = glob.glob(f"{checkpoint_dir_decoder}/checkpoint_epoch_*.pt")

    # If no checkpoints are found, initialize a new model, decoder, and optimizer
    if not model_checkpoints or not decoder_checkpoints:
        print("No checkpoints found")
        model = convnext_small(weights=ConvNeXt_Small_Weights.DEFAULT)
        decoder = SimpleDecoder(embedding_dim, img_size=420)
        return 0, model, decoder

    # Extract the epoch numbers from the checkpoint filenames
    model_epochs = [int(f.split('_')[-1].split('.')[0]) for f in model_checkpoints]
    decoder_epochs = [int(f.split('_')[-1].split('.')[0]) for f in decoder_checkpoints]

    # Find the latest epoch that exists in both directories
    latest_epoch = min(max(model_epochs), max(decoder_epochs))

    # Construct the paths to the latest checkpoints
    latest_model_checkpoint = f"{checkpoint_dir_model}/checkpoint_epoch_{latest_epoch}.pt"
    latest_decoder_checkpoint = f"{checkpoint_dir_decoder}/checkpoint_epoch_{latest_epoch}.pt"

    print(f"Loading latest model checkpoint: {latest_model_checkpoint}")
    print(f"Loading latest decoder checkpoint: {latest_decoder_checkpoint}")

    # Load the model and decoder
    model = convnext_small(weights=ConvNeXt_Small_Weights.DEFAULT)
    decoder = SimpleDecoder(embedding_dim)

    # Load the checkpoints
    model_checkpoint = torch.load(latest_model_checkpoint)
    decoder_checkpoint = torch.load(latest_decoder_checkpoint)

    # Load states into the model, decoder, and optimizer
    model.load_state_dict(model_checkpoint['model_state_dict'])
    decoder.load_state_dict(decoder_checkpoint['model_state_dict'])

    print(f"Resumed from epoch {latest_epoch}")
    return latest_epoch, model, decoder


In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
checkpoint_dir_model = "./checkpoint_model"
checkpoint_dir_decoder = "./checkpoint_decoder"
embedding_dim = 768

start_epoch, model, decoder = load_latest_checkpoint(
    checkpoint_dir_model,
    checkpoint_dir_decoder,
    embedding_dim,
)

In [None]:
# Count how many layers (blocks) the ConvNeXt model has
num_stages = len(model.features)  # Number of stages
num_layers = sum(len(stage) if isinstance(stage, torch.nn.Sequential) else 1 for stage in model.features)

num_layers


In [None]:
freeze_layers = int(0.75 * num_layers)

# Track the current layer index
current_layer = 0

# Iterate through the stages and freeze layers
for stage in model.features:
    if isinstance(stage, torch.nn.Sequential):
        for layer in stage:
            if current_layer < freeze_layers:
                for param in layer.parameters():
                    param.requires_grad = False
            else:
                for param in layer.parameters():
                    param.requires_grad = True
            current_layer += 1
    else:  # If the stage is a single layer
        if current_layer < freeze_layers:
            for param in stage.parameters():
                param.requires_grad = False
        else:
            for param in stage.parameters():
                param.requires_grad = True
        current_layer += 1

In [None]:
embedding_dim = model.classifier[-1].in_features
embedding_dim

In [None]:
from torchvision import transforms

def convNeXt_processor(resize_size=420):
    return transforms.Compose([
        # transforms.Resize(resize_size, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

processor = convNeXt_processor()

## Load Dataset

In [None]:
import os

# folder_path = "./nov27-train_set-labelled"
folder_path = "/Users/antonnovokhatskiy/Desktop/brocvoli/nov27-train_set-labelled"


# pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'

image_files = [
    f for f in os.listdir(folder_path) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))
]

image_arrays = []
image_labels = []

for file_name in tqdm(image_files, desc="Loading Images"):
    image_path = os.path.join(folder_path, file_name)
    with Image.open(image_path) as img:
        img = img.convert("RGB")
        img_array = np.array(img, dtype=np.uint8)
        class_img = img_array[0:60, 0:400]
        label = pytesseract.image_to_string(class_img, lang="eng")
        
        img_array = img_array[60:, :, :]
        image_arrays.append(img_array)
        image_labels.append(label)

images = image_arrays

In [None]:
# csv_file = "data.csv"
# 
# with open(csv_file, "r") as file:
#     reader = csv.reader(file)
#     shape = tuple(map(int, next(reader)))
#     flat_images = [list(map(int, row)) for row in tqdm(reader)]
# 
# images = [np.array(flat_image, dtype=np.uint8).reshape(shape) for flat_image in tqdm(flat_images)]

In [None]:
from torch.utils.data import Dataset, DataLoader, SequentialSampler
from PIL import Image
import numpy as np

class ResizingDataset(Dataset):
    def __init__(self, images, labels, target_size=(224, 224)):
        """
        Dataset class to handle resizing of images and storing labels.
        
        Args:
            images (list): List of image arrays.
            labels (list): List of labels corresponding to the images.
            target_size (tuple): Target size for resizing the images.
        """
        self.images = images
        self.labels = labels
        self.target_size = target_size

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]

        # Convert to PIL image for resizing
        pil_image = Image.fromarray(img)
        resized_image = pil_image.resize(self.target_size)

        # Convert back to numpy array
        resized_image = np.array(resized_image, dtype=np.uint8)

        return resized_image, label

    def get_original(self, idx):
        return self.images[idx], self.labels[idx]


dataset = ResizingDataset(images, image_labels, target_size=(420, 420))

sampler = SequentialSampler(dataset)
loader = DataLoader(dataset, sampler=sampler, batch_size=32)



In [None]:
image1, label1 = dataset[33]
Image.fromarray(image1)

In [None]:
image2, label2 = dataset.get_original(33)
Image.fromarray(image2)

## Visual setup

In [None]:
def tensor_to_pil(tensor):
    """
    Convert a [C, H, W] tensor in [0, 1] range to a PIL Image.
    All operations except for the final conversion happen on GPU to leverage GPU speed,
    then a minimal transfer to CPU is done at the end.
    """
    with torch.no_grad():
        # If tensor is not already on GPU, move it
        tensor = tensor.to('cuda', non_blocking=True)

        # Scale and clamp while on GPU
        tensor = tensor * 255.0
        tensor = torch.clamp(tensor, 0, 255)

        # Permute dimensions on GPU
        tensor = tensor.permute(1, 2, 0)

        # Move the final result to CPU only once
        tensor = tensor.to('cpu', non_blocking=True).detach()

        # Convert to NumPy and then to PIL image
        np_image = tensor.numpy().astype(np.uint8)
        pil_image = Image.fromarray(np_image)

    return pil_image

def show_image(original, transformed, title="Image"):
    """
    Display original and transformed images side-by-side.
    Both original and transformed are assumed to be on GPU or CPU.
    We do the minimal CPU transfer right before showing images.
    """
    # Convert both images efficiently
    # (If not needed, consider doing both conversions back-to-back to save overhead)
    pil_image_original = tensor_to_pil(original)
    pil_image_transformed = tensor_to_pil(transformed)

    # Create a matplotlib figure
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))

    # Display the original image
    axs[0].imshow(pil_image_original)
    axs[0].set_title("Original")
    axs[0].axis('off')

    # Display the transformed image
    axs[1].imshow(pil_image_transformed)
    axs[1].set_title("Transformed")
    axs[1].axis('off')

    plt.suptitle(title)
    plt.show()

    # Clean up GPU memory after displaying
    # If you no longer need the original or transformed tensors, delete them
    del original, transformed
    torch.cuda.empty_cache()

In [None]:
image, label = dataset[33]

# image = image.resize((224, 224))

In [None]:
print(label)
Image.fromarray(image)

## Initial Performance

In [None]:
images_copy = dataset.images.copy()

im_vec_initial = {
    "images": images_copy,
    "labels": image_labels,
    "embeddings": [],
    "projection": []
}

In [None]:
import time

def compute_emebddings():
    model.eval()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Move model to GPU once
    model.to(device)
    
    result = []

    with torch.no_grad():
        for batch_x, batch_y in tqdm(loader, desc="Processing Images"):

            # inputs = torch.stack([processor(image) for image in batch_x])
            inputs = batch_x.float() / 255.0
            inputs = inputs.permute(0, 3, 1, 2)
            inputs = inputs.to(device, non_blocking=True)

            features = model.features(inputs)
            features = model.avgpool(features)
            embeddings = torch.flatten(features, 1) 

            result.extend(embeddings.detach().cpu().numpy())

            del embeddings
            del features
            del inputs
            torch.cuda.empty_cache()

        return result

In [None]:
im_vec_initial["embeddings"] = compute_emebddings()

In [None]:
from sklearn.preprocessing import LabelEncoder
import numpy as np

# Initialize LabelEncoder
label_encoder = LabelEncoder()

# Fit and transform the labels to numerical values
encoded_labels = label_encoder.fit_transform(im_vec_initial["labels"])

# Add the encoded labels to your dictionary
im_vec_initial["labels"] = encoded_labels

# Optionally, store the mapping for future reference
label_mapping = dict(zip(label_encoder.classes_, label_encoder.transform(label_encoder.classes_)))
print("Label Mapping:", label_mapping)


In [None]:
def k_means(embeddings, labels, projection, k=3, max_iters=100, tol=1e-4, show_plot=False):
    # Ensure embeddings are a torch.Tensor
    if isinstance(embeddings, list):
        embeddings = torch.tensor(embeddings, dtype=torch.float32)
    elif isinstance(embeddings, np.ndarray):
        embeddings = torch.from_numpy(embeddings).float()
    elif isinstance(embeddings, torch.Tensor):
        embeddings = embeddings.float()

    # Ensure labels are a torch.Tensor
    if isinstance(labels, list):
        labels = np.array(labels)
    if isinstance(labels, np.ndarray):
        labels = torch.from_numpy(labels)
    elif isinstance(labels, torch.Tensor):
        labels = labels

    # Move embeddings and labels to GPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    embeddings = embeddings.to(device)
    labels = labels.to(device)

    N, D = embeddings.shape

    # Initialize centroids by selecting k random embeddings
    torch.manual_seed(33)
    indices = torch.randperm(N)[:k]
    centroids = embeddings[indices].clone()

    for itr in tqdm(range(max_iters), desc="K-Means Clustering"):
        # Compute distances between embeddings and centroids using (a - b)^2 = a^2 + b^2 - 2ab
        embeddings_squared = torch.sum(embeddings ** 2, dim=1, keepdim=True)  # [N, 1]
        centroids_squared = torch.sum(centroids ** 2, dim=1)  # [k]
        distances = embeddings_squared + centroids_squared - 2 * torch.matmul(embeddings, centroids.t())  # [N, k]

        # Assign clusters based on closest centroid
        cluster_assignments = torch.argmin(distances, dim=1)  # [N]

        # Compute new centroids as the mean of assigned embeddings
        new_centroids = torch.zeros_like(centroids)
        for cluster in range(k):
            assigned_embeddings = embeddings[cluster_assignments == cluster]
            if assigned_embeddings.shape[0] > 0:
                new_centroids[cluster] = assigned_embeddings.mean(dim=0)
            else:
                # If a cluster lost all its members, reinitialize it randomly
                new_centroids[cluster] = embeddings[torch.randint(0, N, (1,))]

        # Check for convergence
        centroid_shift = torch.norm(new_centroids - centroids, dim=1).mean().item()
        centroids = new_centroids.clone()
        if centroid_shift < tol:
            print(f"Converged at iteration {itr}")
            break

        # Clean up to save memory
        del distances, cluster_assignments, assigned_embeddings, new_centroids
        torch.cuda.empty_cache()

    # Move data back to CPU for further processing
    centroids = centroids.to('cpu')
    cluster_assignments = cluster_assignments.to('cpu')
    labels = labels.to('cpu')

    # Assign labels to clusters based on majority vote
    label_mapping = {}
    for cluster in range(k):
        assigned_labels = labels[cluster_assignments == cluster]
        if len(assigned_labels) == 0:
            label_mapping[cluster] = -1  # Unassigned
            continue
        most_common = torch.bincount(assigned_labels).argmax().item()
        label_mapping[cluster] = most_common

    # Assign cluster labels based on mapping
    assigned_cluster_labels = [label_mapping[cluster.item()] for cluster in cluster_assignments]

    # Compute accuracy
    accuracy = np.mean(np.array(assigned_cluster_labels) == labels.numpy())

    # Optionally, plot the clusters using UMAP projections
    if show_plot:
        if projection is None:
            raise ValueError("UMAP projections not provided. Please compute and pass the projection.")

        if isinstance(projection, torch.Tensor):
            projection = projection.cpu().numpy()
        elif isinstance(projection, list):
            projection = np.array(projection)
        elif not isinstance(projection, np.ndarray):
            raise ValueError("Projection must be a numpy array or torch.Tensor")

        plt.figure(figsize=(12, 6))

        # Plot K-Means Clusters
        plt.subplot(1, 2, 1)
        scatter = plt.scatter(projection[:, 0], projection[:, 1], c=assigned_cluster_labels, cmap='viridis', alpha=0.6)
        plt.title("K-Means Clustering Results (UMAP Projection)")
        plt.xlabel("UMAP Component 1")
        plt.ylabel("UMAP Component 2")
        plt.colorbar(scatter, label='Cluster Label')

        # Plot Original Classes
        plt.subplot(1, 2, 2)
        scatter = plt.scatter(projection[:, 0], projection[:, 1], c=labels.numpy(), cmap='viridis', alpha=0.6)
        plt.title("Original Class Labels (UMAP Projection)")
        plt.xlabel("UMAP Component 1")
        plt.ylabel("UMAP Component 2")
        plt.colorbar(scatter, label='Original Label')

        plt.tight_layout()
        plt.show()

    # Clean up GPU memory
    del embeddings, cluster_assignments, labels
    torch.cuda.empty_cache()

    return {
        "cluster_labels": assigned_cluster_labels,
        "accuracy": accuracy,
        "centroids": centroids,
        "label_mapping": label_mapping
    }
    

In [None]:
from umap import UMAP
import numpy as np

def compute_umap_embeddings(embeddings, dimensions=2):
    """

    Args:
        embeddings (array-like): The original high-dimensional embeddings.
        dimensions (int): The number of dimensions for UMAP reduction (default is 2).
    """
    reducer = UMAP(n_components=dimensions, random_state=42)
    reduced_embeddings = reducer.fit_transform(embeddings)
    return reduced_embeddings


In [None]:
embeddings = np.array(im_vec_initial["embeddings"])
images = im_vec_initial["images"]

reduced_embeddings = compute_umap_embeddings(embeddings, dimensions=2)

im_vec_initial["projection"] =reduced_embeddings

In [None]:
results = k_means(
    embeddings=im_vec_initial["embeddings"],
    labels=im_vec_initial["labels"],
    projection=im_vec_initial["projection"],
    k=3,           
    max_iters=100,    
    tol=1e-2,         
    show_plot=True   
)

print(f"Clustering Accuracy: {results['accuracy'] * 100:.2f}%")

In [None]:
import plotly.graph_objects as go
from IPython.display import display
import ipywidgets as widgets
from PIL import Image
import numpy as np
import io

def plot_umap_with_images(reduced_embeddings, images):
    """
    Plots UMAP-reduced embeddings with interactive selection to show corresponding images in a grid.
    
    Args:
        reduced_embeddings (np.ndarray): The reduced embeddings from UMAP with shape (n_samples, 2).
        images (list): List of images (as NumPy arrays) corresponding to the embeddings.
    """
    # Validate inputs
    if not isinstance(reduced_embeddings, np.ndarray):
        raise ValueError("reduced_embeddings must be a NumPy array.")
    if reduced_embeddings.shape[1] != 2:
        raise ValueError("reduced_embeddings must have shape (n_samples, 2).")
    if len(reduced_embeddings) != len(images):
        raise ValueError("The number of embeddings must match the number of images.")

    # Initialize a set to keep track of selected indices
    selected_indices = set()

    # Create a FigureWidget for interactivity
    fig = go.FigureWidget(
        data=go.Scatter(
            x=reduced_embeddings[:, 0],
            y=reduced_embeddings[:, 1],
            mode='markers',
            marker=dict(
                size=3,
                color=['blue'] * len(images),  # Initial color for all points
                opacity=0.7
            ),
            customdata=list(range(len(images))),  # Store image indices
            hoverinfo='text',
            hovertext=[f'Image Index: {i}' for i in range(len(images))]
        )
    )

    # Update layout to make the plot square
    fig.update_layout(
        title="UMAP Visualization",
        width=800,
        height=800,
        xaxis_title="UMAP 1",
        yaxis_title="UMAP 2",
        showlegend=False,
        margin=dict(l=40, r=40, t=40, b=40),
        dragmode='lasso'  # Set default drag mode to lasso for multiple selections
    )

    # Create an Output widget to display images as a grid
    image_output = widgets.Output(layout={
        'border': '1px solid black',
        'width': '1500px',
        'height': '800px',
        'overflow': 'auto'
    })

    # Function to update image grid based on selected indices
    def update_image_grid(indices):
        with image_output:
            image_output.clear_output()
            if not indices:
                display(widgets.HTML("<b>No images selected.</b>"))
                return
            n_cols = 6
            img_widgets = []
            for idx in sorted(indices):
                img_array = images[idx]
                # Ensure the image is in uint8 format
                if img_array.dtype != np.uint8:
                    img_array = img_array.astype(np.uint8)
                img = Image.fromarray(img_array)
                # Optionally resize image for better display
                buffer = io.BytesIO()
                img.save(buffer, format='PNG')
                img_bytes = buffer.getvalue()
                # Create Image widget
                img_widget = widgets.Image(
                    value=img_bytes,
                    format='png',
                    width=398,
                    height=224
                )
                img_widgets.append(img_widget)
            # Create GridBox layout
            grid = widgets.GridBox(
                img_widgets,
                layout=widgets.Layout(
                    grid_template_columns=f"repeat({n_cols}, 224px)",
                    grid_gap='10px'
                )
            )
            display(grid)

    # Function to handle click events (toggle selection)
    def on_click(trace, points, state):
        if points.point_inds:
            for idx in points.point_inds:
                if idx in selected_indices:
                    selected_indices.remove(idx)
                else:
                    selected_indices.add(idx)
            # Update marker colors based on selection
            with fig.batch_update():
                trace.marker.color = [
                    'red' if i in selected_indices else 'blue' for i in range(len(images))
                ]
            update_image_grid(selected_indices)

    # Function to handle selection events (lasso or box select)
    def on_select(trace, points, state):
        if points.point_inds:
            # Replace current selection with new selection
            selected_indices.clear()
            for idx in points.point_inds:
                selected_indices.add(idx)
            # Update marker colors based on selection
            with fig.batch_update():
                trace.marker.color = [
                    'red' if i in selected_indices else 'blue' for i in range(len(images))
                ]
            update_image_grid(selected_indices)
        else:
            # If no points selected, clear selection
            selected_indices.clear()
            with fig.batch_update():
                trace.marker.color = ['blue'] * len(images)
            update_image_grid(selected_indices)

    # Attach the click and select events to the scatter plot
    fig.data[0].on_click(on_click)
    fig.data[0].on_selection(on_select)

    # Initial display message
    with image_output:
        display(widgets.HTML("<b>No images selected.</b>"))

    # Layout the plot and image grid side by side
    hbox = widgets.HBox([fig, image_output])
    display(hbox)


In [None]:
plot_umap_with_images(reduced_embeddings, images)

In [None]:
import random

def similarity(query_index, data_dict, k=10):
    """    
    Args:
        query_index (int): Index of the query image in the data_dict.
        data_dict (dict): 
            "images": List of images as NumPy arrays with shape (224, 398, 3).
            "embeddings": List of embedding vectors, each of shape (768,).
        k (int): Number of top similar images to return. Default is 10.
    """
    embeddings_matrix = np.array(data_dict["embeddings"])  # Shape: (n_samples, 768)

    query_embedding = embeddings_matrix[query_index]        # Shape: (768,)
    query_image = data_dict["images"][query_index]           # Shape: (224, 398, 3)

    embeddings_norm = embeddings_matrix / np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
    query_norm = query_embedding / np.linalg.norm(query_embedding)

    cosine_similarities = np.dot(embeddings_norm, query_norm)  # Shape: (n_samples,)

    # Exclude the query itself
    cosine_similarities[query_index] = -np.inf

    # Get top k indices
    top_k_indices = np.argsort(cosine_similarities)[-k:][::-1]  # Descending order

    # Retrieve top k images and their similarity scores
    top_k_images = [data_dict["images"][i] for i in top_k_indices]
    top_k_scores = [cosine_similarities[i] for i in top_k_indices]

    plt.figure(figsize=(5 * (k + 1), 5)) 

    # Display the original (query) image
    plt.subplot(1, k + 1, 1)
    plt.imshow(query_image)
    plt.axis('off')
    plt.title("Original Image", fontsize=20)

    # Display the top k similar images
    for idx, (img, score) in enumerate(zip(top_k_images, top_k_scores), start=2):
        plt.subplot(1, k + 1, idx)
        plt.imshow(img)
        plt.axis('off')
        plt.title(f"{score:.4f}", fontsize=20)

    plt.tight_layout()
    plt.show()

    # Combine images and scores
    top_k_results = list(zip(top_k_images, top_k_scores))

    return top_k_results


def similarity_range(data_dict):
    query_embedding = np.ones(768, dtype=np.float32)
    
    embeddings_matrix = np.array(data_dict["embeddings"])  # Shape: (n_samples, 768)

    embeddings_norm = embeddings_matrix / np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
    query_norm = query_embedding / np.linalg.norm(query_embedding)

    cosine_similarities = np.dot(embeddings_norm, query_norm)  # Shape: (n_samples,)

    return np.min(cosine_similarities), np.max(cosine_similarities)

similarity_range(im_vec_initial)

In [None]:
similarity_range(im_vec_initial)

## Similarity Range
So this is the higlight of the problem. Since the foundational model is trained and designed to be used on diverse datasets, when inserted one datasets, items from it falls in tiny regoion of space. This is what is tended to be fixed, we want to increse the similarity scores range of the data, thus give more space for features to represent features that are relevant to this specific problem.

In [None]:
top_k_results = similarity(650, im_vec_initial, k=10)
top_k_results = similarity(66, im_vec_initial, k=10)

## Training

The idea is to fine-tune the foundation model in unsupervised way with idea of encoder and decoder. Intuition is as follows - in this task the main idea is try to overfit the data as much as possible, since we dont use this embeddings in any multi-dataset tasks or meming retrieval, we only want embeddings to be good comparable to each other. This implies that if the decoder model could overfit on the main idea of dataset and image view, then embedding from encoder model will describe all the features that differ one image from another.

### Setup Train
choose parameters to train on both Encoder and Decoder

In [None]:
params_to_optimize = list(decoder.parameters())

# Add the trainable parameters from the last 25% of ViT layers
for i, layer in enumerate(model.encoder.layer):
    if i >= freeze_layers:
        params_to_optimize.extend(list(layer.parameters()))

In [None]:
optimizer = torch.optim.Adam(params_to_optimize, lr=1e-3)
criterion = nn.MSELoss()

In [None]:
torch.cuda.is_available()

In [None]:
import time
import os
import torch
from tqdm import tqdm  # Optional: For progress bars


def train(num_epochs):
    """
    Trains the model and decoder for a specified number of epochs.

    Parameters:
    - num_epochs (int): The number of epochs to train the model.
    """

    os.makedirs(checkpoint_dir_model, exist_ok=True)
    os.makedirs(checkpoint_dir_decoder, exist_ok=True)

    # -------------------------------
    # 1. Setup Device (CUDA if available)
    # -------------------------------
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Move the model and decoder to the selected device
    model.to(device)
    decoder.to(device)

    # -------------------------------
    # 2. Training Loop Over Epochs
    # -------------------------------
    for epoch in range(1, num_epochs + 1):
        print(f"\nEpoch {epoch}/{num_epochs}")
        epoch_loss = 0.0  # To accumulate loss over the epoch

        # Set models to training mode
        model.train()
        decoder.train()

        # Optional: Use tqdm for a progress bar
        progress_bar = tqdm(enumerate(loader), total=len(loader), desc=f"Training Epoch {epoch}")

        for batch_idx, data in progress_bar:
            # -------------------------------
            # 3. Prepare the Batch
            # -------------------------------
            # Assume batch shape is [Batch, Height, Width, Channels]
            # Permute to [Batch, Channels, Height, Width]
            
            batch_x, batch_y = data

            batch_x = batch_x.permute(0, 3, 1, 2)  # New shape: [Batch, 3, 224, 224]

            # Convert to float and normalize to [0, 1]
            batch_x = batch_x.float() / 255.0

            # Define target as the input batch (for reconstruction)
            target = batch_x.clone()

            # Move batch and target to the device
            batch_x = batch_x.to(device)
            target = target.to(device)

            # -------------------------------
            # 4. Forward Pass
            # -------------------------------
            optimizer.zero_grad()  # Zero the gradients

            outputs = model(batch_x, output_hidden_states=False)  # Forward pass through the model
            reconstruction = decoder(outputs.last_hidden_state)  # Decode the model's output

            # -------------------------------
            # 5. Compute Loss
            # -------------------------------
            loss = criterion(reconstruction, target)
            epoch_loss += loss.item()  # Accumulate loss

            # -------------------------------
            # 6. Backward Pass and Optimization
            # -------------------------------
            loss.backward()  # Backward pass to compute gradients
            optimizer.step()  # Update model parameters

            # -------------------------------
            # 7. Logging and Visualization
            # -------------------------------
            # Update the progress bar with the current loss
            progress_bar.set_postfix({"Loss": f"{loss.item():.4f}"})
            
        # -------------------------------
        # 8. Epoch Summary
        # -------------------------------
        avg_loss_start = time.time()
        avg_epoch_loss = epoch_loss / len(loader)
        avg_loss_end = time.time()
    
        # Epoch completion timing
        end_epoch_time = time.time()
        print(f"(Avg Loss Calculation: {avg_loss_end - avg_loss_start:.4f}s). Average Loss: {avg_epoch_loss:.4f}")
    
        # Save checkpoint every 10 epochs
        if epoch % 1 == 0:
            # Model checkpoint
            start_model_checkpoint = time.time()
            # checkpoint_path_model = os.path.join(checkpoint_dir_model, f'checkpoint_epoch_{epoch}.pt')
            # torch.save({
            #     'epoch': epoch,
            #     'model_state_dict': model.state_dict(),
            #     'optimizer_state_dict': optimizer.state_dict(),
            #     'loss': avg_epoch_loss,
            # }, checkpoint_path_model)
            # end_model_checkpoint = time.time()
            # print(f'Model Checkpoint Saved: {checkpoint_path_model} (Time: {end_model_checkpoint - start_model_checkpoint:.4f}s)')
    
            # Decoder checkpoint
            # start_decoder_checkpoint = time.time()
            # checkpoint_path_decoder = os.path.join(checkpoint_dir_decoder, f'checkpoint_epoch_{epoch}.pt')
            # torch.save({
            #     'epoch': epoch,
            #     'model_state_dict': model.state_dict(),  # Adjust if separate decoder exists
            #     'optimizer_state_dict': optimizer.state_dict(),
            #     'loss': avg_epoch_loss,
            # }, checkpoint_path_decoder)
            # end_decoder_checkpoint = time.time()
            # print(f'Decoder Checkpoint Saved: {checkpoint_path_decoder} (Time: {end_decoder_checkpoint - start_decoder_checkpoint:.4f}s)')


            emebddings = compute_emebddings()
            projection = compute_umap_embeddings(embeddings)
            
            results = k_means(
                embeddings=emebddings,
                labels=im_vec_initial["encoded_labels"],
                projection=projection,
                k=3,
                max_iters=100,
                tol=1e-2,
                show_plot=True
            )
            print(f"Clustering Accuracy: {results['accuracy'] * 100:.2f}%")
            
            show_image(batch_x[0] ,reconstruction[0])
            
    
        print('-' * 50)  # Separator for readability


In [None]:
train(40)

In [None]:
reconstructed_images = decoder(outputs.last_hidden_state)