In [1]:
import os
import sys
import time
import math
import random
import torch
import torchvision

from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

In [2]:
import os
import pathlib
import torch
from torch.utils.data import Dataset
from PIL import Image
from typing import Tuple, Dict, List
import matplotlib.pyplot as plt

class BinaryClassificationDataset(Dataset):
    def __init__(self, targ_dir: str, transform=None, target_transform=None) -> None:
        self.paths = list(pathlib.Path(targ_dir).glob("*/*/*.png"))
        self.transform = transform
        self.target_transform = target_transform
        self.classes, self.class_to_idx = self.find_classes(targ_dir)
        
        self.non_distracted_classes = {'sitting_still', 'entering_car', 'exiting_car'}
        self.class_to_idx_binary = {cls_name: 0 if cls_name in self.non_distracted_classes else 1 for cls_name in self.classes}
        
        # Map binary labels to class names
        self.binary_label_to_class_name = {0: 'non_distracted', 1: 'distracted'}

        # Attribute for all binary labels of the dataset
        self.all_binary_labels = [self.class_to_idx_binary[path.parent.parent.name] for path in self.paths]

    def load_image(self, index: int) -> Image.Image:
        image_path = self.paths[index]
        image = Image.open(image_path).convert("RGB")
        return image

    def __len__(self) -> int:
        return len(self.paths)

    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
        classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
        if not classes:
            raise FileNotFoundError(f"Couldn't find any classes in {directory}.")
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int, str]:
        image = self.load_image(index)
        class_name = self.paths[index].parent.parent.name
        class_idx = self.class_to_idx[class_name]
        class_idx_binary = self.class_to_idx_binary[class_name]
        
        # Convert binary label to its corresponding class name
        class_name_binary = self.binary_label_to_class_name[class_idx_binary]
        
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            class_idx_binary = self.target_transform(class_idx_binary)
        
        # Return the image, binary class index, and binary class name
        return image, class_idx_binary, class_idx, class_name_binary

In [3]:
train_dir = "/net/polaris/storage/deeplearning/sur_data/rgb_daa/split_0/train"
val_dir = "/net/polaris/storage/deeplearning/sur_data/rgb_daa/split_0/val"
test_dir = "/net/polaris/storage/deeplearning/sur_data/rgb_daa/split_0/test"

In [4]:
def prepare_dataloader(dataset: Dataset, batch_size: int, num_workers: int, prefetch_factor: int):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        num_workers=num_workers,
        prefetch_factor=prefetch_factor,
    )

In [5]:
batch_size_per_gpu = 1024
num_workers = 4
prefetch_factor = 2

In [6]:
import torch.nn as nn

pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights)

# Freeze the base parameters
for parameter in pretrained_vit.parameters():
    parameter.requires_grad = False

# Change the classifier head to match with binary classification:
# {distracted_driver, non_distracted_driver}
pretrained_vit.heads = nn.Linear(in_features=768, out_features=2)
pretrained_vit_transforms = pretrained_vit_weights.transforms()

In [7]:
# Use ImageFolder to create dataset(s)
train_dataset = BinaryClassificationDataset(train_dir, transform=pretrained_vit_transforms)
val_dataset = BinaryClassificationDataset(val_dir, transform=pretrained_vit_transforms)
test_dataset = BinaryClassificationDataset(test_dir, transform=pretrained_vit_transforms)

In [8]:
# Pass the adjusted batch size here
train_dataloader = prepare_dataloader(train_dataset, batch_size_per_gpu, num_workers= num_workers, prefetch_factor=prefetch_factor)
val_dataloader = prepare_dataloader(val_dataset, batch_size_per_gpu, num_workers= num_workers, prefetch_factor=prefetch_factor)
test_dataloader = prepare_dataloader(test_dataset, batch_size_per_gpu, num_workers= num_workers, prefetch_factor=prefetch_factor)