In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import random
from PIL import Image
from tqdm import tqdm
import json

from datasets import load_dataset, Image

import torch
from torchvision import transforms
from torch.utils.data import DataLoader

import sys
sys.path.append("..")

import aiohttp

In [None]:
# Setup device-agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"

random.seed(42)

## Build DOCCI Dataset

In [None]:
# Prolonged timeout for loading the Google dataset because they are losers who don't host on HuggingFace: https://github.com/huggingface/datasets/issues/7164#issuecomment-2439589751

docci_dataset = load_dataset('google/docci', name='docci', trust_remote_code=True, storage_options={'client_kwargs': {'timeout': aiohttp.ClientTimeout(total=10000)}})

In [None]:
print(docci_dataset['train'])
print(docci_dataset['test'])

random_sample = docci_dataset['train'][random.randint(0, len(docci_dataset['train']) - 1)]
plt.imshow(np.array(random_sample['image']))
plt.axis('off')
plt.show()
print('Description:\n', random_sample['description'])

In [None]:
test_compose = transforms.Compose(
    [
        transforms.Resize(size=[256], interpolation=transforms.InterpolationMode.BILINEAR, max_size=None, antialias=None),
        transforms.CenterCrop(size=[224]),
        transforms.ToTensor(),
        # transforms.RandomHorizontalFlip(p=0.5),
    ]
)
def transforms_test(examples):
    examples["pixel_values"] = [test_compose(image.convert("RGB").resize((100,100))) for image in examples["image"]]
    return examples

train_compose = transforms.Compose(
    [
        transforms.Resize(size=[256], interpolation=transforms.InterpolationMode.BILINEAR, max_size=None, antialias=None),
        transforms.CenterCrop(size=[224]),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
        # transforms.RandomHorizontalFlip(p=0.5),
        # transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.5),
    ]
)
def transforms_train(examples):
    examples["pixel_values"] = [train_compose(image.convert("RGB").resize((100,100))) for image in examples["image"]]
    return examples

In [None]:
dataset_train = docci_dataset['train'].map(transforms_train, remove_columns=["image"], batched=True)
dataset_test = docci_dataset['test'].map(transforms_test, remove_columns=["image"], batched=True)
# dataset.set_transform(transforms)
print(dataset_test[0])

In [None]:
dataset_train = dataset_train.with_format("torch", device=device)
dataset_test = dataset_test.with_format("torch", device=device)

In [None]:
# dataset_train.set_format(type="torch", columns=["pixel_values", 'example_id', 'description'])
# dataset_test.set_format(type="torch", columns=["pixel_values", 'example_id', 'description'])

In [None]:
print(dataset_test[0]['pixel_values'].shape)
print(dataset_train[0]['pixel_values'].shape)

plt.imshow(np.array(dataset_train[0]['pixel_values'].cpu().permute(1, 2, 0)))
plt.axis('off')
plt.show()

plt.imshow(docci_dataset['train'][0]['image'])
plt.axis('off')
plt.show()

print('Description:\n', dataset_train[0]['description'])

In [None]:
print(dataset_train[0]['pixel_values'].dtype)
print(dataset_test[0]['pixel_values'].dtype)

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(dataset=dataset_train, 
                              batch_size=1, # how many samples per batch?
                              num_workers=0, # how many subprocesses to use for data loading? (higher = more)
                              shuffle=True) # shuffle the data?

test_dataloader = DataLoader(dataset=dataset_test, 
                             batch_size=1, 
                             num_workers=0, 
                             shuffle=False) # don't usually need to shuffle testing data

train_dataloader, test_dataloader

In [None]:
for batch in train_dataloader:
    print(batch['pixel_values'].shape)
    break

## Evaluate Dataset After Transformations

In [None]:
import wget
import json

# evaluate images with popular CNN model

# Helper function to get find index to class name
def get_imagenet_class(outputs):
  idx = outputs.argmax(dim=1).item()

  if not os.path.isfile("imagenet_class_index.json"):
    wget.download("https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json", "imagenet_class_index.json")
  with open("imagenet_class_index.json", "r") as fp:
    class_idx = json.load(fp)
  idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]
  return idx2label[idx]


import torchvision.models as models
import torchvision.transforms as T

cnn_model = models.resnet152(pretrained=True)
cnn_model = cnn_model.to(device)
cnn_model.eval()

for batch in train_dataloader:
    with torch.no_grad():
        outputs = cnn_model(batch['pixel_values'])
    print(outputs.shape)
    print(get_imagenet_class(outputs))
    plt.imshow(np.array(batch['pixel_values'][0].cpu().permute(1, 2, 0)))
    plt.axis('off')
    plt.show()

## Load Encoder Architectures