# Interactive Clustering of CelebA/Flowers102 Images with PyTorch and Plotly

In this notebook, we will:

- Use **PyTorch** to extract features from the CelebA dataset using a pre-trained VGG16 model.
- Perform dimensionality reduction using **PCA**.
- Apply **K-Means** clustering to group similar images.
- Use **t-SNE** for visualization.
- Create an interactive plot within the Jupyter Notebook where clicking on a point displays the corresponding image.

# TODO: Use data_modules instead

## Prerequisites

Ensure you have the following libraries installed:

```bash
pip install numpy pandas tqdm pillow torch torchvision scikit-learn plotly ipywidgets
```

## 1. Import Necessary Libraries

First, we import all the necessary libraries.


In [1]:
# Import necessary libraries
import os
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from PIL import Image
from tqdm import tqdm

# PyTorch libraries
import torch
from torchvision import models, transforms
from torchvision.models import VGG16_Weights

# Dimensionality reduction and clustering
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE

# Widgets for interactive display
import ipywidgets as widgets
from IPython.display import display

## 2. Check for GPU Availability

We check if a GPU is available and set the device accordingly. 

In [2]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


## 3. Set Up Image Path and Variables

We set the path to the CelebA dataset and list the image files. 

In [5]:
# Set the path to the CelebA or Flowers dataset directory
current_dir = os.getcwd()
#image_folder = os.path.join(current_dir, 'datasets', 'celeba', 'img_align_celeba')
image_folder = os.path.join(current_dir, 'datasets', 'flowers-102', 'jpg')

print(f"Image folder path: {image_folder}")

# List image files
image_files = [f for f in os.listdir(image_folder) if f.endswith('.jpg')]

# Limit the number of images to manage memory usage
max_images = 1000  # Set to desired limit
image_files = np.random.choice(image_files, max_images, replace=False)

# Number of n_clusters=20
clusters=20

# Display the number of images
print(f"Number of images: {len(image_files)}")

Image folder path: D:\Egyetem\Mester\deep_learning\hf\CtrlAltDiffuse\datasets\flowers-102\jpg
Number of images: 1000


> **Note**: Ensure that the `image_folder` path points to the correct location of your CelebA or Flowers dataset. 

## 4. Define Image Transformations

We define the image transformations to preprocess the images. Since not all images are size 256x256, we resize them.

In [23]:
# Define image transformations (no resizing)
transform = transforms.Compose([
    transforms.Resize((256, 256), antialias=True),
    transforms.ToTensor(),
    transforms.Normalize(mean=0.5, std=0.5)
])

## 5. Load Pre-trained VGG16 Model

We load the pre-trained VGG16 model from PyTorch, setting it to evaluation mode. 

In [8]:
# Load the pre-trained VGG16 model
model = models.vgg16(weights=VGG16_Weights.DEFAULT)
model = model.to(device)
model.eval()  # Set model to evaluation mode
print("Model loaded.")

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to C:\Users\szasz/.cache\torch\hub\checkpoints\vgg16-397923af.pth
100%|██████████| 528M/528M [00:05<00:00, 97.1MB/s] 


Model loaded.


## 6. Create Feature Extractor

We create a feature extractor by removing the classification layers from the VGG16 model. 

In [9]:
# Create a feature extractor by removing the classification layers
class FeatureExtractor(torch.nn.Module):
    def __init__(self, model):
        super(FeatureExtractor, self).__init__()
        # Use all layers except the last classifier layers
        self.features = model.features
        # Add adaptive pooling to ensure consistent output size
        self.avgpool = torch.nn.AdaptiveAvgPool2d((7, 7))

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return x

feature_extractor = FeatureExtractor(model).to(device)

## 7. Extract Features from Images

We loop through the images, preprocess them, and extract features using the feature extractor. 

In [10]:
# Extract features from images
print("Extracting features...")
features_list = []

for image_file in tqdm(image_files):
    img_path = os.path.join(image_folder, image_file)
    image = Image.open(img_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0).to(device)

    with torch.no_grad():
        feature = feature_extractor(image)
    feature = feature.cpu().numpy().flatten()
    features_list.append(feature)

features_array = np.array(features_list)
print(f"Shape of extracted features: {features_array.shape}")

Extracting features...


100%|██████████| 1000/1000 [00:41<00:00, 24.03it/s]

Shape of extracted features: (1000, 25088)





> **Note**: This step may take some time depending on the number of images and your hardware. 

## 8. Apply PCA for Dimensionality Reduction

We reduce the dimensionality of the features using PCA. 

In [11]:
# Apply PCA for dimensionality reduction
pca = PCA(n_components=100)
features_pca = pca.fit_transform(features_array)
print(f"Shape after PCA: {features_pca.shape}")

Shape after PCA: (1000, 100)


## 9. Perform K-Means Clustering
 
We cluster the images using K-Means clustering.

In [12]:
# Perform K-Means clustering
kmeans = KMeans(n_clusters=clusters, random_state=42)
kmeans_labels = kmeans.fit_predict(features_pca)
print("Clustering completed.")



Clustering completed.


## 10. Use t-SNE for Visualization

We use t-SNE to reduce the data to two dimensions for visualization. 

In [13]:
# Use t-SNE for visualization
tsne = TSNE(n_components=2, random_state=42, perplexity=30, max_iter=1000)
features_tsne = tsne.fit_transform(features_pca)
print("t-SNE transformation completed.")

t-SNE transformation completed.


## 11. Prepare DataFrame for Visualization 

We create a DataFrame containing the t-SNE results and other relevant data.

In [14]:
# Prepare DataFrame for visualization
tsne_df = pd.DataFrame({
    'TSNE1': features_tsne[:, 0],
    'TSNE2': features_tsne[:, 1],
    'Cluster': kmeans_labels.astype(str),
    'Image File': image_files
})

## 12. Create Interactive Plot and Image Display

We create an interactive plot using Plotly's `FigureWidget` and display the image corresponding to the selected point. 

In [15]:
# Create a FigureWidget
scatter = go.FigureWidget(
    data=go.Scattergl(
        x=tsne_df['TSNE1'],
        y=tsne_df['TSNE2'],
        mode='markers',
        marker=dict(
            color=tsne_df['Cluster'].astype(int),
            colorscale='Viridis',
            showscale=True,
            size=5
        ),
        customdata=tsne_df['Image File'],
        hovertemplate='<b>Image File:</b> %{customdata}<br>' +
                      '<b>Cluster:</b> %{marker.color}<br>' +
                      '<extra></extra>'
    )
)

scatter.update_layout(
    title='t-SNE Visualization of CelebA Images',
    xaxis_title='t-SNE Dimension 1',
    yaxis_title='t-SNE Dimension 2',
    width=800,
    height=600
)

# Create an Image widget to display the selected image
image_widget = widgets.Image(
    format='jpg',
    width=256,
    height=256
)

# Create a function to handle click events
def update_image(trace, points, selector):
    if points.point_inds:
        ind = points.point_inds[0]
        image_file = tsne_df.iloc[ind]['Image File']
        img_path = os.path.join(image_folder, image_file)
        with open(img_path, 'rb') as f:
            img_data = f.read()
            image_widget.value = img_data

# Attach the click event handler
scatter.data[0].on_click(update_image)

# Display the plot and the image widget
container = widgets.VBox([scatter, image_widget])
display(container)

VBox(children=(FigureWidget({
    'data': [{'customdata': array(['image_05977.jpg', 'image_04179.jpg', 'image_…

## 13. Color analysis 

In [27]:
rgb_sum = torch.zeros(3)
pixel_count = 0

transform_image = transforms.ToTensor()

# Initialize histograms for RGB channels (256 bins for pixel intensities 0-255)
hist_r = np.zeros(256)
hist_g = np.zeros(256)
hist_b = np.zeros(256)

for image_file in image_files:
    image_path = os.path.join(image_folder, image_file)
    image = Image.open(image_path)
    
    image_tensor = transform_image(image)
    rgb_sum += image_tensor.sum(dim=[1, 2])
    
    # Count the number of pixels (height * width)
    pixel_count += image_tensor.size(1) * image_tensor.size(2)
    
    # Convert tensor to numpy array and rescale to 0-255 range
    image_np = image_tensor.numpy() * 255
    
    # Flatten the image for each channel
    red_channel = image_np[0, :, :].flatten()
    green_channel = image_np[1, :, :].flatten()
    blue_channel = image_np[2, :, :].flatten()
    
    # Update histograms for each color channel
    hist_r += np.histogram(red_channel, bins=256, range=(0, 256))[0]
    hist_g += np.histogram(green_channel, bins=256, range=(0, 256))[0]
    hist_b += np.histogram(blue_channel, bins=256, range=(0, 256))[0]

mean_rgb = rgb_sum / pixel_count
mean_rgb_np = mean_rgb.numpy()


In [28]:
# Plotly Bar Chart for Mean RGB
fig_rgb = go.Figure()

fig_rgb.add_trace(go.Bar(
    x=['Red', 'Green', 'Blue'],
    y=mean_rgb_np,
    marker=dict(color=['red', 'green', 'blue']),
    name='Mean Color Intensity'
))

fig_rgb.update_layout(
    title="Average Color Distribution in Training Set",
    xaxis_title="Color Channels",
    yaxis_title="Mean Pixel Intensity",
    showlegend=False
)

fig_rgb.show()

In [30]:
# Histogram

fig_histogram = go.Figure()

fig_histogram.add_trace(go.Bar(
    x=np.arange(256), 
    y=hist_r, 
    name='Red Channel',
    marker_color='red',
    opacity=0.6
))

fig_histogram.add_trace(go.Bar(
    x=np.arange(256), 
    y=hist_g, 
    name='Green Channel',
    marker_color='green',
    opacity=0.6
))

fig_histogram.add_trace(go.Bar(
    x=np.arange(256), 
    y=hist_b, 
    name='Blue Channel',
    marker_color='blue',
    opacity=0.6
))

fig_histogram.update_layout(
    title="Aggregated Color Histogram for Dataset",
    xaxis_title="Pixel Intensity",
    yaxis_title="Frequency",
    barmode='overlay',
    legend_title="Color Channels"
)

fig_histogram.show()