In [20]:
import torch
from torch.utils.data import DataLoader, random_split
import json
from tqdm import tqdm

In [None]:
config = {
    'IMG_WIDTH': 224,
    'IMG_HEIGHT': 224,
    'TRAINING_DATASET_DIR': "/home/mcv/datasets/C5/COCO/train2014",
    'TEST_DATASET_DIR': '/home/mcv/datasets/C5/COCO/val2014',
    'num_workers': 8,
    'batch_size': 64,
    'epochs': 80,
    'learning_rate': 0.0001,
    'n_neighbors': 5,
    'type_model': 'triplet',
    'num_blocks_unfreeze': 1,
    'device': torch.device("cuda" if torch.cuda.is_available() else "cpu")
}

In [38]:
# File containing images info (file_name)
with open(f"/home/mcv/datasets/C5/COCO/captions_train2014.json", "r") as f:
    captions_train = json.load(f)

with open(f"/home/mcv/datasets/C5/COCO/captions_val2014.json", "r") as f:
    captions_val = json.load(f)

In [39]:
# Dataset for finetuning the triplet network
finetuning_dataset = CocoMetricDataset(
    root=config["TRAINING_DATASET_DIR"], 
    captions_file=captions_train,
    transforms=CustomTransform(config, mode="train"))

# Dataset for the retrieval
retrieval_dataset = CocoMetricDataset(
    root=config["TEST_DATASET_DIR"],
    captions_file=captions_val,
    transforms=CustomTransform(config, mode="val"))

total_length = len(finetuning_dataset)
train_size = int(0.6 * total_length)  # e.g., 60% for training
valid_size = int(0.2 * total_length)  # e.g., 20% for validation
test_size = total_length - train_size - valid_size # remaining 20% for testing
train_dataset, validation_dataset, test_dataset = random_split(finetuning_dataset, [train_size, valid_size, test_size])

dataloader_train = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config["num_workers"], collate_fn=coco_collator)
dataloader_validation = DataLoader(validation_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config["num_workers"], collate_fn=coco_collator)
dataloader_test = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config["num_workers"], collate_fn=coco_collator)

Creating image-answer pairs...: 100%|██████████| 414113/414113 [14:00<00:00, 492.67it/s]


In [54]:
for i, (images, captions, labels) in tqdm(enumerate(dataloader_train), total=len(dataloader_train), desc="Training..."):
    images, labels = images.to(config["device"]), labels.to(config["device"])
    print(i, images.shape, len(captions), labels.shape)
    break

Training...:   0%|          | 0/3883 [00:04<?, ?it/s]

0 torch.Size([64, 3, 224, 224]) 64 torch.Size([64])



