In [1]:
import torch

if torch.cuda.is_available():
    device = torch.device('cuda')
    print('GPU device:',torch.cuda.get_device_name(0))
else:
    device = torch.device('cpu')
    print('No GPU avaialable, Using CPU')

torch.cuda.set_device(1)

GPU device: Tesla V100-SXM2-32GB


In [2]:
# Set seeds
def set_seeds(seed: int=42):
    """Sets random sets for torch operations.

    Args:
        seed (int, optional): Random seed to set. Defaults to 42.
    """
    # Set the seed for general torch operations
    torch.manual_seed(seed)
    # Set the seed for CUDA torch operations (ones that happen on the GPU)
    torch.cuda.manual_seed(seed)

set_seeds(42)

In [3]:
import torch
import timm
from pathlib import Path
from PIL import Image
# Load the pretrained ViT model
model = timm.create_model(
    'vit_huge_patch14_224.orig_in21k',
    pretrained=True,
    num_classes=0,  # remove classifier nn.Linear
)
model = model.eval()
model = model.to(device)
# Create data transforms
data_config = timm.data.resolve_model_data_config(model)
data_transforms = timm.data.create_transform(**data_config, is_training=False)

In [4]:
import os
from torchvision.datasets import DatasetFolder
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image

class CustomImageFolder(DatasetFolder):
    def __init__(self, root, transform=None, target_transform=None):
        super(CustomImageFolder, self).__init__(root, loader=self.pil_loader, extensions='png', transform=transform, target_transform=target_transform)
    
    def pil_loader(self, path: str) -> Image.Image:
        # Open the path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')
    
    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target, path  # Return image, label, and image path

In [5]:
# Create a CustomImageDataset instance
dataset = CustomImageFolder(root="/net/polaris/storage/deeplearning/sur_data/binary_rgb_daa/split_0/train", 
                             transform=data_transforms
                             )

# Create a DataLoader with custom collate_fn
dataloader = DataLoader(dataset, 
                        batch_size=1024, 
                        shuffle=True,
                        num_workers=10,
                        drop_last=False,
                        )

# Calculate the total number of batches
total_batches = len(dataloader)
total_batches

254

In [8]:
images, labels, paths = next(iter(dataloader))

In [14]:
# images[1]
paths[1]

'/net/polaris/storage/deeplearning/sur_data/binary_rgb_daa/split_0/train/distracted/img_033911.png'

In [6]:
from tqdm import tqdm
import pickle

# Initialize a list to accumulate features
all_features = []
all_gt_labels = []
img_paths_batchwise = []

with torch.no_grad():
    for batch in tqdm(dataloader, total=len(dataloader)):
        images, targets, img_paths = batch
        images = images.to(device)
        features = model(images)
        features = features.to('cpu')
        all_features.append(features)
        all_gt_labels.append(targets)
        img_paths_batchwise.append(img_paths)

# Collect features
all_features = torch.cat(all_features, dim=0)
all_gt_labels = torch.cat(all_gt_labels, dim=0)

# Note: Paths are stored batch wise

# Save features_all as a list in pickle format
with open('/home/sur06423/hiwi/vit_exp/vision_tranformer_baseline/src/components/distraction_detection_d_b/clustering_experiments/features_split_0_kinect_rgb/all_split_0_rgb_features.pkl', 'wb') as file:
    pickle.dump(all_features, file)

with open('/home/sur06423/hiwi/vit_exp/vision_tranformer_baseline/src/components/distraction_detection_d_b/clustering_experiments/features_split_0_kinect_rgb/all_split_0_rgb_labels.pkl', 'wb') as file:
    pickle.dump(all_gt_labels, file)

all_img_paths = img_paths_batchwise
with open('/home/sur06423/hiwi/vit_exp/vision_tranformer_baseline/src/components/distraction_detection_d_b/clustering_experiments/features_split_0_kinect_rgb/all_split_0_rgb_imagepaths.pkl', 'wb') as file:
    pickle.dump(all_img_paths, file)

  return F.conv2d(input, weight, bias, self.stride,
100%|██████████| 254/254 [2:01:53<00:00, 28.79s/it]  
