In [None]:
import os

In [148]:
import mysql.connector
from mysql.connector import Error
from imageCaptioningWithAttention.components.data_processing import MySQLServer
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
from imageCaptioningWithAttention.constants import CONFIG_FILE_PATH, PARAMS_FILE_PATH
from imageCaptioningWithAttention.utils.common import read_yaml, create_directories
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator


In [164]:
server = MySQLServer('localhost', 'root', 'Yesminister22!', 'image_captioning')
db_connection = server.db_connection()

<mysql.connector.connection_cext.CMySQLConnection object at 0x000002220E099670>
[2023-07-22 21:31:24,861: INFO: data_processing: Database connection successful with host name localhost and username root.]


In [150]:
class DatasetConfig():
    def __init__(self, data_path, train_split, val_split, test_split):
        self.data_path = data_path
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split

In [151]:
class ConfigurationManager():
    def __init__(self, config_file_path, params_file_path):
        super(ConfigurationManager, self).__init__()
        self.config = read_yaml(config_file_path)
        self.params = read_yaml(params_file_path)
        create_directories([self.config.artifacts_root], verbose=True)
    
    def get_dataset_config(self):
        config = self.config.dataset
        data_processing_config = DatasetConfig(config.data_path, config.train, config.val, config.test)
        return data_processing_config

In [152]:
config = ConfigurationManager(CONFIG_FILE_PATH, PARAMS_FILE_PATH)
dataset_config = config.get_dataset_config()

[2023-07-22 21:29:01,315: INFO: common: YAML file config\config.yaml loaded successfully]
[2023-07-22 21:29:01,319: INFO: common: YAML file params.yaml loaded successfully]
[2023-07-22 21:29:01,321: INFO: common: created directory at: artifacts]


In [200]:
class ImageCaptionDataset(Dataset):
    def __init__(self, data, data_path, split, transform=None):
        self.data = data
        self.data_path = data_path
        self.split = split
        assert split in set(['train', 'val', 'test'])
        self.transform = transform
        self.image_paths = []
        self.captions = []
        for _, path, caption_string in data:
            image_captions = caption_string.split('@')
            for j in image_captions:
                self.image_paths.append(path)
                self.captions.append(j)
        self.vocab = build_vocab_from_iterator(map(self.tokenizer, self.captions))
        self.vectorized_captions = [torch.tensor([self.vocab[token] for token in self.tokenizer(text)], dtype=torch.long) for text in self.captions]
    
    def tokenizer(self, text):
        return [token.lower() for token in get_tokenizer('basic_english')(text)]
    
    def __getitem__(self, index):
        img_filename = self.image_paths[index]
        vectorized_caption = self.vectorized_captions[index]
        with open(self.data_path + img_filename, 'rb') as f:
            img = Image.open(f)
            if self.transform is not None:
                img = self.transform(img)
            img = transforms.ToTensor()(img)
        all_captions = self.vectorized_captions[(index//5)*5:(index//5)*5+5]
        if self.split == 'train':
            return img, vectorized_caption, vectorized_caption.shape[0]
        else:
            return img, vectorized_caption, vectorized_caption.shape[0], all_captions
    
    def __len__(self):
        return len(self.vectorized_captions)

In [188]:
data = server.read_query(db_connection, 'SELECT * FROM table_image_caption')

In [189]:
train_data = data[:int(len(data)*dataset_config.train_split)]
val_data = data[int(len(data)*dataset_config.train_split):int(len(data)*(dataset_config.train_split+dataset_config.val_split))]
test_data = data[int(len(data)*(dataset_config.train_split+dataset_config.val_split)):]

In [201]:
train_dataset = ImageCaptionDataset(train_data, dataset_config.data_path, 'train')

In [202]:
val_dataset = ImageCaptionDataset(val_data, dataset_config.data_path, 'val')

In [203]:
test_dataset = ImageCaptionDataset(test_data, dataset_config.data_path, 'test')
print(len(train_dataset), len(val_dataset), len(test_dataset))

28315 10115 2025
