In [None]:
###
# This code extracts image features using ImageNet-trained ResNet50 
# and visualizes them in a two-dimensional space using t-SNE.
###

In [None]:
import timm
import torch
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from openTSNE import TSNE
import umap.umap_ as umap
import os, glob
from matplotlib.cm import get_cmap
from matplotlib.colors import to_rgba

def set_gpu(gpu_id):
    if torch.cuda.is_available():
        device = torch.device(f"cuda:{gpu_id}")
    else:
        device = torch.device("cpu")
    return device

gpu_id = 0
device = set_gpu(gpu_id)
print(f"Using device: {device}")

In [None]:
#ResNet50

model_name = "resnet50"  
model = timm.create_model(model_name, pretrained=True, num_classes=1000)
model = model.to(device)  
model.fc = torch.nn.Identity()  

class ImageDataset(Dataset):
    def __init__(self, image_paths, transform):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGB")
        return self.transform(image), image_path

def extract_features_with_batches(image_paths, model, batch_size=128):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    dataset = ImageDataset(image_paths, transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    features = []
    filenames = []

    model.eval()
    with torch.no_grad():
        for batch_images, batch_paths in dataloader:
            batch_images = batch_images.to(device)
            batch_features = model(batch_images) 
            features.extend(batch_features.cpu().numpy())
            filenames.extend(batch_paths)

    return np.array(features), filenames

In [None]:
d = '/PATH/TO/YOUR/IMAGES'
image_paths = sorted(glob.glob(d + '/*/*.png'))
print(len(image_paths))

In [None]:
batch_size = 512 
features, filenames = extract_features_with_batches(image_paths, model, batch_size=batch_size)

In [None]:
# t-SNE

tsne = TSNE(
    n_components=2,
    perplexity=30,
    learning_rate="auto",
    n_jobs=-1,  
    random_state=42,
    metric="cosine",
)
features_tsne = tsne.fit(features)

In [None]:
# save

output_dir = "/YOUR/SAVE/DIRECTORY"
os.makedirs(output_dir, exist_ok=True)

np.save(os.path.join(output_dir, "features.npy"), features)  
np.save(os.path.join(output_dir, "filenames.npy"), np.array(filenames)) 
np.save(os.path.join(output_dir, "features_tsne.npy"), features_tsne)  
