In [31]:
# !pip install opencv-python

In [32]:
import os
import numpy as np
from sklearn.cluster import SpectralClustering 
import matplotlib.pyplot as plt
from collections import Counter
import torch
from torchvision import models, transforms
from PIL import Image
import shutil

In [33]:
# Base directory with "train" and "val" folders
base_dir = "/vol/bitbucket/wr323/ILSVRC2017_VID_PROCESSED_SEP"
img_save_base_dir = "/vol/bitbucket/wr323/ILSVRC2017_VID_Subsets_2frames/ResNet_SC"

result_save_dir = "/vol/bitbucket/wr323/Result 2017 Pre-Trained-CNN and Spectral"
os.makedirs(result_save_dir, exist_ok=True)

num_clusters = 4
num_key_clusters = 2

In [34]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load pre-trained ResNet50
model = models.resnet50(weights='ResNet50_Weights.DEFAULT')

# Remove classification layer
model = torch.nn.Sequential(*list(model.children())[:-1]) 
model.to(device) 
model.eval()  # Set to evaluation mode

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

Using device: cuda


In [35]:
def read_images(video_dir):
    images = []; filenames = []; image_paths = []
    for filename in os.listdir(video_dir):
        if filename.lower().endswith((".jpeg", ".jpg", ".png")):
            img = Image.open(os.path.join(video_dir, filename))
            images.append(img)
            filenames.append(filename)
            image_paths.append(os.path.join(video_dir, filename))
    return images, filenames, image_paths

def process_video_dir(video_dir, save_dir, result_save_dir, debugging = True):
    images, filenames, image_paths = read_images(video_dir)

    if len(images) - 1 < num_clusters:  # Handle case with fewer images than clusters
        print(f"Saving all frames from {video_dir} (found {len(images)}, need at least {num_clusters})")
        for image, filename in zip(images, filenames):
            if debugging:
                plt.imshow(closest_image)
                plt.title(filename)
                plt.show()
            else:
                source_path = os.path.join(video_dir, filename) 
                destination_path = os.path.join(save_dir, filename)
                # print(source_path)
                # print(destination_path)
                # print('----------------')
                shutil.copy2(source_path, destination_path)
            # fig, axes = plt.subplots(1, num_key_clusters, figsize=(25, 5))
            # axes[idx].imshow(closest_image)
            # axes[idx].set_title(f"Cluster {cluster_idx}: {filename}")
            # axes[idx].axis('off')
        return  # Exit function early if fewer images than clusters

    # Compute embedding for each image with 16 bins
    embeddings = []
    for img in images:
        # print(type(img))
        img_tensor = preprocess(img).unsqueeze(0).to(device)  # Add batch dimension
        with torch.no_grad():
            embedding = model(img_tensor)
        embedding = embedding.cpu().squeeze().numpy()  
        # print(len(embedding)) # 2048
        embeddings.append(embedding.astype('double'))

    # Flatten embeddings
    flattened_embeddings = [embedding.reshape(1, -1) for embedding in embeddings]

    # Concatenate flattened embeddings into one array
    all_embeddings = np.concatenate(flattened_embeddings, axis=0)

    # Perform K-means clustering
    spectral_cluster = SpectralClustering(n_clusters=num_clusters, affinity='nearest_neighbors', n_neighbors=num_clusters).fit(all_embeddings)

    # Assign each embedding to a cluster
    embedding_clusters = spectral_cluster.labels_
    
    # Count the number of images in each cluster
    cluster_counts = Counter(embedding_clusters)

    # Get the 5 largest clusters by their count
    largest_clusters = cluster_counts.most_common(num_key_clusters)

    print(largest_clusters)
        
    # Save results in the appropriate structure
    non_empty_cluster_count = 0
    closest_images = []

    for idx, (cluster_idx, _) in enumerate(largest_clusters):
        
        cluster_embedding_img_filename = [(embedding, image, filename, image_path) for embedding, image, cluster, filename, image_path in zip(
            embeddings, images, embedding_clusters, filenames, image_paths) if cluster == cluster_idx]
        
        if len(cluster_embedding_img_filename) > 0:
            # print(len(cluster_embedding_img_filename))
            non_empty_cluster_count += 1

            cluster_embeddings = np.array([embedding for (embedding, _, _, _) in cluster_embedding_img_filename])
            centroid = np.mean(cluster_embeddings, axis=0)

            distances = [np.linalg.norm(embedding - centroid) for (embedding, _, _, _) in cluster_embedding_img_filename]
            _, closest_image, filename, closest_image_path = cluster_embedding_img_filename[np.argmin(distances)]
            # print(filename)
            if debugging:
                closest_images.append((closest_image, filename))
            else: 
                destination_path = os.path.join(save_dir, filename)
                # print(closest_image_path)
                # print(destination_path)
                # print('----------------')
                shutil.copy2(closest_image_path, destination_path)

    if debugging:
        # Sort closest images by filename
        closest_images.sort(key=lambda x: x[1])

        # Display results in subplots
        fig, axes = plt.subplots(1, num_key_clusters, figsize=(31, 5))  # 1 row, num_key_clusters columns

        for idx, (image, filename) in enumerate(closest_images):
            axes[idx].imshow(image)
            axes[idx].set_title(f"Cluster {embedding_clusters[idx]}: {filename}")
            axes[idx].axis('off')  # Turn off axis labels

        plt.tight_layout()  # Adjust layout for better spacing
        plt.show()
        video_name = video_dir.split("/")[-1]
        fig.savefig(os.path.join(result_save_dir, f'{video_name}.jpg'))
        # fig.close()

    print("non_empty_cluster_count: ", non_empty_cluster_count)

In [None]:
debugging = False

# Loop over train and val folders
for split in ["train"]: # validation use full set
    split_dir = os.path.join(base_dir, split)

    # Loop over class folders within train/val
    for class_name in os.listdir(split_dir):
        class_dir = os.path.join(split_dir, class_name)

        count = 0

        # Loop over video folders within class folders
        for video_name in os.listdir(class_dir):
            if count == 3 and debugging:
                break
            count += 1

            video_dir = os.path.join(class_dir, video_name)

            # Create the save directory for this video within the new structure
            save_dir = os.path.join(img_save_base_dir, split, class_name)
            os.makedirs(save_dir, exist_ok=True)

            # Process the video's frames
            process_video_dir(video_dir, save_dir, result_save_dir, debugging=debugging)