In [1]:
"""PATHS & DIRS"""
from pathlib import Path

try:
    from google.colab import drive
    drive.mount("/content/drive")
    DATA_ROOT = Path("/content/data")
    SAVE_DIR = Path("/content/drive/MyDrive/research/vision_transformer")
    CHECKPOINT_DIR = SAVE_DIR/"vit_caltech_checkpoints"
    EXPERIMENT_DIR = SAVE_DIR/"experiments"
except:
    DATA_ROOT = Path.home()/"Desktop/research"
    CHECKPOINT_DIR = Path.home()/"Desktop/projects/deep_learning_essentials/vision_transformer/vit_cifar_checkpoints"
    EXPERIMENT_DIR = Path.home()/"Desktop/projects/deep_learning_essentials/vision_transformer/experiments"

if not DATA_ROOT.exists():
    DATA_ROOT.mkdir()
    print("Created Data dir")
else:
    print("DATA_ROOT exists at : ", DATA_ROOT)
    
if not CHECKPOINT_DIR.exists():
    CHECKPOINT_DIR.mkdir()
    print("created CKPT dir")
else:
    print("ckpt exists at : ", CHECKPOINT_DIR)

DATA_ROOT exists at :  /home/avishkar/Desktop/research
ckpt exists at :  /home/avishkar/Desktop/projects/deep_learning_essentials/vision_transformer/vit_cifar_checkpoints


In [5]:
"""DATASET"""
import torch
from torchvision import datasets
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
!pip install gdown
def to_rgb(image):
  """Converts a grayscale image to RGB format."""
  if len(image.getbands()) == 1:
    # Add two dummy channels to make it RGB
    return image.convert('RGB')
  else:
    return image

NUM_WORKERS = 2
BATCH_SIZE = 32 

transform = transforms.Compose([
  transforms.Lambda(to_rgb),
    transforms.Resize((64, 64)),  # Resize images 64x64 : caltech101 has pics of around 200x300 
    transforms.ToTensor(),           # Convert images to tensors
])

dataset = datasets.Caltech101(DATA_ROOT, transform=transform, download=True)
# dataset = datasets.Caltech101(DATA_ROOT, transform=transform, download=True)
print("Dataset size : ",len(dataset))
indices = list(range(len(dataset)))

split = int(0.8 * len(dataset))
train_indices, test_indices = indices[:split], indices[split:]

# Create training and test subsets using Subset
train_dataset = torch.utils.data.Subset(dataset, train_indices)
test_dataset = torch.utils.data.Subset(dataset, test_indices)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)


Defaulting to user installation because normal site-packages is not writeable
Collecting gdown
  Downloading gdown-5.2.0-py3-none-any.whl (18 kB)
Collecting tqdm
  Downloading tqdm-4.66.4-py3-none-any.whl (78 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.3/78.3 KB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
Collecting PySocks!=1.5.7,>=1.5.6
  Downloading PySocks-1.7.1-py3-none-any.whl (16 kB)
Installing collected packages: tqdm, PySocks, gdown
Successfully installed PySocks-1.7.1 gdown-5.2.0 tqdm-4.66.4


Downloading...
From (original): https://drive.google.com/uc?id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
From (redirected): https://drive.usercontent.google.com/download?id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp&confirm=t&uuid=10c234c5-e045-4081-9c85-2afbe2888d65
To: /home/avishkar/Desktop/research/caltech101/101_ObjectCategories.tar.gz
100%|██████████| 132M/132M [00:20<00:00, 6.34MB/s] 


Extracting /home/avishkar/Desktop/research/caltech101/101_ObjectCategories.tar.gz to /home/avishkar/Desktop/research/caltech101


Downloading...
From (original): https://drive.google.com/uc?id=175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m
From (redirected): https://drive.usercontent.google.com/download?id=175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m&confirm=t&uuid=a65c85de-5496-499a-a38d-2bedf678b0df
To: /home/avishkar/Desktop/research/caltech101/Annotations.tar
100%|██████████| 14.0M/14.0M [00:01<00:00, 8.45MB/s]


Extracting /home/avishkar/Desktop/research/caltech101/Annotations.tar to /home/avishkar/Desktop/research/caltech101
Dataset size :  8677


In [None]:
"""VISUALIZE DATA"""
import matplotlib.pyplot as plt

for i, (imgs, labels) in enumerate(train_loader):
    img = imgs[0]
    plt.imshow(img.T.cpu().numpy())
    plt.show()
    break
    