In [None]:
!nvidia-smi

In [None]:
import torch

if torch.cuda.is_available():
    device = torch.device('cuda')
    print('GPU device:',torch.cuda.get_device_name(0))
else:
    device = torch.device('cpu')
    print('No GPU avaialable, Using CPU')

torch.cuda.set_device(1)

In [None]:
import os
import pathlib
import torch
from torch.utils.data import Dataset
from PIL import Image
from typing import Tuple, Dict, List
from torchvision import transforms

class BinaryClassificationDataset(Dataset):
    def __init__(self, targ_dir: str, transform=None, feature_extractor=None, device=None) -> None:
        self.paths = list(pathlib.Path(targ_dir).glob("*/*/*.png"))
        self.transform = transform
        self.feature_extractor = feature_extractor
        self.device = device
        self.classes, self.class_to_idx = self.find_classes(targ_dir)
        
        self.non_distracted_classes = {'sitting_still', 'entering_car', 'exiting_car'}
        self.class_to_idx_binary = {cls_name: 0 if cls_name in self.non_distracted_classes else 1 for cls_name in self.classes}
        
        # Map binary labels to class names
        self.binary_label_to_class_name = {0: 'non_distracted', 1: 'distracted'}

    def load_image(self, index: int) -> Image.Image:
        image_path = self.paths[index]
        image = Image.open(image_path).convert("RGB")
        return image

    def __len__(self) -> int:
        return len(self.paths)

    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
        classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
        if not classes:
            raise FileNotFoundError(f"Couldn't find any classes in {directory}.")
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

    def __getitem__(self, index: int) -> Tuple[Dict, np.ndarray]:
        image_path = self.paths[index]
        image = self.load_image(index)
        class_name = image_path.parent.parent.name
        class_idx = self.class_to_idx[class_name]
        class_idx_binary = self.class_to_idx_binary[class_name]
        class_name_binary = self.binary_label_to_class_name[class_idx_binary]
        
        # Prepare sample information
        sample = {
            'image_path': str(image_path),
            'class_idx_binary': class_idx_binary,
            'class_idx_original': class_idx,
            'class_name_binary': class_name_binary,
            'class_name_original': class_name,
        }
        
        # Extract features if a feature extractor and device are provided
        features = np.array([])  # Default to an empty array if no feature extraction is performed
        if self.feature_extractor and self.device and self.transform:
            image_tensor = self.transform(image).unsqueeze(0).to(self.device)
            with torch.no_grad():
                features = self.feature_extractor(image_tensor)
                features = features.squeeze(0).cpu().numpy()
        
        return sample, features

In [None]:
import os
import torch
import numpy as np
import timm
from torchvision import transforms
from PIL import Image
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from multiprocessing import Pool
import pickle

def load_and_preprocess_image(image_path, transform):
    image = Image.open(image_path).convert("RGB")
    return transform(image)

def extract_features(sample_info, model, data_transforms, device):
    image_path = sample_info['image_path']
    image = load_and_preprocess_image(image_path, data_transforms)
    img_tensor = image.unsqueeze(0).to(device)
    with torch.no_grad():
        feature = model(img_tensor).cpu().numpy()
    return feature.squeeze(), sample_info

def worker_init_fn(worker_id):
    torch_seed = torch.initial_seed()
    np.random.seed(torch_seed % 2**32)

def process_chunk(chunk, model_name, data_transforms, gpu_id):
    device = f'cuda:{gpu_id}'
    model = timm.create_model(model_name, pretrained=True, num_classes=0).eval().to(device)
    
    processed_samples = [extract_features(sample, model, data_transforms, device) for sample in chunk]
    return processed_samples

def parallel_feature_extraction(dataset, model_name, num_gpus=4):
    data_config = timm.data.resolve_model_data_config(model_name)
    data_transforms = timm.data.create_transform(**data_config, is_training=False)
    
    # Create a DataLoader to handle batching and multiprocessing
    data_loader = DataLoader(dataset, batch_size=len(dataset) // num_gpus, shuffle=False, num_workers=num_gpus, worker_init_fn=worker_init_fn)
    
    # Using Pool to manage GPU allocation
    with Pool(num_gpus) as p:
        results = p.starmap(process_chunk, [(chunk, model_name, data_transforms, gpu_id % num_gpus) for gpu_id, chunk in enumerate(data_loader)])
    
    # Flattening the list of results
    all_features, all_samples_info = zip(*[item for sublist in results for item in sublist])
    
    return all_samples_info, all_features

# Function to save data to a pickle file
def save_to_pickle(data, filename):
    with open(filename, 'wb') as f:
        pickle.dump(data, f)

if __name__ == "__main__":
    dataset_root = "/net/polaris/storage/deeplearning/sur_data/rgb_daa/split_0/train"
    binary_dataset = BinaryClassificationDataset(dataset_root)
    model_name = 'vit_huge_patch14_224.orig_in21k'
    all_samples_info, all_features = parallel_feature_extraction(binary_dataset, model_name, num_gpus=4)
    save_to_pickle({'samples_info': all_samples_info, 'features': all_features}, 'aggregated_features.pkl')
