# Data loading and feature extraction

#### Load Karpathy split and organize the COCO data according to it

In [None]:
import json
import pandas as pd
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from tqdm import tqdm
from PIL import Image
import numpy as np
import os


karpathy_file = '/kaggle/input/karpathy-splits/dataset_coco.json'

if not os.path.exists(karpathy_file):
    raise FileNotFoundError(f"Karpathy split not found at: {karpathy_file}")

with open(karpathy_file, 'r') as f:
    karpathy_data = json.load(f)


def organize_by_split(karpathy_data):
    splits = {'train': [], 'val': [], 'test': []}
    
    for img_data in karpathy_data['images']:
        split = img_data['split']
        
        # handle 'restval' - we add them to the training set 
        if split == 'restval':
            split = 'train'
        
        if split in ['train', 'val', 'test']:
          
            image_info = {
                'image_id': img_data['cocoid'],
                'file_name': img_data['filename'],  
                'captions': [sent['raw'] for sent in img_data['sentences']]
            }
            splits[split].append(image_info)
    
    return splits

splits_data = organize_by_split(karpathy_data)

# convert to DataFrames
train_df = pd.DataFrame(splits_data['train'])
val_df = pd.DataFrame(splits_data['val'])
test_df = pd.DataFrame(splits_data['test'])
print(train_df.head())

Load the pretrained feature extractor models

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {device}")

# VGG16 - fc7 features (4096-dim) - matching the paper

vgg16 = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
vgg16.classifier = vgg16.classifier[:-1]  # remove last layer to get fc7
vgg16 = vgg16.to(device)
vgg16.eval()

# ResNet101

resnet101 = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1)
resnet101 = torch.nn.Sequential(*list(resnet101.children())[:-1])  # remove FC layer
resnet101 = resnet101.to(device)
resnet101.eval()

## Feature extraction

In [None]:
# image preprocessing
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
])


def extract_features(image_path, model):
    try:
        img = Image.open(image_path).convert('RGB')
        img = transform(img).unsqueeze(0).to(device)
        
        with torch.no_grad():
            features = model(img)
            if len(features.shape) > 2:
                features = features.squeeze()
        
        return features.cpu().numpy()
    except Exception as e:
        return None


def get_image_path(filename, base_paths):   
    # determine which folder based on filename
    if 'train2014' in filename:
        folder = 'train2014'
    elif 'val2014' in filename:
        folder = 'val2014'
    else:
        return None
    
    img_path = f"{base_paths[folder]}/{filename}"
    
    if os.path.exists(img_path):
        return img_path
    else:
        return None



def extract_and_save_split_features(df, split_name, base_paths, models_dict):
   
    features_by_model = {model_name: {} for model_name in models_dict.keys()}
    missing_images = []
    processed = 0
    
    for idx, row in tqdm(df.iterrows(), total=len(df), desc=f"{split_name}"):
        img_id = row['image_id']
        img_filename = row['file_name']
        
        img_path = get_image_path(img_filename, base_paths)
        
        if img_path is None:
            missing_images.append(img_filename)
            continue
        
        # extract features with each model
        for model_name, model in models_dict.items():
            features = extract_features(img_path, model)
            if features is not None:
                features_by_model[model_name][img_id] = features
        
        processed += 1
    

BASE_PATHS = {
    'train2014': '/kaggle/input/coco2014/train2014/train2014',
    'val2014': '/kaggle/input/coco2014/val2014/val2014'
}


models_dict = {
    'vgg16': vgg16,
    'resnet101': resnet101,
}

#start feature extraction
train_features = extract_and_save_split_features(train_df, 'train', BASE_PATHS, models_dict)
val_features = extract_and_save_split_features(val_df, 'val', BASE_PATHS, models_dict)
test_features = extract_and_save_split_features(test_df, 'test', BASE_PATHS, models_dict)

Save the caption metadata

In [None]:
train_df.to_pickle('train_captions.pkl')
val_df.to_pickle('val_captions.pkl')
test_df.to_pickle('test_captions.pkl')

train_df.to_csv('train_captions.csv', index=False)
val_df.to_csv('val_captions.csv', index=False)
test_df.to_csv('test_captions.csv', index=False)

# Vocabulary


In [None]:
import pandas as pd
import numpy as np
import pickle
import re
from collections import Counter
from tqdm import tqdm


# vocabulary size  ~9,221 words + special tokens ( matches the paper's approach)

class Vocabulary:
    
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.word_counts = Counter()
        
        # special tokens
        self.PAD_TOKEN = '<PAD>'
        self.START_TOKEN = '<START>'
        self.END_TOKEN = '<END>'
        self.UNK_TOKEN = '<UNK>'
        
        # initialize with special tokens
        self.word2idx = {
            self.PAD_TOKEN: 0,
            self.START_TOKEN: 1,
            self.END_TOKEN: 2,
            self.UNK_TOKEN: 3
        }
        self.idx2word = {v: k for k, v in self.word2idx.items()}
        self.idx = 4  # next available index
    
    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1
    
    def __len__(self):
        return len(self.word2idx)
    
    def __call__(self, word):
        return self.word2idx.get(word, self.word2idx[self.UNK_TOKEN])


def tokenize_caption(caption):
    """
    - Convert to lowercase
    - Remove punctuation (except hyphens in words)
    - Split into words

    """
    caption = caption.lower()
    
    # keeps alphanumeric, apostrophes, and hyphens
    caption = re.sub(r'[^\w\s\'-]', ' ', caption)
    
    # split and remove extra whitespace
    tokens = caption.split()
    
    # remove empty strings
    tokens = [t for t in tokens if t]
    
    return tokens


def build_vocabulary(train_captions_df, vocab_size=9221, min_word_freq=5):
    
    vocab = Vocabulary()
    
    # count all words in training captions
    all_tokens = []
    
    for idx, row in tqdm(train_captions_df.iterrows(), 
                         total=len(train_captions_df),
                         desc="Processing"):
        captions = row['captions']

        #process captions
        for caption in captions:
            tokens = tokenize_caption(caption)
            all_tokens.extend(tokens)
            vocab.word_counts.update(tokens)
    

    # filter by minimum frequency
    filtered_words = {word: count for word, count in vocab.word_counts.items() 
                      if count >= min_word_freq}
    

    most_common = sorted(filtered_words.items(), key=lambda x: x[1], reverse=True)
    
    # vocab_size - 4 (to account for special tokens)
    top_words = most_common[:vocab_size - 4]
    
    for word, count in tqdm(top_words, desc="Adding words"):
        vocab.add_word(word)
    
    return vocab


def save_vocabulary(vocab, filepath='vocabulary.pkl'):
    with open(filepath, 'wb') as f:
        pickle.dump(vocab, f)


def load_vocabulary(filepath='vocabulary.pkl'):
    with open(filepath, 'rb') as f:
        vocab = pickle.load(f)

    return vocab




# load training captions
train_df = pd.read_pickle('train_captions.pkl')

vocab = build_vocabulary(
    train_df, 
    vocab_size=9221,  
    min_word_freq=5  
)

analyze_vocabulary(vocab, train_df)

save_vocabulary(vocab, 'vocabulary.pkl')

# Dataset and DataLoaders

Dataset Objects

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import pickle
import random
from typing import Tuple, List


class CaptionDataset(Dataset):
    """
    dataset for image captioning that loads pre-extracted features and captions.
    
    Args:
        captions_df: DataFrame with columns ['image_id', 'file_name', 'captions']
        features_dict: Dictionary mapping image_id -> feature vector 
        vocabulary: Vocabulary object
        max_length: Maximum caption length (default: 15, matching paper)
        training: If True, randomly sample one caption per image per epoch
    """
    
    def __init__(self, 
                 captions_df: pd.DataFrame,
                 features_dict: dict,
                 vocabulary,
                 max_length: int = 15,
                 training: bool = True):
        
        self.captions_df = captions_df.reset_index(drop=True)
        self.features_dict = features_dict
        self.vocab = vocabulary
        self.max_length = max_length
        self.training = training
        
    
    def __len__(self):
        return len(self.valid_indices)
    
    def tokenize_caption(self, caption: str) -> List[str]:
        
        import re
        caption = caption.lower()
        caption = re.sub(r'[^\w\s\'-]', ' ', caption)
        tokens = caption.split()
        tokens = [t for t in tokens if t]
        return tokens
    
    def caption_to_sequence(self, caption: str) -> Tuple[torch.Tensor, int]:
        """
        Convert caption to sequence of word indices with <START> and <END>. Pad to max length using <PAD>
        """
        
        words = self.tokenize_caption(caption)
        
        # leave room for START and END
        if len(words) > self.max_length:
            words = words[:self.max_length]
        
        # convert to indices and add START + END
        tokens = [self.vocab.word2idx[self.vocab.START_TOKEN]]
        tokens.extend([self.vocab(word) for word in words])
        tokens.append(self.vocab.word2idx[self.vocab.END_TOKEN])
        
        actual_length = len(tokens)
        
        # pad to max_length + 2 (for START and END)
        max_seq_len = self.max_length + 2
        if len(tokens) < max_seq_len:
            tokens.extend([self.vocab.word2idx[self.vocab.PAD_TOKEN]] * (max_seq_len - len(tokens)))
        
        return torch.tensor(tokens, dtype=torch.long), actual_length
    
    def __getitem__(self, idx):
        """
        returns:
            image_features: Tensor of shape (feature_dim,) - e.g., (4096,) for VGG16
            caption: Tensor of shape (max_length + 2,) - padded caption sequence
            caption_length: int - actual caption length including START/END
        """
        df_idx = self.valid_indices[idx]
        row = self.captions_df.iloc[df_idx]
        
        # Get image features
        image_id = row['image_id']
        image_features = self.features_dict[image_id]
        image_features = torch.from_numpy(image_features).float()
        
        # get caption
        captions = row['captions']
        if self.training:
            caption = random.choice(captions)
        else:
            caption = captions[0]
        
        # Convert caption to sequence
        caption_seq, caption_length = self.caption_to_sequence(caption)
        
        return image_features, caption_seq, caption_length


class CaptionDatasetAllCaptions(Dataset):
    """
    dataset that returns all captions for each image.
    
    Args:
        captions_df: DataFrame with columns ['image_id', 'file_name', 'captions']
        features_dict: Dictionary mapping image_id -> feature vector
        vocabulary: Vocabulary object
        max_length: Maximum caption length
    """
    
    def __init__(self, 
                 captions_df: pd.DataFrame,
                 features_dict: dict,
                 vocabulary,
                 max_length: int = 15):
        
        self.captions_df = captions_df.reset_index(drop=True)
        self.features_dict = features_dict
        self.vocab = vocabulary
        self.max_length = max_length
        
        # create expanded dataset with one entry per caption
        self.data = []
        for idx, row in self.captions_df.iterrows():
            if row['image_id'] in self.features_dict:
                for caption in row['captions']:
                    self.data.append({
                        'image_id': row['image_id'],
                        'caption': caption,
                        'all_captions': row['captions']
                    })
        

    
    def tokenize_caption(self, caption: str) -> List[str]:
        import re
        caption = caption.lower()
        caption = re.sub(r'[^\w\s\'-]', ' ', caption)
        tokens = caption.split()
        tokens = [t for t in tokens if t]
        return tokens
    
    def caption_to_sequence(self, caption: str) -> Tuple[torch.Tensor, int]:
        words = self.tokenize_caption(caption)
        if len(words) > self.max_length:
            words = words[:self.max_length]
        
        tokens = [self.vocab.word2idx[self.vocab.START_TOKEN]]
        tokens.extend([self.vocab(word) for word in words])
        tokens.append(self.vocab.word2idx[self.vocab.END_TOKEN])
        
        actual_length = len(tokens)
        max_seq_len = self.max_length + 2
        
        if len(tokens) < max_seq_len:
            tokens.extend([self.vocab.word2idx[self.vocab.PAD_TOKEN]] * (max_seq_len - len(tokens)))
        
        return torch.tensor(tokens, dtype=torch.long), actual_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        image_features = self.features_dict[item['image_id']]
        image_features = torch.from_numpy(image_features).float()
        
        # convert caption to sequence
        caption_seq, caption_length = self.caption_to_sequence(item['caption'])
        
        return image_features, caption_seq, caption_length, item['all_captions']

## Helper functions and main dataloader entrypoint

In [None]:
def collate_fn(batch):
    """
    Custom collate function for batching.
    
    Args:
        batch: List of (image_features, caption, length) tuples
    
    Returns:
        images: Tensor of shape (batch_size, feature_dim)
        captions: Tensor of shape (batch_size, max_length + 2)
        lengths: Tensor of shape (batch_size,)
    """
    # Separate the components
    images, captions, lengths = zip(*batch)
    
    # Stack into tensors
    images = torch.stack(images, dim=0)
    captions = torch.stack(captions, dim=0)
    lengths = torch.tensor(lengths, dtype=torch.long)
    
    return images, captions, lengths


def collate_fn_eval(batch):
    """
    Custom collate function for evaluation (includes all reference captions).
    
    Returns:
        images: Tensor of shape (batch_size, feature_dim)
        captions: Tensor of shape (batch_size, max_length + 2)
        lengths: Tensor of shape (batch_size,)
        all_captions: List of lists of reference captions
    """
    images, captions, lengths, all_captions = zip(*batch)
    
    images = torch.stack(images, dim=0)
    captions = torch.stack(captions, dim=0)
    lengths = torch.tensor(lengths, dtype=torch.long)
    
    return images, captions, lengths, list(all_captions)


def create_dataloaders(
    train_df: pd.DataFrame,
    val_df: pd.DataFrame,
    test_df: pd.DataFrame,
    train_features: dict,
    val_features: dict,
    test_features: dict,
    vocabulary,
    batch_size: int = 64,
    max_length: int = 15,
    num_workers: int = 4,
    shuffle_train: bool = True
):
    """
    Create train, validation, and test dataloaders.
    
    Args:
        train_df, val_df, test_df: DataFrames with captions
        train_features, val_features, test_features: Feature dictionaries
        vocabulary: Vocabulary object
        batch_size: Batch size for training
        max_length: Maximum caption length
        num_workers: Number of worker processes for data loading
        shuffle_train: Whether to shuffle training data
    
    Returns:
        train_loader, val_loader, test_loader
    """
  
    # create datasets
    train_dataset = CaptionDataset(
        train_df, 
        train_features, 
        vocabulary, 
        max_length=max_length,
        training=True
    )
    
    val_dataset = CaptionDataset(
        val_df, 
        val_features, 
        vocabulary, 
        max_length=
        \,
        training=False
    )
    
    test_dataset = CaptionDataset(
        test_df, 
        test_features, 
        vocabulary, 
        max_length=max_length,
        training=False
    )
    
    # create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle= ,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    print(f"\nDataLoader Configuration:")
    print(f"  Batch size: {batch_size}")
    print(f"  Num workers: {num_workers}")
    print(f"  Max caption length: {max_length}")
    
    print(f"\nDataLoader Sizes:")
    print(f"  Train batches: {len(train_loader)}")
    print(f"  Val batches: {len(val_loader)}")
    print(f"  Test batches: {len(test_loader)}")
    
    return train_loader, val_loader, test_loader


# Decoder-Transformer Implementation

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import math



class ImageCaptioningTransformer(nn.Module):
 
    def __init__(self, vocab_size, embed_dim=512, num_heads=8, 
                 num_layers=4, image_feat_dim=4096, max_len=17,dropout=0.1):
        super().__init__()
        
        self.embed_dim = embed_dim
        self.max_len = max_len
        self.vocab_size = vocab_size
        
        self.image_embed = nn.Sequential(
            nn.Linear(image_feat_dim, embed_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        self.word_embed = nn.Embedding(vocab_size, embed_dim)
        self.positional_embed = nn.Embedding(max_len+1, embed_dim) # +1 for the image

        self.embed_dropout = nn.Dropout(dropout)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=2048,
            dropout=dropout,
            activation='relu',
            batch_first=True, 
            norm_first=False  
        )

        self.transformer = nn.TransformerEncoder(
            encoder_layer=encoder_layer,
            num_layers = num_layers
        )
        
        self.output_project= nn.Linear(embed_dim, vocab_size)
        

    def generate_causal_mask(self, seq_len, device):
        """
        Generate causal mask: upper triangular matrix of -inf.
        
        """
        mask = torch.triu(
            torch.ones(seq_len, seq_len, device=device) * float('-inf'),
            diagonal=1
        )
        return mask
    
    def generate_padding_mask(self,captions,pad_idx=0):
        return (captions == pad_idx)
        
    def forward(self,images,captions):
        """
        forward pass
        """
        device = captions.device
        batch_size, seq_len = captions.shape

        img_embed = self.image_embed(images).unsqueeze(1) # -> (batch, 1,512)

        caption_embed = self.word_embed(captions)

        sequence = torch.cat([img_embed,caption_embed], dim = 1)

        positional_encoding = torch.arange(seq_len + 1,device=device).unsqueeze(0)
        positional_encoding= self.positional_embed(positional_encoding)

        sequence = sequence + postional_encoding
        sequence = self.embed_dropout(sequence)
        
        causal_mask = self.generate_causal_mask(seq_len +1, device)
        padding_mask = self.generate_padding_mask(captions)

        img_padding_mask = torch.zeros(batch_size,1, dtype=bool, device=device)

        padding_mask = torch.cat([img_padding_mask, padding_mask], dim = 1)

        output = self.transformer(
            sequence,
            mask=causal_mask,          
            src_key_padding_mask=padding_mask
        )

        logits = self.output_project(output)
        
        return logits


    def generate(self,images, start_token_idx=1, end_token_idx =2):

         """
        Generate captions autoregressively
        
        Start with [START]
        Loop:
          1. Get predictions for current sequence
          2. Take last prediction
          3. Sample/argmax next token
          4. Append to sequence
          5. Stop if END token or max_len reached
        """
        self.eval()
        batch_size = images.shape[0]
        device = images.device
        generated = torch.full((batch_size,1), start_token_idx,dtype= torch.long, device=device)

        # track what sequences have finished
        finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
        with torch.no_grad():
            for _ in range(self.max_len):
                logits = self.forward(images, generated)

                next_logits = logits[:,-1,:]
                next_token = next_logits.argmax(dim=1,keepdim = True)

                generated = torch.cat([generated, next_token], dim = 1)

                finished = finished | (next_token.squeeze(-1) == end_token_idx)

                # stop if all sequences have finished
                if finished.all():
                    break

        return generated
    


## Loss function and padding mask

In [1]:
class MaskedCrossEntropyLoss(nn.Module):
    def __init__(self,pad_idx=0):
        self.pad_idx = pad_idx
        self.criterion = nn.CrossEntropy(reduction='none')

    def forward(self, logits, targets):
        """
        logits (batch,seq_len,vocab_size)
        targets (batch, seq_len, 1)
        """
        batch_size,seq_len, vocab_size = logits.shape
        device = logits.device
        
        logits_flat = logits.reshape(-1,vocab_size) # ( batch * seq_len, vocab_size)
        targets_flat = targets.reshape(-1)  # (batch * seq_len)

        
        loss = self.criterion(logits_flat, targets_flat)

        mask = (targets_flat != self.pad_idx).float()

        loss_masked = loss * mask

        return torch.sum(loss_masked) / torch.sum(mask)
        

SyntaxError: incomplete input (3557761403.py, line 1)

## Train step and validation

In [None]:

def train_step(model, images, captions, lengths, criterion, optimizer, device):
    """
    One training step with teacher forcing
    
    Remember:
      - Input to model: captions
      - Model output: logits at each position
      - Predictions: logits[:, :-1, :] (drop last)
      - Targets: captions (original)
    """

    model.train()
    # Move to gpu
    images = images.to(device)
    captions = captions.to(device)
    lengths = lengths.to(device)
    logits = model(images,captions) # we pass ground truth for "teacher forcing"

    logits = logits[:,:-1,:]

    loss = criterion(logits, captions,lengths)

    updates = 
    
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    
    # Gradient clipping (prevent exploding gradients)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
    
    optimizer.step()
    
    return loss.item()


def validate(model, val_loader, criterion, device):
    """
    Validation loop.
    
    Args:
        model: ImageCaptioningTransformer model
        val_loader: validation DataLoader
        criterion: MaskedCrossEntropyLoss
        device: device to run on
    
    Returns:
        avg_loss: average validation loss
    """
    model.eval()
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        for images, captions, lengths in val_loader:
            images = images.to(device)
            captions = captions.to(device)
            lengths = lengths.to(device)
            
            # forward pass
            logits = model(images, captions)
            
            # compute loss 
            predictions = logits[:, :-1, :]
            targets = captions
            
            loss = criterion(predictions, targets, lengths)
            
            total_loss += loss.item()
            num_batches += 1
    
    return total_loss / num_batches

# Training

In [None]:
def train_epoch(model, criterion, train_loader,optimizer,config,epoch):
    model.train()

    total_loss = 0
    num_batches =0

    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config.num_epochs}')
    for batch_idx, (images,captions,lengths) in enumerate(train_loader):
       
        loss = train_step(model, images, captions, lengths, 
                         criterion, optimizer, config.device)
        
        total_loss+=loss
        num_batches+=1

        if (batch_idx + 1) % config.print_every == 0:
            print(f'\n[Epoch {epoch+1}, Batch {batch_idx+1}/{len(train_loader)}] '
                  f'Loss: {loss:.4f}, Avg Loss: {total_loss/num_batches:.4f}')
    return total_loss/ num_batches



In [None]:
class Config:
    """
    training configuration
    """
   # model
    vocab_size = 9221
    embed_dim = 512
    num_heads = 8
    num_layers = 4
    image_feat_dim = 4096
    max_len = 17
    dropout = 0.1
    
    # training
    batch_size = 64
    num_epochs = 30
    learning_rate = 5e-5
    weight_decay = 1e-4
    base_dir ="/kaggle/working"
    save_dir="/checkpoints"
    best_dir = "/best_model"
    save_every= 10 
    
    
    # device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    



def train(config):
    """
    -. we need to call the train_step function for n epochs for all batches
    - 
    
    """
    train_df = pd.read_pickle('train_captions.pkl')
    val_df = pd.read_pickle('val_captions.pkl')

    train_features = np.load('coco_train_vgg16_features.npy', allow_pickle=True).item()
    val_features = np.load('coco_val_vgg16_features.npy',allow_pickle = True).item()
    
    vocab = build_vocabulary(train_df,config.vocab_size)
    
    train_loader, val_loader, _ = create_dataloaders(# test_features and test_df are not used during training, so i used val_df as a dummy
        train_df,val_df,val_df,train_features,val_features,val_features,vocab,config.batch_size
    )

    model = ImageCaptioningTransformer(config.vocab_size, 
                                       config.embed_dim,
                                       config.num_heads,
                                       config.num_layers, 
                                       config.image_feat_dim,
                                       config.max_len,
                                       config.dropout
                                      ).to(config.device)
    

    optimizer = torch.optim.RMSprop(model.parameters(), 
                                    lr=config.learning_rate)
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config.num_epochs
    )

    criterion = MaskedCrossEntropyLoss()
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []

    os.makedirs(os.path.join(config.base_dir, config.save_dir),exist_ok=True)
    
    for epoch in range(config.num_epochs):
        epoch_start= time.time()
        
        train_loss = train_epoch(model, criterion, train_loader,optimizer,config,epoch)
        train_losses.append(train_loss)

        val_loss = validate(model,val_loader,criterion,config.device)
        val_losses.append(val_loss)
        
        scheduler.step()

        epoch_time = time.time() - epoch_start

        if (epoch +1) % config.save_every ==0:
             checkpoint_path = os.path.join(
                config.base_dir,
                config.save_dir,
                f'checkpoint_epoch_{epoch+1}.pt'
            )
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'config': config
            }, checkpoint_path)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            checkpoint_path = os.path.join(
                config.base_dir,
                config.save_dir,
                f'best_model.pt'
            )
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'config': config
            }, checkpoint_path)
        
    history = {
        'train_losses': train_losses,
        'val_losses': val_losses
    }
    
    history_path = os.path.join(config.base_dir,config.save_dir, 'training_history.pkl')
    with open(history_path, 'wb') as f:
        pickle.dump(history, f)
        
    print("Training history saved")
