In [None]:
import torch
import torch.nn as nn
from transformers import ViTModel, ViTImageProcessor
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import csv
from tqdm import tqdm

## Load Foundation model
In the model I want to freeze first layers to not lose the low level feature extraction, tuning will happen on the last layers that produce the final embedding.

In [None]:
# Load the pretrained ViT model
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

In [None]:
# Count how many layers the encoder has
num_layers = len(model.encoder.layer)  # Typically 12 for vit-base

In [None]:
# Freeze the first 75% layers
freeze_layers = int(0.75 * num_layers)  # e.g., 9 if num_layers=12
for i, layer in enumerate(model.encoder.layer):
    if i < freeze_layers:
        for param in layer.parameters():
            param.requires_grad = False
    else:
        # The last 25% remain trainable
        for param in layer.parameters():
            param.requires_grad = True

In [None]:
embedding_dim = model.config.hidden_size

In [None]:
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

## Setup Decoder
decoder is seted up to fixed output size, so when training all images will be scaled to that size

In [None]:
class SimpleDecoder(nn.Module):
    def __init__(self, embedding_dim=768, img_size=224, channels=3):
        super().__init__()

        # The ViT uses 16x16 patches for a 224x224 image, so the patch grid is 14x14.
        # We'll start with a 14x14 feature map and upscale step-by-step.
        self.img_size = img_size
        self.channels = channels
        self.initial_resolution = 14
        self.initial_channels = 64

        # Map from embedding vector to a 14x14 feature map with 64 channels
        self.linear = nn.Linear(embedding_dim, self.initial_channels * self.initial_resolution * self.initial_resolution)

        # A series of ConvTranspose2d layers to gradually upscale 14x14 -> 28x28 -> 56x56 -> 112x112 -> 224x224
        # kernel_size=4, stride=2, padding=1 doubles spatial dimensions
        self.up1 = nn.ConvTranspose2d(self.initial_channels, self.initial_channels, kernel_size=4, stride=2, padding=1) # 14->28
        self.up2 = nn.ConvTranspose2d(self.initial_channels, self.initial_channels, kernel_size=4, stride=2, padding=1) # 28->56
        self.up3 = nn.ConvTranspose2d(self.initial_channels, self.initial_channels, kernel_size=4, stride=2, padding=1) # 56->112
        self.up4 = nn.ConvTranspose2d(self.initial_channels, self.initial_channels, kernel_size=4, stride=2, padding=1) # 112->224

        # Final convolution to reduce channels to 3 for RGB output
        self.final_conv = nn.Conv2d(self.initial_channels, channels, kernel_size=3, stride=1, padding=1)

        # Optional normalization or activation layers could be added here
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # x: [batch_size, seq_len, embedding_dim]
        # Take the CLS embedding (x[:, 0, :]) as a global image representation
        cls_embedding = x[:, 0, :] # [batch_size, embedding_dim]

        # Project to low-res feature map
        out = self.linear(cls_embedding) # [batch_size, 64 * 14 * 14]
        out = out.view(-1, self.initial_channels, self.initial_resolution, self.initial_resolution) # [batch, 64, 14, 14]

        # Upsample steps
        out = self.relu(self.up1(out)) # [batch, 64, 28, 28]
        out = self.relu(self.up2(out)) # [batch, 64, 56, 56]
        out = self.relu(self.up3(out)) # [batch, 64, 112, 112]
        out = self.relu(self.up4(out)) # [batch, 64, 224, 224]

        # Convert to 3 channels (RGB)
        out = self.final_conv(out) # [batch, 3, 224, 224]

        return out

In [None]:
decoder = SimpleDecoder(embedding_dim)

## Load Dataset

In [None]:
csv_file = "data.csv"

with open(csv_file, "r") as file:
    reader = csv.reader(file)
    shape = tuple(map(int, next(reader)))
    flat_images = [list(map(int, row)) for row in tqdm(reader)]

images = [np.array(flat_image, dtype=np.uint8).reshape(shape) for flat_image in tqdm(flat_images)]

In [None]:
from torch.utils.data import Dataset, DataLoader, SequentialSampler

class ResizingDataset(Dataset):
    def __init__(self, images, target_size=(224, 224)):
        self.images = images
        self.target_size = target_size

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        pil_image = Image.fromarray(img)
        resized_image = pil_image.resize(self.target_size)

        return np.array(resized_image, dtype=np.uint8)

    def get_original(self, idx):
        return self.images[idx]

dataset = ResizingDataset(images, target_size=(224, 224))
sampler = SequentialSampler(dataset)
loader = DataLoader(dataset, sampler=sampler, batch_size=4)

In [None]:
image1 = Image.fromarray(dataset[33])
image1

In [None]:
image2 = Image.fromarray(dataset.get_original(33))
image2

## Visual setup

In [None]:
def tensor_to_pil(tensor):
    """
    Args:
        tensor (torch.Tensor): Tensor of shape [C, H, W] with values in [0, 1].

    Returns:
        PIL.Image: Converted image.
    """
    # Move tensor to CPU and detach from the computation graph
    tensor = tensor.cpu().detach()
    # Clamp the tensor to ensure all values are within [0, 1]
    tensor = torch.clamp(tensor, 0, 1)
    
    np_image = tensor.numpy()
    # Transpose the array from [C, H, W] to [H, W, C] for PIL
    np_image = np.transpose(np_image, (1, 2, 0))
    
    np_image = (np_image * 255).astype(np.uint8)
    
    pil_image = Image.fromarray(np_image)

    return pil_image


def show_image(orginal, transformed, title="Image"):
    """
    Args:
        tensor (torch.Tensor): Image tensor of shape [3, H, W]
        title (str): Title for the displayed image
    """
    # Convert tensor to PIL Image
    pil_image_original = tensor_to_pil(orginal)
    pil_image_otransformed = tensor_to_pil(transformed)

    # Create a matplotlib figure
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))

    # Display the image
    axs[0].imshow(pil_image_original)
    axs[0].set_title("original")
    axs[0].axis('off')  # Hide axis

    # Display the transformed image
    axs[1].imshow(pil_image_otransformed)
    axs[1].set_title("transformed")
    axs[1].axis('off')  # Hide axis

    # Show the image inline
    plt.show()



In [None]:
image = Image.fromarray(dataset[33])

# image = image.resize((224, 224))

In [None]:
image

## Initial Performance

In [None]:
images_copy = dataset.images.copy()

im_vec_initial = {
    "images": images_copy,
    "embeddings": []
}

In [None]:
model.eval()

for batch in tqdm(loader, desc="Processing Images"):
    inputs = processor(images=batch, return_tensors="pt")
    outputs = model(**inputs, output_hidden_states=False)
    
    embeddings = outputs.last_hidden_state[:, 0, :]

    im_vec_initial["embeddings"].extend(embeddings.detach().cpu().numpy())
    

In [None]:
from umap import UMAP
import numpy as np

def compute_umap_embeddings(embeddings, dimensions=2):
    """

    Args:
        embeddings (array-like): The original high-dimensional embeddings.
        dimensions (int): The number of dimensions for UMAP reduction (default is 2).
    """
    reducer = UMAP(n_components=dimensions, random_state=42)
    reduced_embeddings = reducer.fit_transform(embeddings)
    return reduced_embeddings


In [None]:
embeddings = np.array(im_vec_initial["embeddings"])
images = im_vec_initial["images"]

reduced_embeddings = compute_umap_embeddings(embeddings, dimensions=2)

In [None]:
import plotly.graph_objects as go
from IPython.display import display
import ipywidgets as widgets
from PIL import Image
import numpy as np
import io

def plot_umap_with_images(reduced_embeddings, images):
    """
    Plots UMAP-reduced embeddings with interactive selection to show corresponding images in a grid.
    
    Args:
        reduced_embeddings (np.ndarray): The reduced embeddings from UMAP with shape (n_samples, 2).
        images (list): List of images (as NumPy arrays) corresponding to the embeddings.
    """
    # Validate inputs
    if not isinstance(reduced_embeddings, np.ndarray):
        raise ValueError("reduced_embeddings must be a NumPy array.")
    if reduced_embeddings.shape[1] != 2:
        raise ValueError("reduced_embeddings must have shape (n_samples, 2).")
    if len(reduced_embeddings) != len(images):
        raise ValueError("The number of embeddings must match the number of images.")

    # Initialize a set to keep track of selected indices
    selected_indices = set()

    # Create a FigureWidget for interactivity
    fig = go.FigureWidget(
        data=go.Scatter(
            x=reduced_embeddings[:, 0],
            y=reduced_embeddings[:, 1],
            mode='markers',
            marker=dict(
                size=3,
                color=['blue'] * len(images),  # Initial color for all points
                opacity=0.7
            ),
            customdata=list(range(len(images))),  # Store image indices
            hoverinfo='text',
            hovertext=[f'Image Index: {i}' for i in range(len(images))]
        )
    )

    # Update layout to make the plot square
    fig.update_layout(
        title="UMAP Visualization",
        width=800,
        height=800,
        xaxis_title="UMAP 1",
        yaxis_title="UMAP 2",
        showlegend=False,
        margin=dict(l=40, r=40, t=40, b=40),
        dragmode='lasso'  # Set default drag mode to lasso for multiple selections
    )

    # Create an Output widget to display images as a grid
    image_output = widgets.Output(layout={
        'border': '1px solid black',
        'width': '1500px',
        'height': '800px',
        'overflow': 'auto'
    })

    # Function to update image grid based on selected indices
    def update_image_grid(indices):
        with image_output:
            image_output.clear_output()
            if not indices:
                display(widgets.HTML("<b>No images selected.</b>"))
                return
            n_cols = 6
            img_widgets = []
            for idx in sorted(indices):
                img_array = images[idx]
                # Ensure the image is in uint8 format
                if img_array.dtype != np.uint8:
                    img_array = img_array.astype(np.uint8)
                img = Image.fromarray(img_array)
                # Optionally resize image for better display
                buffer = io.BytesIO()
                img.save(buffer, format='PNG')
                img_bytes = buffer.getvalue()
                # Create Image widget
                img_widget = widgets.Image(
                    value=img_bytes,
                    format='png',
                    width=398,
                    height=224
                )
                img_widgets.append(img_widget)
            # Create GridBox layout
            grid = widgets.GridBox(
                img_widgets,
                layout=widgets.Layout(
                    grid_template_columns=f"repeat({n_cols}, 224px)",
                    grid_gap='10px'
                )
            )
            display(grid)

    # Function to handle click events (toggle selection)
    def on_click(trace, points, state):
        if points.point_inds:
            for idx in points.point_inds:
                if idx in selected_indices:
                    selected_indices.remove(idx)
                else:
                    selected_indices.add(idx)
            # Update marker colors based on selection
            with fig.batch_update():
                trace.marker.color = [
                    'red' if i in selected_indices else 'blue' for i in range(len(images))
                ]
            update_image_grid(selected_indices)

    # Function to handle selection events (lasso or box select)
    def on_select(trace, points, state):
        if points.point_inds:
            # Replace current selection with new selection
            selected_indices.clear()
            for idx in points.point_inds:
                selected_indices.add(idx)
            # Update marker colors based on selection
            with fig.batch_update():
                trace.marker.color = [
                    'red' if i in selected_indices else 'blue' for i in range(len(images))
                ]
            update_image_grid(selected_indices)
        else:
            # If no points selected, clear selection
            selected_indices.clear()
            with fig.batch_update():
                trace.marker.color = ['blue'] * len(images)
            update_image_grid(selected_indices)

    # Attach the click and select events to the scatter plot
    fig.data[0].on_click(on_click)
    fig.data[0].on_selection(on_select)

    # Initial display message
    with image_output:
        display(widgets.HTML("<b>No images selected.</b>"))

    # Layout the plot and image grid side by side
    hbox = widgets.HBox([fig, image_output])
    display(hbox)


In [None]:
plot_umap_with_images(reduced_embeddings, images)

In [None]:
import random

def similarity(query_index, data_dict, k=10):
    """    
    Args:
        query_index (int): Index of the query image in the data_dict.
        data_dict (dict): 
            "images": List of images as NumPy arrays with shape (224, 398, 3).
            "embeddings": List of embedding vectors, each of shape (768,).
        k (int): Number of top similar images to return. Default is 10.
    """
    embeddings_matrix = np.array(data_dict["embeddings"])  # Shape: (n_samples, 768)

    query_embedding = embeddings_matrix[query_index]        # Shape: (768,)
    query_image = data_dict["images"][query_index]           # Shape: (224, 398, 3)

    embeddings_norm = embeddings_matrix / np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
    query_norm = query_embedding / np.linalg.norm(query_embedding)

    cosine_similarities = np.dot(embeddings_norm, query_norm)  # Shape: (n_samples,)

    # Exclude the query itself
    cosine_similarities[query_index] = -np.inf

    # Get top k indices
    top_k_indices = np.argsort(cosine_similarities)[-k:][::-1]  # Descending order

    # Retrieve top k images and their similarity scores
    top_k_images = [data_dict["images"][i] for i in top_k_indices]
    top_k_scores = [cosine_similarities[i] for i in top_k_indices]

    plt.figure(figsize=(5 * (k + 1), 5)) 

    # Display the original (query) image
    plt.subplot(1, k + 1, 1)
    plt.imshow(query_image)
    plt.axis('off')
    plt.title("Original Image", fontsize=20)

    # Display the top k similar images
    for idx, (img, score) in enumerate(zip(top_k_images, top_k_scores), start=2):
        plt.subplot(1, k + 1, idx)
        plt.imshow(img)
        plt.axis('off')
        plt.title(f"{score:.4f}", fontsize=20)

    plt.tight_layout()
    plt.show()

    # Combine images and scores
    top_k_results = list(zip(top_k_images, top_k_scores))

    return top_k_results


def similarity_range(data_dict):
    query_embedding = np.ones(768, dtype=np.float32)
    
    embeddings_matrix = np.array(data_dict["embeddings"])  # Shape: (n_samples, 768)

    embeddings_norm = embeddings_matrix / np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
    query_norm = query_embedding / np.linalg.norm(query_embedding)

    cosine_similarities = np.dot(embeddings_norm, query_norm)  # Shape: (n_samples,)

    return np.min(cosine_similarities), np.max(cosine_similarities)    

In [None]:
similarity_range(im_vec_initial)

## Similarity Range
So this is the higlight of the problem. Since the foundational model is trained and designed to be used on diverse datasets, when inserted one datasets, items from it falls in tiny regoion of space. This is what is tended to be fixed, we want to increse the similarity scores range of the data, thus give more space for features to represent features that are relevant to this specific problem.

In [None]:
top_k_results = similarity(1381, im_vec_initial, k=10)
top_k_results = similarity(66, im_vec_initial, k=10)
top_k_results = similarity(21, im_vec_initial, k=10)

## Training

The idea is to fine-tune the foundation model in unsupervised way with idea of encoder and decoder. Intuition is as follows - in this task the main idea is try to overfit the data as much as possible, since we dont use this embeddings in any multi-dataset tasks or meming retrieval, we only want embeddings to be good comparable to each other. This implies that if the decoder model could overfit on the main idea of dataset and image view, then embedding from encoder model will describe all the features that differ one image from another.

### Setup Train
choose parameters to train on both Encoder and Decoder

In [None]:
params_to_optimize = list(decoder.parameters())

# Add the trainable parameters from the last 25% of ViT layers
for i, layer in enumerate(model.encoder.layer):
    if i >= freeze_layers:
        params_to_optimize.extend(list(layer.parameters()))

In [None]:
optimizer = torch.optim.Adam(params_to_optimize, lr=1e-3)
criterion = nn.MSELoss()

In [None]:
import torch
from tqdm import tqdm  # Optional: For progress bars

def train(num_epochs):
    """
    Trains the model and decoder for a specified number of epochs.

    Parameters:
    - num_epochs (int): The number of epochs to train the model.
    """

    # -------------------------------
    # 1. Setup Device (CUDA if available)
    # -------------------------------
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Move the model and decoder to the selected device
    model.to(device)
    decoder.to(device)

    # -------------------------------
    # 2. Training Loop Over Epochs
    # -------------------------------
    for epoch in range(1, num_epochs + 1):
        print(f"\nEpoch {epoch}/{num_epochs}")
        epoch_loss = 0.0  # To accumulate loss over the epoch

        # Set models to training mode
        model.train()
        decoder.train()

        # Optional: Use tqdm for a progress bar
        progress_bar = tqdm(enumerate(loader), total=len(loader), desc=f"Training Epoch {epoch}")

        for batch_idx, batch in progress_bar:
            # -------------------------------
            # 3. Prepare the Batch
            # -------------------------------
            # Assume batch shape is [Batch, Height, Width, Channels]
            # Permute to [Batch, Channels, Height, Width]
            batch = batch.permute(0, 3, 1, 2)  # New shape: [Batch, 3, 224, 224]

            # Convert to float and normalize to [0, 1]
            batch = batch.float() / 255.0

            # Define target as the input batch (for reconstruction)
            target = batch.clone()

            # Move batch and target to the device
            batch = batch.to(device)
            target = target.to(device)

            # -------------------------------
            # 4. Forward Pass
            # -------------------------------
            optimizer.zero_grad()  # Zero the gradients

            outputs = model(batch, output_hidden_states=False)  # Forward pass through the model
            reconstruction = decoder(outputs.last_hidden_state)  # Decode the model's output

            # -------------------------------
            # 5. Compute Loss
            # -------------------------------
            loss = criterion(reconstruction, target)
            epoch_loss += loss.item()  # Accumulate loss

            # -------------------------------
            # 6. Backward Pass and Optimization
            # -------------------------------
            loss.backward()  # Backward pass to compute gradients
            optimizer.step()  # Update model parameters

            # -------------------------------
            # 7. Logging and Visualization
            # -------------------------------
            # Update the progress bar with the current loss
            progress_bar.set_postfix({"Loss": f"{loss.item():.4f}"})
            
        # -------------------------------
        # 8. Epoch Summary
        # -------------------------------
        avg_epoch_loss = epoch_loss / len(loader)
        print(f"Epoch {epoch} completed. Average Loss: {avg_epoch_loss:.4f}")
        original_image = batch[0].detach().cpu()
        reconstructed_image = reconstruction[0].detach().cpu()
        show_image(original_image, reconstructed_image)



In [None]:
train(20)

In [None]:
reconstructed_images = decoder(outputs.last_hidden_state)