In [1]:
new_train_dir = "/net/polaris/storage/deeplearning/sur_data/binary_rgb_daa/split_0/train"
new_val_dir = "/net/polaris/storage/deeplearning/sur_data/binary_rgb_daa/split_0/val"
new_test_dir = "/net/polaris/storage/deeplearning/sur_data/binary_rgb_daa/split_0/test"

In [4]:
import torch
import torchvision
import torch.nn as nn

pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights)

# Freeze the base parameters
for parameter in pretrained_vit.parameters():
    parameter.requires_grad = False

# Change the classifier head to match with binary classification:
# {distracted_driver, non_distracted_driver}
pretrained_vit.heads = nn.Linear(in_features=768, out_features=2)
pretrained_vit_transforms = pretrained_vit_weights.transforms()

In [5]:
from torchvision.datasets import ImageFolder

train_dataset = ImageFolder(root=new_train_dir, transform=pretrained_vit_transforms)
val_dataset = ImageFolder(root=new_val_dir, transform=pretrained_vit_transforms)
test_dataset = ImageFolder(root=new_test_dir, transform=pretrained_vit_transforms)

In [3]:
print(f"The length of the Train dataset split_0 RGB is: {len(train_dataset)}")
print(f"The length of the Validation dataset split_0 RGB is: {len(val_dataset)}")
print(f"The length of the Test dataset split_0 RGB is: {len(test_dataset)}")

The length of the Train dataset split_0 RGB is: 259865
The length of the Validation dataset split_0 RGB is: 56024
The length of the Test dataset split_0 RGB is: 87315


In [7]:
# Turn train and test Datasets into DataLoaders
from torch.utils.data import DataLoader
train_dataloader = DataLoader(dataset=train_dataset, 
                              batch_size=1, # how many samples per batch?
                              num_workers=1, # how many subprocesses to use for data loading? (higher = more)
                              shuffle=True) # shuffle the data?

val_dataloader = DataLoader(dataset=val_dataset, 
                             batch_size=1, 
                             num_workers=1, 
                             shuffle=False) # don't usually need to shuffle testing data

test_dataloader = DataLoader(dataset=test_dataset, 
                             batch_size=1, 
                             num_workers=1, 
                             shuffle=False) # don't usually need to shuffle testing data

train_dataloader, val_dataloader, test_dataloader

(<torch.utils.data.dataloader.DataLoader at 0x7f1335028310>,
 <torch.utils.data.dataloader.DataLoader at 0x7f132f413ac0>,
 <torch.utils.data.dataloader.DataLoader at 0x7f132f4138e0>)

In [8]:
img, label = next(iter(train_dataloader))

# Batch size will now be 1, try changing the batch_size parameter above and see what happens
print(f"Image shape: {img.shape} -> [batch_size, color_channels, height, width]")
print(f"Label shape: {label.shape}")

Image shape: torch.Size([1, 3, 224, 224]) -> [batch_size, color_channels, height, width]
Label shape: torch.Size([1])


In [9]:
# Install torchinfo if it's not available, import it if it is
try: 
    import torchinfo
except:
    !pip install torchinfo
    import torchinfo

In [11]:
from torchinfo import summary
summary(pretrained_vit, input_size=[1, 3, 224, 224])

  return F.conv2d(input, weight, bias, self.stride,


Layer (type:depth-idx)                        Output Shape              Param #
VisionTransformer                             [1, 2]                    768
├─Conv2d: 1-1                                 [1, 768, 14, 14]          (590,592)
├─Encoder: 1-2                                [1, 197, 768]             151,296
│    └─Dropout: 2-1                           [1, 197, 768]             --
│    └─Sequential: 2-2                        [1, 197, 768]             --
│    │    └─EncoderBlock: 3-1                 [1, 197, 768]             (7,087,872)
│    │    └─EncoderBlock: 3-2                 [1, 197, 768]             (7,087,872)
│    │    └─EncoderBlock: 3-3                 [1, 197, 768]             (7,087,872)
│    │    └─EncoderBlock: 3-4                 [1, 197, 768]             (7,087,872)
│    │    └─EncoderBlock: 3-5                 [1, 197, 768]             (7,087,872)
│    │    └─EncoderBlock: 3-6                 [1, 197, 768]             (7,087,872)
│    │    └─EncoderBlock: 3-

In [12]:
# Pass through a batchsize of 512
from torchinfo import summary
summary(pretrained_vit, input_size=[512, 3, 224, 224])

Layer (type:depth-idx)                        Output Shape              Param #
VisionTransformer                             [512, 2]                  768
├─Conv2d: 1-1                                 [512, 768, 14, 14]        (590,592)
├─Encoder: 1-2                                [512, 197, 768]           151,296
│    └─Dropout: 2-1                           [512, 197, 768]           --
│    └─Sequential: 2-2                        [512, 197, 768]           --
│    │    └─EncoderBlock: 3-1                 [512, 197, 768]           (7,087,872)
│    │    └─EncoderBlock: 3-2                 [512, 197, 768]           (7,087,872)
│    │    └─EncoderBlock: 3-3                 [512, 197, 768]           (7,087,872)
│    │    └─EncoderBlock: 3-4                 [512, 197, 768]           (7,087,872)
│    │    └─EncoderBlock: 3-5                 [512, 197, 768]           (7,087,872)
│    │    └─EncoderBlock: 3-6                 [512, 197, 768]           (7,087,872)
│    │    └─EncoderBlock: 3-