In [1]:
import os

In [2]:
%pwd

'c:\\Users\\rahul\\Desktop\\Project\\CaptionAI\\research'

In [3]:
os.chdir("../")

In [4]:
%pwd

'c:\\Users\\rahul\\Desktop\\Project\\CaptionAI'

Entity

In [5]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen = True)
class CustomDatasetConfig:
    root_dir: Path
    image_dir: Path
    caption_file: Path
    vocab: Path
    save_file_path: Path
    transform: bool

Configuration

In [6]:
from CaptionAI.constants import *
from CaptionAI.utils.common import read_yaml, create_directories

In [7]:
class ConfigurationManager:
    def __init__(self,
                 config_file_path = CONFIG_FILE_PATH,
                 params_file_path = PARAMS_FILE_PATH):
        
        self.config = read_yaml(config_file_path)
        self.params = read_yaml(params_file_path)

        create_directories([self.config.artifacts_root])

    def get_dataset_config(self):
        config = self.config.custom_dataset

        create_directories([config.root_dir])

        dataset_config = CustomDatasetConfig(
            root_dir = config.root_dir,
            image_dir = config.image_dir,
            caption_file = config.caption_file,
            vocab = config.vocab,
            save_file_path = config.save_file_path,
            transform = False
        )

        return dataset_config

Components

In [8]:
from CaptionAI.utils.dataset import FlickrDataset, generate_batch_captions
from torch.utils.data import DataLoader
from CaptionAI import logger
import pickle
from tqdm import tqdm

In [9]:
class DatasetCreation:
    def __init__(self, config: CustomDatasetConfig):
        self.config = config
        self.flickr_dataset = None
        self._create_dataset()

    def _create_dataset(self):
        logger.info("Creating the Custom Dataset.")
        self.flickr_dataset = FlickrDataset(
            image_dir = self.config.image_dir,
            caption_file = self.config.caption_file,
            vocab_file = self.config.vocab
        )
        logger.info("Custom Dataset is created.")

    def create_dataloader(self, batch_size: int = 512):
        logger.info("Data Loader is getting created.")
        self.data_loader = DataLoader(
            dataset = self.flickr_dataset,
            batch_size = batch_size,
            shuffle = True,
            collate_fn = generate_batch_captions(pad_idx = 1,
                                                 batch_first = True)
        )
        logger.info("Data Loader is created.")

    def save_dataset(self):
        all_batches = []
        logger.info("Saving dataset batches...")
        
        for batch in tqdm(self.data_loader, desc = "Saving Batches", unit = "batch"):
            all_batches.append(batch)

        with open(self.config.save_file_path, "wb") as f:
            pickle.dump(all_batches, f)

        logger.info(f"DataLoader batches saved to {self.config.save_file_path}")


Pipeline

In [10]:
try:
    config = ConfigurationManager()
    custom_dataset_config = config.get_dataset_config()
    create_dataset = DatasetCreation(config = custom_dataset_config)
    create_dataset.create_dataloader()
    create_dataset.save_dataset()
except Exception as e:
    raise e

[2024-12-08 10:26:53,693: INFO: common: yaml file: config\config.yaml loaded successfully]
[2024-12-08 10:26:53,695: INFO: common: yaml file: params.yaml loaded successfully]
[2024-12-08 10:26:53,696: INFO: common: created directory at: artifacts]
[2024-12-08 10:26:53,696: INFO: common: created directory at: artifacts/custom_dataset]
[2024-12-08 10:26:53,698: INFO: 4238582117: Creating the Custom Dataset.]
[2024-12-08 10:26:53,786: INFO: dataset: Vocabulary loaded from artifacts/tokenization/data.]
[2024-12-08 10:26:53,787: INFO: 4238582117: Custom Dataset is created.]
[2024-12-08 10:26:53,788: INFO: 4238582117: Data Loader is getting created.]
[2024-12-08 10:26:53,789: INFO: 4238582117: Data Loader is created.]
[2024-12-08 10:26:53,789: INFO: 4238582117: Saving dataset batches...]


Saving Batches: 100%|██████████| 80/80 [03:33<00:00,  2.67s/batch]


[2024-12-08 10:35:17,697: INFO: 4238582117: DataLoader batches saved to artifacts/custom_dataset/data]
