In [None]:
import torch

if torch.cuda.is_available():
    device = torch.device('cuda')
    print('GPU device:',torch.cuda.get_device_name(0))
else:
    device = torch.device('cpu')
    print('No GPU avaialable, Using CPU')

torch.cuda.set_device(1)

In [None]:
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")

In [None]:
# Set seeds
def set_seeds(seed: int=42):
    """Sets random sets for torch operations.

    Args:
        seed (int, optional): Random seed to set. Defaults to 42.
    """
    # Set the seed for general torch operations
    torch.manual_seed(seed)
    # Set the seed for CUDA torch operations (ones that happen on the GPU)
    torch.cuda.manual_seed(seed)

set_seeds(42)

In [None]:
import os
import getpass
import sys

def setup_ccname():
    user=getpass.getuser()
    # check if k5start is running, exit otherwise
    try:
        pid=open("/tmp/k5pid_"+user).read().strip()
        os.kill(int(pid), 0)
    except:
        sys.stderr.write("Unable to setup KRB5CCNAME!\nk5start not running!\n")
        sys.exit(1)
    try:
        ccname=open("/tmp/kccache_"+user).read().split("=")[1].strip()
        os.environ['KRB5CCNAME']=ccname
    except:
        sys.stderr.write("Unable to setup KRB5CCNAME!\nmaybe k5start not running?\n")
        sys.exit(1)

In [None]:
import os
import glob
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from typing import List, Dict

class CustomImageDataset(Dataset):
    """
    Custom dataset for loading image data from a directory.

    Args:
        root_dir (str): Root directory containing class subdirectories.
        transform (callable, optional): A function/transform to apply to the image.
    """

    def __init__(self, root_dir: str, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes, self.class_to_idx = self._find_classes(root_dir)
        self.samples = self._load_samples()
        self.class_ratios = self._calculate_class_ratios()

    def _find_classes(self, directory: str) -> (List[str], Dict[str, int]):
        classes = [d.name for d in os.scandir(directory) if d.is_dir()]
        classes.sort()
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

    def _load_samples(self):
        samples = []
        for target_class in self.classes:
            class_dir = os.path.join(self.root_dir, target_class)
            class_idx = self.class_to_idx[target_class]
            for root, _, fnames in os.walk(class_dir):
                for fname in fnames:
                    path = os.path.join(root, fname)
                    samples.append((path, class_idx))
        return samples

    def _calculate_class_ratios(self):
        class_counts = [0] * len(self.classes)
        for _, class_idx in self.samples:
            class_counts[class_idx] += 1

        total_samples = len(self.samples)
        class_ratios = [count / total_samples for count in class_counts]

        return class_ratios

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, target = self.samples[idx]
        img = Image.open(img_path).convert("RGB")

        if self.transform:
            img = self.transform(img)

        return img, target, img_path

In [None]:
import torch
import timm
from pathlib import Path
from PIL import Image
# Load the pretrained ViT model
model = timm.create_model(
    'vit_huge_patch14_224.orig_in21k',
    pretrained=True,
    num_classes=0,  # remove classifier nn.Linear
)
model = model.eval()
model = model.to(device)
# Create data transforms
data_config = timm.data.resolve_model_data_config(model)
data_transforms = timm.data.create_transform(**data_config, is_training=False)

In [None]:
# Create a CustomImageDataset instance
dataset = CustomImageDataset(root_dir='/net/polaris/storage/deeplearning/sur_data/rgb_daa/split_1/train', 
                             transform=data_transforms
                             )

# Create a DataLoader with custom collate_fn
dataloader = DataLoader(dataset, 
                        batch_size=1024, 
                        shuffle=True,
                        num_workers=16, 
                        #collate_fn=custom_collate,
                        drop_last=False,
                        )

# Calculate the total number of batches
total_batches = len(dataloader)
total_batches

In [None]:
from tqdm import tqdm
import pickle

# Initialize a list to accumulate features
all_features = []
all_gt_labels = []
img_paths_batchwise = []

with torch.no_grad():
    for batch in tqdm(dataloader, total=len(dataloader)):
        setup_ccname()
        images, targets, img_paths = batch
        images = images.to(device)
        features = model(images)
        features = features.to('cpu')
        all_features.append(features)
        all_gt_labels.append(targets)
        img_paths_batchwise.append(img_paths)

# Collect features
all_features = torch.cat(all_features, dim=0)
all_gt_labels = torch.cat(all_gt_labels, dim=0)

# Save features_all as a list in pickle format ::: Change names in the paths
with open('/home/sur06423/hiwi/vit_exp/vision_tranformer_baseline/src/components/Trainer_D_2/Dataloader_2/features_store/rgb_split_1_daa/all_d2_s1_features.pkl', 'wb') as file:
    pickle.dump(all_features, file)

with open('/home/sur06423/hiwi/vit_exp/vision_tranformer_baseline/src/components/Trainer_D_2/Dataloader_2/features_store/rgb_split_1_daa/all_d2_s1_labels.pkl', 'wb') as file:
    pickle.dump(all_gt_labels, file)

all_img_paths = img_paths_batchwise
with open('/home/sur06423/hiwi/vit_exp/vision_tranformer_baseline/src/components/Trainer_D_2/Dataloader_2/features_store/rgb_split_1_daa/all_d2_s1_imagepaths.pkl', 'wb') as file:
    pickle.dump(all_img_paths, file)


In [None]:
import pickle
# Load features_all as a list 
with open('/home/sur06423/hiwi/vit_exp/vision_tranformer_baseline/src/components/Trainer_D_2/Dataloader_2/features_store/rgb_split_1_daa/all_d2_s1_features.pkl', 'rb') as file:
    all_features_loaded = pickle.load(file)

with open('/home/sur06423/hiwi/vit_exp/vision_tranformer_baseline/src/components/Trainer_D_2/Dataloader_2/features_store/rgb_split_1_daa/all_d2_s1_labels.pkl', 'rb') as file:
    all_gt_labels_loaded = pickle.load(file)

with open('/home/sur06423/hiwi/vit_exp/vision_tranformer_baseline/src/components/Trainer_D_2/Dataloader_2/features_store/rgb_split_1_daa/all_d2_s1_imagepaths.pkl', 'rb') as file:
    all_img_paths_loaded = pickle.load(file)

In [None]:
all_features_loaded.shape

In [None]:
all_gt_labels_loaded.shape

In [None]:
all_img_paths_loaded.__len__()

In [None]:
classes_in_dataset = dataset.classes
print(f"The class with 8th index in the DAA Image split_1 train dataset are:{classes_in_dataset[8]}")

In [None]:
from PIL import Image

first_path = all_img_paths_loaded[0][0]
img = Image.open(first_path)
img