In [None]:
####IMPORTING LIBRARIES########
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import requests
from PIL import Image
from transformers import BlipProcessor, Blip2ForConditionalGeneration
import matplotlib.pyplot as plt
import numpy as np
import pickle

In [None]:
#####HELPER FUNCTIONS#######
def imshow(img):
    img = img
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [None]:
####Loading the Caption Generation Model####
def create_processor_and_model():
  processor = BlipProcessor.from_pretrained("Salesforce/blip2-flan-t5-xl")
  model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16)

  device = "cuda" if torch.cuda.is_available() else "cpu"
  model.to(device)
  return processor, model

In [None]:
####Loading the Train and Test CIFAR Dataset####
def get_dataset_and_loader():
  transform = transforms.Compose(
    [transforms.ToTensor()])

  batch_size = 4

  trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

  trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                            shuffle=False, num_workers=2)

  testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
  testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                          shuffle=False, num_workers=2)
  classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  return trainset, trainloader, testset, testloader

In [None]:
####Custom Dataset for the Images with their Captions####
class CIFAR10WithCaptions(Dataset):
    def __init__(self, cifar_dataset, captions):
        self.cifar_dataset = cifar_dataset
        self.captions = captions


    def __len__(self):
        return len(self.cifar_dataset)

    def __getitem__(self, idx):
        image, label = self.cifar_dataset[idx]
        caption = self.captions[idx]
        return image, label, caption

In [None]:
####Generating captions and creating a new pickle file for the dataset####
def generate_caption_dataset(dataset, dataset_loader, processor, model):
  captions = []
  for images, labels in dataset_loader:
    for image in images:
         transform = transforms.ToPILImage()
         PIL_image = transform(image)

         question = "Describe everything in this image"
         inputs = processor(PIL_image, question, return_tensors="pt").to("cuda").to("cuda", torch.float16)

         out = model.generate(**inputs)
         captions.append(processor.decode(out[0], skip_special_tokens=True))
  return CIFAR10WithCaptions(dataset, captions)

In [None]:
processor, model = create_processor_and_model()
trainset, trainloader, testset, testloader = get_dataset_and_loader()
cifar10_train_with_captions = generate_caption_dataset(trainset, trainloader, processor, model)
cifar10_test_with_captions = generate_caption_dataset(testset, testloader, processor, model)
with open('CAPTIONED_CIFAR_TRAIN.pkl', 'wb') as file:
    pickle.dump(cifar10_train_with_captions, file)
with open('CAPTIONED_CIFAR_TEST.pkl', 'wb') as file:
    pickle.dump(cifar10_test_with_captions, file)



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'BertTokenizerFast'.
Some kwargs in processor config are unused and will not have any effect: num_query_tokens. 


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Files already downloaded and verified
Files already downloaded and verified


RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx