Dataset from [Kaggle](https://www.kaggle.com/competitions/dogs-vs-cats/data?select=train.zip)

In [None]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

In [None]:
# Set up the plot style
sns.set_theme(style="whitegrid")

# Load the model

In [None]:
# Load the DINO model
device = "cuda" if torch.cuda.is_available() else "mps"
model = torch.hub.load('facebookresearch/dino:main', 'dino_vits8')
model = model.to(device)
model.eval()
print(f"Model loaded at device {device}")

In [None]:
# Define the transformation
transform = transforms.Compose([
    transforms.Resize(256, interpolation=3),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# Load the dataset

In [None]:
# https://www.kaggle.com/competitions/dogs-vs-cats/data
# Custom Dataset
class CatsDogDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        self.images = [f for f in os.listdir(folder_path) if f.endswith('.jpg')]
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.folder_path, img_name)
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        label = 0 if img_name.startswith('cat') else 1
        return image, label, img_name

In [None]:
# Create dataset and dataloader
# We are using the train folder as DATASET_PATH
DATASET_PATH = 'catsNdogs'
dataset = CatsDogDataset(DATASET_PATH, transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=False)
print(f"Dataset loaded with {len(dataset)} images")

# Extract features

In [None]:
features = []
labels = []
image_names = []

with torch.no_grad():
    for batch, label, names in tqdm(dataloader):
        output = model(batch.to(device))
        features.append(output.cpu().numpy())
        labels.extend(label.cpu().numpy())
        image_names.extend(names)

In [None]:
# Group features
features = np.concatenate(features)
features.shape

# Perform t-SNE

In [None]:
# Apply t-SNE
tsne = TSNE(n_components=2, random_state=42)
features_tsne = tsne.fit_transform(features)

In [None]:
features_tsne.shape

In [None]:
# After feature extraction and t-SNE, create a DataFrame for easier plotting
df_tsne = pd.DataFrame({
    'tsne_1': features_tsne[:, 0], 'tsne_2': features_tsne[:, 1],
    'label': labels, 'image_name': image_names
})

In [None]:
# Set font size globally for the plot and ensure colors are BLACK
plt.rcParams.update({
    'font.size': 14,
    'text.color': 'black',
    'axes.labelcolor': 'black',
    'xtick.color': 'black',
    'ytick.color': 'black',
    'legend.title_fontsize': 14,
    'legend.fontsize': 14
})

plt.figure(figsize=(8, 6))
ax = plt.gca() # Get current axes object

# Create the main scatter plot
scatter = sns.scatterplot(
    data=df_tsne,
    x='tsne_1', y='tsne_2',
    hue='label',
    palette={0: 'skyblue', 1: 'salmon'},
    legend='full',
    alpha=0.7,
    ax=ax
)

# Customize the legend
legend_labels = ['Cat', 'Dog']
# Create a custom legend with markers (dots) and set the connecting line color to BLACK
legend_elements = [
    plt.Line2D([0], [0], marker='o', color='black', markerfacecolor='skyblue', markersize=10, label='Cat'), # CHANGED color='black'
    plt.Line2D([0], [0], marker='o', color='black', markerfacecolor='salmon', markersize=10, label='Dog') # CHANGED color='black'
]
legend = plt.legend(handles=legend_elements, title='Class', loc='best', title_fontsize=14, labelcolor='black', frameon=False)
plt.setp(legend.get_title(), color='black')

# Set title and labels
plt.title('t-SNE Visualization of DINO Features', fontsize=16, pad=20, color='black')
plt.xlabel('t-SNE Feature 1', fontsize=14, color='black')
plt.ylabel('t-SNE Feature 2', fontsize=14, color='black')

# --- Ensure Transparency ---
# Set alpha/color to None/0 for axes and figure backgrounds
ax.set_facecolor((0, 0, 0, 0))
ax.patch.set_alpha(0.0)
plt.gcf().set_facecolor((0, 0, 0, 0))

# Ensure the tick labels/axis spines are black
ax.tick_params(axis='x', colors='black')
ax.tick_params(axis='y', colors='black')
for spine in ax.spines.values():
    spine.set_color('black')

# Improve the layout
plt.tight_layout()

# Save the plot with a transparent background
plt.savefig('t-SNE.png', dpi=300, transparent=True)

# Show the plot
plt.show()

# KNN: Nearest Neighbors

In [None]:
# Split the data into training and validation sets
X_train, X_val, y_train, y_val, names_train, names_val = train_test_split(
    features, labels, image_names, test_size=0.2, random_state=42, stratify=labels
)

In [None]:
# Create and train the KNN classifier
knn = KNeighborsClassifier(n_neighbors=20)
knn.fit(X_train, y_train)

In [None]:
# Make predictions on the validation set
y_pred = knn.predict(X_val)

# Calculate and print the accuracy
accuracy = accuracy_score(y_val, y_pred)
print(f"Validation Accuracy: {accuracy:.4f}")

In [None]:
# Create a DataFrame with the results
results_df = pd.DataFrame({
    'image_name': names_val,
    'ground_truth': y_val,
    'prediction': y_pred
})
results_df.head()

In [None]:
# Identify misclassified cases
misclassified = results_df[results_df['ground_truth'] != results_df['prediction']]

# Separate misclassified cats and dogs
misclassified_cats = misclassified[misclassified['ground_truth'] == 0]
misclassified_dogs = misclassified[misclassified['ground_truth'] == 1]

# Print some statistics
print(f"Total misclassified: {len(misclassified)} out of {len(results_df)}")
print(f"Misclassified cats: {len(misclassified_cats)}")
print(f"Misclassified dogs: {len(misclassified_dogs)}")

In [None]:
# display the first 5 misclassified cats
misclassified_cats.head()

In [None]:
# display the first 5 misclassified dogs
# we can see how some cases are really cats!
misclassified_dogs.head()