In [1]:
import os


In [2]:
%pwd

'c:\\Users\\rahul\\Desktop\\Project\\CaptionAI\\research'

In [3]:
os.chdir("../")

In [4]:
%pwd

'c:\\Users\\rahul\\Desktop\\Project\\CaptionAI'

Entity

In [5]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen = True)
class ModelTrainerConfig:
    root_dir: Path
    data_dir: Path
    token_dir: Path
    emb_size: int
    attn_size: int
    enc_hidden_size: int
    dec_hidden_size: int
    learning_rate: float

Configuration

In [6]:
from CaptionAI.constants import *
from CaptionAI.utils.common import read_yaml, create_directories

In [7]:
class ConfigurationManager:
    def __init__(self,
                 config_file_path = CONFIG_FILE_PATH,
                 params_file_path = PARAMS_FILE_PATH):
        
        self.config = read_yaml(config_file_path)
        self.params = read_yaml(params_file_path)

        create_directories([self.config.artifacts_root])

    def get_model_trainer_config(self):
        config = self.config.model_trainer
        params = self.params.model_params

        create_directories([config.root_dir])

        modeltrainer_config = ModelTrainerConfig(
            root_dir = config.root_dir,
            data_dir = config.data_dir,
            token_dir = config.token_dir,
            emb_size = params.emb_size,
            attn_size = params.attn_size,
            enc_hidden_size = params.enc_hidden_size,
            dec_hidden_size = params.dec_hidden_size,
            learning_rate = params.learning_rate
        )

        return modeltrainer_config

Components

In [12]:
from CaptionAI.utils.model import EncoderCNN, DecoderRNN, Attention
from CaptionAI.utils.common import get_device
from CaptionAI import logger
import torch
import torch.optim as optim
import torch.nn as nn
import pickle
from tqdm import tqdm

In [13]:
class Img2Caption(nn.Module):
    def __init__(self,
                 emb_size,
                 vocab_size,
                 attn_size,
                 enc_hidden_size,
                 dec_hidden_size,
                 drop_prob = 0.3):
        super(Img2Caption, self).__init__()

        self.encoder = EncoderCNN()

        self.decoder = DecoderRNN(
            embd_size = emb_size,
            vocab_size = vocab_size,
            attn_size = attn_size,
            enc_hidden_state = enc_hidden_size,
            dec_hidden_state = dec_hidden_size
        )

    def forward(self, images, captions):
        features = self.encoder(images)
        output = self.decoder(features, captions)
        return output

In [None]:
class ModelTrain:
    def __init__(self, config: ModelTrainerConfig):
        self.config = config
        self.data_loader = None
        self._get_vocab()
        self._read_data_loader()
        self._init_model()

    def _get_vocab(self):
        with open(self.config.token_dir, "rb") as f:
            self.vocab = pickle.load(f)

    def _init_model(self):

        self.device = get_device()

        self.model = Img2Caption(
            emb_size = self.config.emb_size,
            vocab_size = len(self.vocab),
            attn_size = self.config.attn_size,
            enc_hidden_size = self.config.enc_hidden_size,
            dec_hidden_size = self.config.dec_hidden_size
        ).to(self.device)

        self.criterion = nn.CrossEntropyLoss(ignore_index = self.vocab["<pad>"])
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr = self.config.learning_rate
        )
        logger.info("Model is Initiated.")

    def save_model(self, model, num_epochs):

        model_state = {
            "num_epcohs": num_epochs,
            "emb_size": self.config.emb_size,
            "vocab_size": len(self.vocab),
            "attn_size": self.config.attn_size,
            "enc_hidden_size": self.config.enc_hidden_size,
            "dec_hidden_size": self.config.dec_hidden_size,
            "state_dict": model.state_dict()
        }
        torch.save(model_state, f"{self.config.root_dir}/attention_model_state.pth")

    def _read_data_loader(self):
        with open(self.config.data_dir, "rb") as f:
            self.data_loader = pickle.load(f)

    def train_model(self, num_epochs: int = 3, print_every: int = 100):

        logger.info("Model training has been started.")
        for epoch in range(num_epochs):
            with tqdm(enumerate(self.data_loader), total = len(self.data_loader), desc = f"Epoch {epoch + 1}/{num_epochs}") as pbar:
                for idx, (image, captions) in pbar:
                    image, captions = image.to(self.device), captions.to(self.device)
                    self.model.train()
                    self.optimizer.zero_grad()

                    outputs, attentions = self.model(image, captions)

                    targets = captions[:, 1:]

                    loss = self.criterion(outputs.view(-1, len(self.vocab)), targets.reshape(-1))

                    loss.backward()
                    self.optimizer.step()
                    if (idx + 1) % print_every == 0:
                        logger.info(f"Epoch [{epoch}/{num_epochs}], Step [{idx + 1}/{len(self.data_loader)}], Loss: {loss.item()}")

                        self.model.eval()
                        with torch.no_grad():
                            img, _ = next(iter(self.data_loader))
                            features = self.model.encoder(img[0:1].to(self.device))
                            caps, attn_weights = self.model.decoder.generate_caption(features, self.vocab)

                            caption = " ".join(caps)
                            print(caption)

                        self.model.train()

                self.save_model(self.model, epoch)

Pipeline

In [15]:
try:
    config = ConfigurationManager()
    model_trainer_config = config.get_model_trainer_config()
    model_train = ModelTrain(model_trainer_config)
    model_train.train_model()
except Exception as e:
    raise e

[2024-12-10 15:00:58,558: INFO: common: yaml file: config\config.yaml loaded successfully]
[2024-12-10 15:00:58,559: INFO: common: yaml file: params.yaml loaded successfully]
[2024-12-10 15:00:58,560: INFO: common: created directory at: artifacts]
[2024-12-10 15:00:58,560: INFO: common: created directory at: artifacts/model_trainer]
[2024-12-10 15:01:31,328: INFO: 1902623072: Model is Initiated.]
[2024-12-10 15:01:31,344: INFO: 1902623072: Model training has been started.]


Epoch 1/3: 100%|██████████| 40/40 [11:42<00:00, 17.57s/it]
Epoch 2/3: 100%|██████████| 40/40 [11:38<00:00, 17.47s/it]
Epoch 3/3: 100%|██████████| 40/40 [11:38<00:00, 17.47s/it]
