# Image Clustering with ImageBind

This notebook demonstrates how to cluster images using embeddings generated by the ImageBind model. The process includes data preprocessing, embedding extraction, applying clustering algorithms, and visualizing the results.

In [2]:
%pip install torch torchvision scikit-learn matplotlib

Defaulting to user installation because normal site-packages is not writeable
Collecting torchvision
  Downloading torchvision-0.20.1-cp312-cp312-win_amd64.whl.metadata (6.2 kB)
Downloading torchvision-0.20.1-cp312-cp312-win_amd64.whl (1.6 MB)
   ---------------------------------------- 0.0/1.6 MB ? eta -:--:--
   ---------------------------------------- 1.6/1.6 MB 8.3 MB/s eta 0:00:00
Installing collected packages: torchvision
Successfully installed torchvision-0.20.1
Note: you may need to restart the kernel to use updated packages.


In [4]:
# Import necessary libraries
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from PIL import Image
import torch
%pip install imagebind
%pip install git+https://github.com/username/imagebind.git
from imagebind import ImageBind

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load ImageBind model
model = ImageBind().to(device)


Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.
Defaulting to user installation because normal site-packages is not writeable
Collecting git+https://github.com/username/imagebind.git
  Cloning https://github.com/username/imagebind.git to c:\users\ramro\appdata\local\temp\pip-req-build-dikbkppq
Note: you may need to restart the kernel to use updated packages.


  Running command git clone --filter=blob:none --quiet https://github.com/username/imagebind.git 'C:\Users\Ramro\AppData\Local\Temp\pip-req-build-dikbkppq'
  remote: Repository not found.
  fatal: repository 'https://github.com/username/imagebind.git/' not found
  error: subprocess-exited-with-error
  
  × git clone --filter=blob:none --quiet https://github.com/username/imagebind.git 'C:\Users\Ramro\AppData\Local\Temp\pip-req-build-dikbkppq' did not run successfully.
  │ exit code: 128
  ╰─> See above for output.
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
error: subprocess-exited-with-error

× git clone --filter=blob:none --quiet https://github.com/username/imagebind.git 'C:\Users\Ramro\AppData\Local\Temp\pip-req-build-dikbkppq' did not run successfully.
│ exit code: 128
╰─> See above for output.

note: This error originates from a subprocess, and is likely not a problem with pip.


ModuleNotFoundError: No module named 'imagebind'

## Data Preprocessing

In this section, we will load the images from the dataset and preprocess them for embedding extraction.

In [2]:
# Load images
def load_images_from_folder(folder):
    images = []
    for filename in os.listdir(folder):
        img_path = os.path.join(folder, filename)
        img = Image.open(img_path).convert('RGB')
        images.append(img)
    return images

images = load_images_from_folder('data/images')


## Embedding Extraction

Next, we will extract embeddings from the loaded images using the ImageBind model.

In [3]:
# Extract embeddings
def extract_embeddings(images):
    embeddings = []
    for img in images:
        img_tensor = transform(img).unsqueeze(0).to(device)
        with torch.no_grad():
            embedding = model(img_tensor)
        embeddings.append(embedding.cpu().numpy())
    return np.vstack(embeddings)

embeddings = extract_embeddings(images)


## Clustering Algorithms

We will apply KMeans clustering to the extracted embeddings.

In [4]:
# Apply KMeans clustering
n_clusters = 5
kmeans = KMeans(n_clusters=n_clusters)
kmeans.fit(embeddings)
labels = kmeans.labels_


## Visualization of Results

Finally, we will visualize the clustering results.

In [5]:
# Visualize clusters
def visualize_clusters(images, labels):
    plt.figure(figsize=(15, 10))
    for i in range(len(images)):
        plt.subplot(5, 5, i + 1)
        plt.imshow(images[i])
        plt.title(f'Cluster: {labels[i]}')
        plt.axis('off')
    plt.show()

visualize_clusters(images, labels)
