In [None]:
import os
import torch
import timm
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from tqdm import tqdm

# Define the path to your dataset
dataset_root = "/net/polaris/storage/deeplearning/sur_data/binary_rgb_daa/split_1/train"

# Define the device for computation (GPU if available, otherwise CPU)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# Load the pretrained ViT model
model = timm.create_model(
    'vit_huge_patch14_224.orig_in21k',
    pretrained=True,
    num_classes=0,  # Remove classifier nn.Linear
)
model = model.eval()
model = model.to(device)

# Create data transforms
data_config = timm.data.resolve_model_data_config(model)
data_transforms = timm.data.create_transform(**data_config, is_training=False)

In [None]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
# Set up DataLoader
dataset = ImageFolder(root=dataset_root, transform=data_transforms)
dataloader = DataLoader(dataset, batch_size=1024, shuffle=False)

In [None]:
import pickle
# Initialize feature storage with numpy arrays for concatenation
features_by_class = {0: np.empty((0, 1280), dtype=float),  # feature size
                     1: np.empty((0, 1280), dtype=float)}

# Extract features
with torch.no_grad():
    for inputs, labels in tqdm(dataloader):
        inputs = inputs.to(device)
        features = model(inputs)
        features = features.cpu().numpy()  # Convert to NumPy after moving to CPU
        labels = labels.cpu().numpy()

        # Efficiently concatenate features by class
        for class_index in features_by_class.keys():
            class_features = features[labels == class_index]
            features_by_class[class_index] = np.concatenate((features_by_class[class_index], class_features))

# Save the features to a pickle file
with open('/home/sur06423/hiwi/vit_exp/vision_tranformer_baseline/notebooks/varience_analysis/Features_binary_split_0_kinect_rgb/features_by_class_split_1_rgb.pickle', 'wb') as handle:
    pickle.dump(features_by_class, handle, protocol=pickle.HIGHEST_PROTOCOL)

print("Features saved successfully.")