### **Import Dependencies**

In [16]:
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, ViTModel, ViTFeatureExtractor
from datasets import load_dataset
import numpy as np
from datasets.utils.file_utils import get_datasets_user_agent
from torch.utils.data import DataLoader
from tqdm.auto import tqdm 

import warnings
warnings.filterwarnings("ignore")

In [18]:
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import io
import urllib

import PIL.Image

from datasets import load_dataset
from datasets.utils.file_utils import get_datasets_user_agent


USER_AGENT = get_datasets_user_agent()


def fetch_single_image(image_url, timeout=None, retries=0):
    for _ in range(retries + 1):
        try:
            request = urllib.request.Request(
                image_url,
                data=None,
                headers={"user-agent": USER_AGENT},
            )
            with urllib.request.urlopen(request, timeout=timeout) as req:
                image = PIL.Image.open(io.BytesIO(req.read()))
            break
        except Exception:
            image = None
    return image


def fetch_images(batch, num_threads, timeout=None, retries=0):
    fetch_single_image_with_args = partial(fetch_single_image, timeout=timeout, retries=retries)
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        batch["image"] = list(executor.map(fetch_single_image_with_args, batch["image_url"]))
    return batch


In [20]:
# Extract image features
def extract_image_features(images, model, processor, batch_size=32):
    # Preprocess images
    inputs = processor(images=images, return_tensors="pt")

    # Create a dataloader for batch processing
    dataloader = DataLoader(inputs, batch_size=batch_size)

    # Extract features in batches
    features = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch = {k: v.to(device) for k, v in batch.items()} 
            outputs = model(**batch)
            features.append(outputs.pooler_output.cpu().numpy())  

    return np.concatenate(features)

In [21]:
# Etract text embeddings
def extract_text_embeddings(texts, model, tokenizer, batch_size=32):
    # Tokenize the texts
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")

    # Create a dataloader for batch processing
    dataloader = DataLoader(inputs, batch_size=batch_size)

    # Extract embeddings in batches
    embeddings = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch = {k: v.to(device) for k, v in batch.items()} 
            outputs = model(**batch)
            embeddings.append(outputs.pooler_output.cpu().numpy())

    return np.concatenate(embeddings)

#### **Model Loading & Setup**

In [22]:
# Image Encoder
image_encoder_name = "google/vit-base-patch16-224"
image_encoder = ViTModel.from_pretrained(image_encoder_name)
image_processor = ViTFeatureExtractor.from_pretrained(image_encoder_name) 

# Text Encoder
text_encoder_name = "bert-base-uncased" 
text_encoder = AutoModel.from_pretrained(text_encoder_name)
text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_name)

# Language Model
language_model_name = "distilgpt2"
language_model = AutoModelForCausalLM.from_pretrained(language_model_name)
language_tokenizer = AutoTokenizer.from_pretrained(language_model_name)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
# Check if GPU is available and move models to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_encoder.to(device)
text_encoder.to(device)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [24]:
dataset = load_dataset("conceptual_captions")

In [25]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['image_url', 'caption'],
        num_rows: 3318333
    })
    validation: Dataset({
        features: ['image_url', 'caption'],
        num_rows: 15840
    })
})


In [28]:
print(type(dataset['train'][0]['image']))

<class 'PIL.JpegImagePlugin.JpegImageFile'>


In [27]:
num_threads = 20

# Process both splits: 'train' and 'validation'
for split_name in ['train', 'validation']:
    # Sample 10% of the dataset
    dataset[split_name] = dataset[split_name].shuffle(seed=42)
    dataset[split_name] = dataset[split_name].select(range(int(len(dataset[split_name]) * 0.1)))

    # Extract features from images and captions
    dataset[split_name] = dataset[split_name].map(
        fetch_images, 
        batched=True, 
        batch_size=100, 
        fn_kwargs={"num_threads": num_threads}
    )

    dataset[split_name] = dataset[split_name].map(
        lambda examples: {
            'image_features': extract_image_features(examples['image'], image_encoder, image_processor, batch_size=32),
            'text_embeddings': extract_text_embeddings(examples['caption'], text_encoder, text_tokenizer, batch_size=32)
        },
        batched=True,
        batch_size=100,
    )

Map: 100%|██████████| 33183/33183 [3:24:36<00:00,  2.70 examples/s]  
Map:   0%|          | 0/33183 [00:00<?, ? examples/s]


ValueError: Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or jax.ndarray.

In [29]:
# Extract image features
dataset[split_name] = dataset[split_name].map(
    lambda examples: {
    'image_features': extract_image_features(examples['image'], image_encoder, image_processor, batch_size=32)
    },
    batched=True,
    batch_size=100,
)

Map:   0%|          | 0/33183 [00:00<?, ? examples/s]


ValueError: Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or jax.ndarray.

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#image_encoder.to(device)
#text_encoder.to(device)
print(device)

cuda


In [6]:
import pickle

with open("processed_data.pkl", "wb") as f:
    pickle.dump(dataset[split], f)

In [15]:
print(dataset[split])

KeyError: "Column train not in the dataset. Current columns in the dataset: ['image_url', 'user_id', 'caption', 'image']"

### **Feature Extraction**

In [None]:
# Extract text embeddings
def extract_text_embeddings(texts, model, tokenizer, batch_size=32):
    # Tokenize the texts
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")

    # Dataloader for batch processing
    dataloader = DataLoader(inputs, batch_size=batch_size)

    # Extract embeddings in batches
    embeddings = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch = {k: v.to(device) for k, v in batch.items()} 
            outputs = model(**batch)
            embeddings.append(outputs.pooler_output.cpu().numpy())

    return np.concatenate(embeddings)