# Imports and Configuration
This cell imports necessary libraries and sets up basic configurations like the device (CPU/GPU), the dataset to use (CIFAR-100 only), output directories, model details, and the similarity threshold for defining overlapping classes. It also creates the required output folders, including the new `test_set` directory.

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, ConcatDataset
from torchvision import datasets, transforms
from sklearn.metrics.pairwise import cosine_similarity
from collections import Counter, defaultdict
import numpy as np
import timm
from tqdm.notebook import tqdm
import os
import pandas as pd
from PIL import Image

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

DATASET_NAMES = ['CIFAR100']

OUTPUT_BASE_DIR = 'feature_extraction'
FEATURE_DIR = os.path.join(OUTPUT_BASE_DIR, 'features')

NEW_DATASET_DIR = 'new_dataset'
TEST_SET_DIR = 'test_set'

os.makedirs(FEATURE_DIR, exist_ok=True)
os.makedirs(NEW_DATASET_DIR, exist_ok=True)
os.makedirs(TEST_SET_DIR, exist_ok=True)

DINO_MODEL_NAME = 'vit_small_patch16_224.dino'
FEATURE_DIM = 384
BATCH_SIZE_FEATURES = 256

SIMILARITY_THRESHOLD = 0.9
print(f"Using Cosine Similarity threshold: > {SIMILARITY_THRESHOLD} for overlap")

Using device: cuda
Using Cosine Similarity threshold: > 0.9 for overlap


Data Loading and Transforms
This cell defines the image transformations needed for basic image loading and for DINO feature extraction. It then loads the CIFAR-100 dataset, keeping the training and test sets separate for later saving, but creating concatenated versions with DINO transforms specifically for the feature extraction step.

In [8]:
transform_basic = transforms.Compose([
    transforms.ToTensor()
])

transform_features = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

original_datasets = {}
feature_extract_datasets = {}
original_class_names = {}
dataset_target_attrs = {}

name = 'CIFAR100'
print(f"\n--- Loading Dataset: {name} ---")
num_classes = None
try:
    base_dataset_train = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_basic)
    base_dataset_test = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_basic)
    original_datasets[f"{name}_train"] = base_dataset_train
    original_datasets[f"{name}_test"] = base_dataset_test

    feature_dataset_train = datasets.CIFAR100(root='./data', train=True, download=False, transform=transform_features)
    feature_dataset_test = datasets.CIFAR100(root='./data', train=False, download=False, transform=transform_features)
    feature_extract_datasets[name] = ConcatDataset([feature_dataset_train, feature_dataset_test])

    num_classes = 100
    original_class_names[name] = base_dataset_train.classes
    dataset_target_attrs[name] = 'targets'

    print(f"  Loaded {name} Train: {len(original_datasets[f'{name}_train'])} samples.")
    print(f"  Loaded {name} Test: {len(original_datasets[f'{name}_test'])} samples.")
    print(f"  Total samples for feature extraction: {len(feature_extract_datasets[name])} samples, {num_classes} classes.")

except Exception as e:
    print(f"  Could not load {name}: {e}. Ensure it's downloaded or set download=True.")

len_train_dataset = len(original_datasets[f"{name}_train"])


--- Loading Dataset: CIFAR100 ---
  Loaded CIFAR100 Train: 50000 samples.
  Loaded CIFAR100 Test: 10000 samples.
  Total samples for feature extraction: 60000 samples, 100 classes.


DINO Model Setup
This cell loads the pre-trained DINO Vision Transformer model (`vit_small_patch16_224.dino`) using the `timm` library, sets it to evaluation mode, moves it to the appropriate device (GPU/CPU), and defines a wrapper class `DINOFeatureExtractor` to easily extract features (specifically the CLS token) from the model.

In [9]:
print(f"Loading DINO model: {DINO_MODEL_NAME}")
try:
    dino_model = timm.create_model(DINO_MODEL_NAME, pretrained=True)
    dino_model.eval()
    dino_model.to(DEVICE)
except Exception as e:
    print(f"Error loading DINO model: {e}. Check model name and internet connection.")

class DINOFeatureExtractor(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        with torch.no_grad():
             if hasattr(self.model, 'forward_features'):
                 feats = self.model.forward_features(x)
                 cls_token = feats[:, 0]
             else:
                  raise AttributeError("Model does not have 'forward_features' method.")
        return cls_token

extractor = DINOFeatureExtractor(dino_model)

Loading DINO model: vit_small_patch16_224.dino


Feature Extraction
This cell performs the DINO feature extraction on the concatenated CIFAR-100 dataset (train + test). It checks if the feature and label files already exist to avoid re-computation. If not, it iterates through the data using a DataLoader, extracts features using the DINO model, and saves the resulting feature tensors (`X_all`) and corresponding labels (`y_all`) to files. Finally, it cleans up GPU memory if applicable.

In [10]:
name = 'CIFAR100'
if name not in feature_extract_datasets:
    print(f"\nSkipping feature extraction for {name} (not loaded).")
else:
    print(f"\n--- Extracting Features for: {name} ---")

    feature_path = os.path.join(FEATURE_DIR, f"{name}_X_all_dino.pt")
    label_path = os.path.join(FEATURE_DIR, f"{name}_y_all.pt")

    if os.path.exists(feature_path) and os.path.exists(label_path):
        print(f"  Features for {name} already exist. Skipping extraction.")
    else:
        all_feats = []
        all_labels = []

        current_loader = DataLoader(
            feature_extract_datasets[name],
            batch_size=BATCH_SIZE_FEATURES,
            shuffle=False,
            num_workers=2
        )

        for images, labels in tqdm(current_loader, desc=f"Extracting {name}"):
            images = images.to(DEVICE)
            feats = extractor(images)
            all_feats.append(feats.cpu())
            all_labels.append(labels.cpu())

        X_all = torch.cat(all_feats, dim=0)
        y_all = torch.cat(all_labels, dim=0)

        torch.save(X_all, feature_path)
        torch.save(y_all, label_path)
        print(f"  Saved all features ({X_all.shape}) to {feature_path}")
        print(f"  Saved all labels ({y_all.shape}) to {label_path}")

del dino_model, extractor, feature_extract_datasets
if DEVICE == torch.device('cuda'):
    torch.cuda.empty_cache()


--- Extracting Features for: CIFAR100 ---


Extracting CIFAR100:   0%|          | 0/235 [00:00<?, ?it/s]

  Saved all features (torch.Size([60000, 384])) to feature_extraction\features\CIFAR100_X_all_dino.pt
  Saved all labels (torch.Size([60000])) to feature_extraction\features\CIFAR100_y_all.pt


Overlap Identification
This cell identifies which classes in CIFAR-100 are considered "overlapping" based on their DINO features. It loads the previously extracted features and labels, calculates the mean feature vector (centroid) for each class, computes the pairwise cosine similarity between all class centroids, and identifies pairs with similarity exceeding the defined `SIMILARITY_THRESHOLD`. The indices of these overlapping classes are stored.

In [11]:
class_centroids = {}
overlapping_classes = {}

name = 'CIFAR100'
if name not in original_class_names:
    print(f"\nSkipping overlap analysis for {name} (data not loaded).")
else:
    print(f"\n--- Analyzing Feature Overlap for: {name} ---")

    try:
        feature_path = os.path.join(FEATURE_DIR, f"{name}_X_all_dino.pt")
        label_path = os.path.join(FEATURE_DIR, f"{name}_y_all.pt")

        all_features = torch.load(feature_path)
        all_labels = torch.load(label_path)
        num_classes = len(original_class_names[name])

    except FileNotFoundError:
        print(f"  Feature/label files not found for {name}. Skipping analysis.")

    if 'all_features' in locals():
        print(f"  Calculating centroids for {num_classes} classes...")
        current_centroids = {}
        for i in range(num_classes):
            class_mask = (all_labels == i)
            if class_mask.sum() == 0:
                print(f"  Warning: No samples found for class {i} in {name}. Skipping centroid.")
                continue

            class_feats = all_features[class_mask]
            current_centroids[i] = torch.mean(class_feats, dim=0)

        class_centroids[name] = current_centroids

        if len(current_centroids) < 2:
            print("  Not enough class centroids to compare. Skipping similarity matrix.")
        else:
            centroid_indices = sorted(current_centroids.keys())
            centroid_tensor = torch.stack([current_centroids[idx] for idx in centroid_indices])

            print(f"  Calculating {len(centroid_indices)}x{len(centroid_indices)} similarity matrix...")
            sim_matrix = cosine_similarity(centroid_tensor.numpy())

            current_overlapping_set = set()
            overlapping_pairs_found = []

            for i in range(len(centroid_indices)):
                for j in range(i + 1, len(centroid_indices)):
                    original_i = centroid_indices[i]
                    original_j = centroid_indices[j]
                    similarity = sim_matrix[i, j]

                    if similarity > SIMILARITY_THRESHOLD:
                        current_overlapping_set.add(original_i)
                        current_overlapping_set.add(original_j)
                        overlapping_pairs_found.append( (original_i, original_j, similarity) )

            overlapping_classes[name] = current_overlapping_set

            print(f"  Found {len(overlapping_pairs_found)} overlapping pairs.")
            print(f"  Total unique overlapping classes: {len(current_overlapping_set)}")


del all_features, all_labels, centroid_tensor, sim_matrix
if DEVICE == torch.device('cuda'):
    torch.cuda.empty_cache()


--- Analyzing Feature Overlap for: CIFAR100 ---
  Calculating centroids for 100 classes...
  Calculating 100x100 similarity matrix...
  Found 70 overlapping pairs.
  Total unique overlapping classes: 54


Saving Training Set Images
This cell iterates through the original *training* set of CIFAR-100. For each image, it checks if its class label is in the set of overlapping classes identified previously. If it is, the image is converted to PIL format and saved as a PNG file in a subdirectory within `NEW_DATASET_DIR`, named according to the dataset and class name (e.g., `new_dataset/CIFAR100_apple`). Metadata for each saved image is collected.

In [12]:
print(f"\n--- Creating New Training Dataset Subset in '{NEW_DATASET_DIR}' ---")

metadata_train_list = []
total_train_samples_saved = 0
total_train_classes_saved = 0

name = 'CIFAR100'
if name not in overlapping_classes or not overlapping_classes[name]:
    print(f"\nNo overlapping classes found/processed for {name}. Skipping training dataset creation.")
else:
    print(f"\nProcessing overlapping classes for {name} (Training Set)...")

    base_dataset_train = original_datasets[f"{name}_train"]
    class_names = original_class_names[name]
    target_attr = dataset_target_attrs[name]
    overlapping_set = overlapping_classes[name]

    current_dataset_class_folders = set()
    for class_idx in overlapping_set:
        class_name = class_names[class_idx].replace(' ', '_').replace('/', '_')
        new_folder_name = f"{name}_{class_name}"
        new_class_dir = os.path.join(NEW_DATASET_DIR, new_folder_name)
        os.makedirs(new_class_dir, exist_ok=True)
        current_dataset_class_folders.add(new_folder_name)

    total_train_classes_saved += len(current_dataset_class_folders)
    print(f"  Created/verified {len(current_dataset_class_folders)} subdirectories for {name} training set.")

    samples_saved_for_this_dataset = 0
    for i in tqdm(range(len(base_dataset_train)), desc=f"Saving training images for {name}"):
        try:
            img_tensor, label_idx = base_dataset_train[i]

            if label_idx in overlapping_set:
                total_train_samples_saved += 1
                samples_saved_for_this_dataset += 1

                class_name = class_names[label_idx].replace(' ', '_').replace('/', '_')
                new_folder_name = f"{name}_{class_name}"
                new_class_dir = os.path.join(NEW_DATASET_DIR, new_folder_name)

                img_pil = transforms.ToPILImage()(img_tensor)

                img_filename = f"{name}_train_orig-idx-{i}_label-{label_idx}.png"
                img_save_path = os.path.join(new_class_dir, img_filename)

                img_pil.save(img_save_path)

                metadata_train_list.append({
                    'new_class_name': new_folder_name,
                    'original_dataset': name,
                    'original_class_name': class_names[label_idx],
                    'original_label_idx': label_idx,
                    'original_index_in_split': i,
                    'split': 'train',
                    'saved_path': img_save_path
                })

        except Exception as e:
            print(f"  Error processing training sample index {i} for {name}: {e}")

    print(f"  Saved {samples_saved_for_this_dataset} training samples for {name}.")

print(f"\nTotal training classes created/verified: {total_train_classes_saved}")
print(f"Total training samples saved: {total_train_samples_saved}")


--- Creating New Training Dataset Subset in 'new_dataset' ---

Processing overlapping classes for CIFAR100 (Training Set)...
  Created/verified 54 subdirectories for CIFAR100 training set.


Saving training images for CIFAR100:   0%|          | 0/50000 [00:00<?, ?it/s]

  Saved 27000 training samples for CIFAR100.

Total training classes created/verified: 54
Total training samples saved: 27000


Saving Test Set Images
This cell performs the same saving process as the previous one, but specifically for the *test* set of CIFAR-100. Images belonging to overlapping classes are saved into subdirectories within the dedicated `TEST_SET_DIR` (e.g., `new_dataset/test_set/CIFAR100_apple`). Metadata for the saved test images is collected separately.

In [13]:
print(f"\n--- Creating New Test Dataset Subset in '{TEST_SET_DIR}' ---")

metadata_test_list = []
total_test_samples_saved = 0
total_test_classes_saved = 0

name = 'CIFAR100'
if name not in overlapping_classes or not overlapping_classes[name]:
    print(f"\nNo overlapping classes found/processed for {name}. Skipping test dataset creation.")
else:
    print(f"\nProcessing overlapping classes for {name} (Test Set)...")

    base_dataset_test = original_datasets[f"{name}_test"]
    class_names = original_class_names[name]
    target_attr = dataset_target_attrs[name]
    overlapping_set = overlapping_classes[name]

    current_dataset_class_folders_test = set()
    for class_idx in overlapping_set:
        class_name = class_names[class_idx].replace(' ', '_').replace('/', '_')
        new_folder_name_test = f"{name}_{class_name}"
        new_class_dir_test = os.path.join(TEST_SET_DIR, new_folder_name_test)
        os.makedirs(new_class_dir_test, exist_ok=True)
        current_dataset_class_folders_test.add(new_folder_name_test)

    total_test_classes_saved += len(current_dataset_class_folders_test)
    print(f"  Created/verified {len(current_dataset_class_folders_test)} subdirectories for {name} test set.")

    samples_saved_for_this_dataset_test = 0
    for i in tqdm(range(len(base_dataset_test)), desc=f"Saving test images for {name}"):
        try:
            img_tensor, label_idx = base_dataset_test[i]

            if label_idx in overlapping_set:
                total_test_samples_saved += 1
                samples_saved_for_this_dataset_test += 1

                class_name = class_names[label_idx].replace(' ', '_').replace('/', '_')
                new_folder_name_test = f"{name}_{class_name}"
                new_class_dir_test = os.path.join(TEST_SET_DIR, new_folder_name_test)

                img_pil = transforms.ToPILImage()(img_tensor)

                img_filename_test = f"{name}_test_orig-idx-{i}_label-{label_idx}.png"
                img_save_path_test = os.path.join(new_class_dir_test, img_filename_test)

                img_pil.save(img_save_path_test)

                metadata_test_list.append({
                    'new_class_name': new_folder_name_test,
                    'original_dataset': name,
                    'original_class_name': class_names[label_idx],
                    'original_label_idx': label_idx,
                    'original_index_in_split': i,
                    'split': 'test',
                    'saved_path': img_save_path_test
                })

        except Exception as e:
            print(f"  Error processing test sample index {i} for {name}: {e}")

    print(f"  Saved {samples_saved_for_this_dataset_test} test samples for {name}.")


print(f"\nTotal test classes created/verified: {total_test_classes_saved}")
print(f"Total test samples saved: {total_test_samples_saved}")


--- Creating New Test Dataset Subset in 'test_set' ---

Processing overlapping classes for CIFAR100 (Test Set)...
  Created/verified 54 subdirectories for CIFAR100 test set.


Saving test images for CIFAR100:   0%|          | 0/10000 [00:00<?, ?it/s]

  Saved 5400 test samples for CIFAR100.

Total test classes created/verified: 54
Total test samples saved: 5400


Save Training Metadata
This cell takes the collected metadata for the saved *training* images and saves it into a CSV file named `metadata_train.csv` inside the `NEW_DATASET_DIR`.

In [14]:
metadata_train_df = pd.DataFrame(metadata_train_list)
csv_train_path = os.path.join(NEW_DATASET_DIR, 'metadata_train.csv')

try:
    metadata_train_df.to_csv(csv_train_path, index=False)
    print(f"\nSuccessfully saved training metadata CSV to: {csv_train_path}")
except Exception as e:
    print(f"\nError saving training metadata CSV: {e}")


Successfully saved training metadata CSV to: new_dataset\metadata_train.csv


Save Test Metadata
This cell takes the collected metadata for the saved *test* images and saves it into a CSV file named `metadata_test.csv` inside the `TEST_SET_DIR`.

In [15]:
metadata_test_df = pd.DataFrame(metadata_test_list)
csv_test_path = os.path.join(TEST_SET_DIR, 'metadata_test.csv')

try:
    metadata_test_df.to_csv(csv_test_path, index=False)
    print(f"\nSuccessfully saved test metadata CSV to: {csv_test_path}")
except Exception as e:
    print(f"\nError saving test metadata CSV: {e}")

del original_datasets # Clean up original data loaded into memory


Successfully saved test metadata CSV to: test_set\metadata_test.csv


Final Report
This cell prints a summary of the dataset creation process, reporting the locations of the new training and test subset directories, the total number of overlapping classes found, and the total number of training and test samples saved.

In [16]:
print("\n--- Final Dataset Summary ---")
print(f"New training subset location: {os.path.abspath(NEW_DATASET_DIR)}")
print(f"New test subset location: {os.path.abspath(TEST_SET_DIR)}")
num_overlap_classes = len(overlapping_classes.get('CIFAR100', set()))
print(f"Total number of overlapping classes (folders created in each subset): {num_overlap_classes}")
print(f"Total number of training samples saved: {total_train_samples_saved}")
print(f"Total number of test samples saved: {total_test_samples_saved}")


--- Final Dataset Summary ---
New training subset location: c:\Users\mateo\Desktop\cifar-100-dataset-cdl\new_dataset
New test subset location: c:\Users\mateo\Desktop\cifar-100-dataset-cdl\test_set
Total number of overlapping classes (folders created in each subset): 54
Total number of training samples saved: 27000
Total number of test samples saved: 5400
