In [None]:
from datasets import load_from_disk, Dataset
import torch

In [None]:
original_ds = load_from_disk("../../../datasets/coco_val_images")

In [None]:
original_ds

In [None]:
print(torch.cuda.device_count())
print([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])

In [None]:
from torchvision import models, transforms

# Load ResNet model and set to eval mode
resnet = models.resnet50(pretrained=True)
resnet.eval()

# Define image preprocessing
preprocess = transforms.Compose([
    transforms.Lambda(lambda img: img.convert("RGB")),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

def image_to_features(example):
    # Convert PIL image to tensor and preprocess
    image = example["image"]
    input_tensor = preprocess(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
    resnet.to("cuda" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        features = resnet(input_tensor)
    return {"image_tensor": features.cpu().squeeze().numpy()}

In [None]:
# import os

# print(os.cpu_count())

In [None]:
new_ds = original_ds.map(
    lambda example: {
        "captions": example["captions"],
        "features": torch.tensor(image_to_features(example)["image_tensor"])
    },
    remove_columns=[col for col in original_ds.column_names if col not in ["captions"]],
    # num_proc=24
)

In [None]:
new_ds.save_to_disk("../../../datasets/coco_val_features")

In [None]:
len(new_ds[0]['features'])

In [None]:
# dataset[350]['captions']