In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from PIL import Image
import os
import shutil

# Load pretrained ResNet-18 model
resnet = models.resnet18(pretrained=True)
# Remove the final fully connected layer
resnet = nn.Sequential(*list(resnet.children())[:-1])
# Set the model to evaluation mode
resnet.eval()

# Define a function to preprocess images
def preprocess_image(image_path):
    image = Image.open(image_path).convert('RGB')
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = preprocess(image)
    # Add batch dimension
    image = image.unsqueeze(0)
    return image

# Define a function to extract image features
def extract_features(image_path):
    image = preprocess_image(image_path)
    with torch.no_grad():
        features = resnet(image)
    # Flatten the feature vector
    features = features.squeeze().numpy()
    return features

# Define a function to perform K-means clustering
def perform_clustering(feature_vectors, num_clusters):
    kmeans = KMeans(n_clusters=num_clusters)
    cluster_labels = kmeans.fit_predict(feature_vectors)
    return cluster_labels

# Define a function to select representative images from each cluster
def select_representative_images(cluster_labels, image_features, image_paths, num_representatives):
    representatives = {}
    for cluster_label in set(cluster_labels):
        cluster_indices = np.where(cluster_labels == cluster_label)[0]
        cluster_features = [image_features[i] for i in cluster_indices]  # Use feature vectors
        cluster_paths = [image_paths[i] for i in cluster_indices]  # Corresponding image paths
        # Calculate the centroid of the cluster
        centroid = np.mean(cluster_features, axis=0)
        # Find the closest feature vector to the centroid
        representative_index = np.argmin(np.linalg.norm(cluster_features - centroid, axis=1))
        representative_image = cluster_paths[representative_index]  # Get the corresponding image path
        representatives[cluster_label] = representative_image
    return representatives

# Define paths to image directory
image_dir = '/content/drive/MyDrive/patches/1'

# Extract features from images in the directory
image_paths = [os.path.join(image_dir, img_name) for img_name in os.listdir(image_dir)]
image_features = [extract_features(img_path) for img_path in image_paths]

# Perform clustering
num_clusters = 600
cluster_labels = perform_clustering(image_features, num_clusters)

# Select representative images
num_representatives = 1
representatives = select_representative_images(cluster_labels,image_features, image_paths, num_representatives)

# Display representative images
for cluster_label, image_path in representatives.items():
    shutil.copy(image_path,'/content/drive/MyDrive/patches/1_subset/')
