# Data loading and feature extraction

#### Load Karpathy split and organize the COCO data according to it

In [None]:
import json
import pandas as pd
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from tqdm import tqdm
from PIL import Image
import numpy as np
import os


karpathy_file = '/kaggle/input/karpathy-splits/dataset_coco.json'

if not os.path.exists(karpathy_file):
    raise FileNotFoundError(f"Karpathy split not found at: {karpathy_file}")

with open(karpathy_file, 'r') as f:
    karpathy_data = json.load(f)


def organize_by_split(karpathy_data):
    splits = {'train': [], 'val': [], 'test': []}
    
    for img_data in karpathy_data['images']:
        split = img_data['split']
        
        # handle 'restval' - we add them to the training set 
        if split == 'restval':
            split = 'train'
        
        if split in ['train', 'val', 'test']:
          
            image_info = {
                'image_id': img_data['cocoid'],
                'file_name': img_data['filename'],  
                'captions': [sent['raw'] for sent in img_data['sentences']]
            }
            splits[split].append(image_info)
    
    return splits

splits_data = organize_by_split(karpathy_data)

# convert to DataFrames
train_df = pd.DataFrame(splits_data['train'])
val_df = pd.DataFrame(splits_data['val'])
test_df = pd.DataFrame(splits_data['test'])
print(train_df.head())

Load the pretrained feature extractor models

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {device}")

# VGG16 - fc7 features (4096-dim) - matching the paper

vgg16 = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
vgg16.classifier = vgg16.classifier[:-1]  # remove last layer to get fc7
vgg16 = vgg16.to(device)
vgg16.eval()

# ResNet101

resnet101 = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1)
resnet101 = torch.nn.Sequential(*list(resnet101.children())[:-1])  # remove FC layer
resnet101 = resnet101.to(device)
resnet101.eval()

## Feature extraction

In [None]:
# image preprocessing
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
])


def extract_features(image_path, model):
    try:
        img = Image.open(image_path).convert('RGB')
        img = transform(img).unsqueeze(0).to(device)
        
        with torch.no_grad():
            features = model(img)
            if len(features.shape) > 2:
                features = features.squeeze()
        
        return features.cpu().numpy()
    except Exception as e:
        return None


def get_image_path(filename, base_paths):   
    # determine which folder based on filename
    if 'train2014' in filename:
        folder = 'train2014'
    elif 'val2014' in filename:
        folder = 'val2014'
    else:
        return None
    
    img_path = f"{base_paths[folder]}/{filename}"
    
    if os.path.exists(img_path):
        return img_path
    else:
        return None



def extract_and_save_split_features(df, split_name, base_paths, models_dict):
   
    features_by_model = {model_name: {} for model_name in models_dict.keys()}
    missing_images = []
    processed = 0
    
    for idx, row in tqdm(df.iterrows(), total=len(df), desc=f"{split_name}"):
        img_id = row['image_id']
        img_filename = row['file_name']
        
        img_path = get_image_path(img_filename, base_paths)
        
        if img_path is None:
            missing_images.append(img_filename)
            continue
        
        # extract features with each model
        for model_name, model in models_dict.items():
            features = extract_features(img_path, model)
            if features is not None:
                features_by_model[model_name][img_id] = features
        
        processed += 1
    

BASE_PATHS = {
    'train2014': '/kaggle/input/coco2014/train2014/train2014',
    'val2014': '/kaggle/input/coco2014/val2014/val2014'
}


models_dict = {
    'vgg16': vgg16,
    'resnet101': resnet101,
}

#start feature extraction
train_features = extract_and_save_split_features(train_df, 'train', BASE_PATHS, models_dict)
val_features = extract_and_save_split_features(val_df, 'val', BASE_PATHS, models_dict)
test_features = extract_and_save_split_features(test_df, 'test', BASE_PATHS, models_dict)

Save the caption metadata

In [None]:
train_df.to_pickle('train_captions.pkl')
val_df.to_pickle('val_captions.pkl')
test_df.to_pickle('test_captions.pkl')

train_df.to_csv('train_captions.csv', index=False)
val_df.to_csv('val_captions.csv', index=False)
test_df.to_csv('test_captions.csv', index=False)

# Vocabulary


In [5]:
import pandas as pd
import numpy as np
import pickle
import re
from collections import Counter
from tqdm import tqdm


# vocabulary size  ~9,221 words + special tokens ( matches the paper's approach)

class Vocabulary:
    
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.word_counts = Counter()
        
        # special tokens
        self.PAD_TOKEN = '<PAD>'
        self.START_TOKEN = '<START>'
        self.END_TOKEN = '<END>'
        self.UNK_TOKEN = '<UNK>'
        
        # initialize with special tokens
        self.word2idx = {
            self.PAD_TOKEN: 0,
            self.START_TOKEN: 1,
            self.END_TOKEN: 2,
            self.UNK_TOKEN: 3
        }
        self.idx2word = {v: k for k, v in self.word2idx.items()}
        self.idx = 4  # next available index
    
    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1
    
    def __len__(self):
        return len(self.word2idx)
    
    def __call__(self, word):
        return self.word2idx.get(word, self.word2idx[self.UNK_TOKEN])


def tokenize_caption(caption):
    """
    - Convert to lowercase
    - Remove punctuation (except hyphens in words)
    - Split into words

    """
    caption = caption.lower()
    
    # keeps alphanumeric, apostrophes, and hyphens
    caption = re.sub(r'[^\w\s\'-]', ' ', caption)
    
    # split and remove extra whitespace
    tokens = caption.split()
    
    # remove empty strings
    tokens = [t for t in tokens if t]
    
    return tokens


def build_vocabulary(train_captions_df, vocab_size=9221, min_word_freq=5):
    
    vocab = Vocabulary()
    
    # count all words in training captions
    all_tokens = []
    
    for idx, row in tqdm(train_captions_df.iterrows(), 
                         total=len(train_captions_df),
                         desc="Processing"):
        captions = row['captions']

        #process captions
        for caption in captions:
            tokens = tokenize_caption(caption)
            all_tokens.extend(tokens)
            vocab.word_counts.update(tokens)
    

    # filter by minimum frequency
    filtered_words = {word: count for word, count in vocab.word_counts.items() 
                      if count >= min_word_freq}
    

    most_common = sorted(filtered_words.items(), key=lambda x: x[1], reverse=True)
    
    # vocab_size - 4 (to account for special tokens)
    top_words = most_common[:vocab_size - 4]
    
    for word, count in tqdm(top_words, desc="Adding words"):
        vocab.add_word(word)
    
    return vocab


def save_vocabulary(vocab, filepath='vocabulary.pkl'):
    with open(filepath, 'wb') as f:
        pickle.dump(vocab, f)


def load_vocabulary(filepath='vocabulary.pkl'):
    with open(filepath, 'rb') as f:
        vocab = pickle.load(f)

    return vocab




# load training captions
train_df = pd.read_pickle('train_captions.pkl')

vocab = build_vocabulary(
    train_df, 
    vocab_size=9221,  
    min_word_freq=5  
)

save_vocabulary(vocab, 'vocabulary.pkl')

Processing: 100%|██████████| 113287/113287 [00:07<00:00, 15719.59it/s]
Adding words: 100%|██████████| 9217/9217 [00:00<00:00, 1943244.19it/s]


# Dataset and DataLoaders

Dataset Objects

In [6]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import pickle
import random
from typing import Tuple, List


class CaptionDataset(Dataset):
    """
    dataset for image captioning that loads pre-extracted features and captions.
    
    Args:
        captions_df: DataFrame with columns ['image_id', 'file_name', 'captions']
        features_dict: Dictionary mapping image_id -> feature vector 
        vocabulary: Vocabulary object
        max_length: Maximum caption length (default: 15, matching paper)
        training: If True, randomly sample one caption per image per epoch
    """
    
    def __init__(self, 
                 captions_df: pd.DataFrame,
                 features_dict: dict,
                 vocabulary,
                 max_length: int = 15,
                 training: bool = True):
        
        self.captions_df = captions_df.reset_index(drop=True)
        self.features_dict = features_dict
        self.vocab = vocabulary
        self.max_length = max_length
        self.training = training
        self.valid_indices = []
        for idx, row in self.captions_df.iterrows():
            if row['image_id'] in self.features_dict:
                self.valid_indices.append(idx)

        
    
    def __len__(self):
        return len(self.valid_indices)
    
    def tokenize_caption(self, caption: str) -> List[str]:
        
        import re
        caption = caption.lower()
        caption = re.sub(r'[^\w\s\'-]', ' ', caption)
        tokens = caption.split()
        tokens = [t for t in tokens if t]
        return tokens
    
    def caption_to_sequence(self, caption: str) -> Tuple[torch.Tensor, int]:
        """
        Convert caption to sequence of word indices with <START> and <END>. Pad to max length using <PAD>
        """
        
        words = self.tokenize_caption(caption)
        
        # leave room for START and END
        if len(words) > self.max_length:
            words = words[:self.max_length]
        
        # convert to indices and add START + END
        tokens = [self.vocab.word2idx[self.vocab.START_TOKEN]]
        tokens.extend([self.vocab(word) for word in words])
        tokens.append(self.vocab.word2idx[self.vocab.END_TOKEN])
        
        actual_length = len(tokens)
        
        # pad to max_length + 2 (for START and END)
        max_seq_len = self.max_length + 2
        if len(tokens) < max_seq_len:
            tokens.extend([self.vocab.word2idx[self.vocab.PAD_TOKEN]] * (max_seq_len - len(tokens)))
        
        return torch.tensor(tokens, dtype=torch.long), actual_length
    
    def __getitem__(self, idx):
        """
        returns:
            image_features: Tensor of shape (feature_dim,) - e.g., (4096,) for VGG16
            caption: Tensor of shape (max_length + 2,) - padded caption sequence
            caption_length: int - actual caption length including START/END
        """
        df_idx = self.valid_indices[idx]
        row = self.captions_df.iloc[df_idx]
        
        # Get image features
        image_id = row['image_id']
        image_features = self.features_dict[image_id]
        image_features = torch.from_numpy(image_features).float()
        
        # get caption
        captions = row['captions']
        if self.training:
            caption = random.choice(captions)
        else:
            caption = captions[0]
        
        # Convert caption to sequence
        caption_seq, caption_length = self.caption_to_sequence(caption)
        
        return image_features, caption_seq, caption_length


class CaptionDatasetAllCaptions(Dataset):
    """
    dataset that returns all captions for each image.
    
    Args:
        captions_df: DataFrame with columns ['image_id', 'file_name', 'captions']
        features_dict: Dictionary mapping image_id -> feature vector
        vocabulary: Vocabulary object
        max_length: Maximum caption length
    """
    
    def __init__(self, 
                 captions_df: pd.DataFrame,
                 features_dict: dict,
                 vocabulary,
                 max_length: int = 15):
        
        self.captions_df = captions_df.reset_index(drop=True)
        self.features_dict = features_dict
        self.vocab = vocabulary
        self.max_length = max_length
        
        # create expanded dataset with one entry per caption
        self.data = []
        for idx, row in self.captions_df.iterrows():
            if row['image_id'] in self.features_dict:
                for caption in row['captions']:
                    self.data.append({
                        'image_id': row['image_id'],
                        'caption': caption,
                        'all_captions': row['captions']
                    })
        

    
    def tokenize_caption(self, caption: str) -> List[str]:
        import re
        caption = caption.lower()
        caption = re.sub(r'[^\w\s\'-]', ' ', caption)
        tokens = caption.split()
        tokens = [t for t in tokens if t]
        return tokens
    
    def caption_to_sequence(self, caption: str) -> Tuple[torch.Tensor, int]:
        words = self.tokenize_caption(caption)
        if len(words) > self.max_length:
            words = words[:self.max_length]
        
        tokens = [self.vocab.word2idx[self.vocab.START_TOKEN]]
        tokens.extend([self.vocab(word) for word in words])
        tokens.append(self.vocab.word2idx[self.vocab.END_TOKEN])
        
        actual_length = len(tokens)
        max_seq_len = self.max_length + 2
        
        if len(tokens) < max_seq_len:
            tokens.extend([self.vocab.word2idx[self.vocab.PAD_TOKEN]] * (max_seq_len - len(tokens)))
        
        return torch.tensor(tokens, dtype=torch.long), actual_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        image_features = self.features_dict[item['image_id']]
        image_features = torch.from_numpy(image_features).float()
        
        # convert caption to sequence
        caption_seq, caption_length = self.caption_to_sequence(item['caption'])
        
        return image_features, caption_seq, caption_length, item['all_captions']

## Helper functions and main dataloader entrypoint

In [7]:
def collate_fn(batch):
    """
    Custom collate function for batching.
    
    Args:
        batch: List of (image_features, caption, length) tuples
    
    Returns:
        images: Tensor of shape (batch_size, feature_dim)
        captions: Tensor of shape (batch_size, max_length + 2)
        lengths: Tensor of shape (batch_size,)
    """
    # Separate the components
    images, captions, lengths = zip(*batch)
    
    # Stack into tensors
    images = torch.stack(images, dim=0)
    captions = torch.stack(captions, dim=0)
    lengths = torch.tensor(lengths, dtype=torch.long)
    
    return images, captions, lengths


def collate_fn_eval(batch):
    """
    Custom collate function for evaluation (includes all reference captions).
    
    Returns:
        images: Tensor of shape (batch_size, feature_dim)
        captions: Tensor of shape (batch_size, max_length + 2)
        lengths: Tensor of shape (batch_size,)
        all_captions: List of lists of reference captions
    """
    images, captions, lengths, all_captions = zip(*batch)
    
    images = torch.stack(images, dim=0)
    captions = torch.stack(captions, dim=0)
    lengths = torch.tensor(lengths, dtype=torch.long)
    
    return images, captions, lengths, list(all_captions)


def create_dataloaders(
    train_df: pd.DataFrame,
    val_df: pd.DataFrame,
    test_df: pd.DataFrame,
    train_features: dict,
    val_features: dict,
    test_features: dict,
    vocabulary,
    batch_size: int = 64,
    max_length: int = 15,
    num_workers: int = 4,
    shuffle_train: bool = True
):
    """
    Create train, validation, and test dataloaders.
    
    Args:
        train_df, val_df, test_df: DataFrames with captions
        train_features, val_features, test_features: Feature dictionaries
        vocabulary: Vocabulary object
        batch_size: Batch size for training
        max_length: Maximum caption length
        num_workers: Number of worker processes for data loading
        shuffle_train: Whether to shuffle training data
    
    Returns:
        train_loader, val_loader, test_loader
    """
  
    # create datasets
    train_dataset = CaptionDataset(
        train_df, 
        train_features, 
        vocabulary, 
        max_length=max_length,
        training=True
    )
    
    val_dataset = CaptionDataset(
        val_df, 
        val_features, 
        vocabulary, 
        max_length=max_length,
        training=False
    )
    
    test_dataset = CaptionDataset(
        test_df, 
        test_features, 
        vocabulary, 
        max_length=max_length,
        training=False
    )
    
    # create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle= shuffle_train,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    print(f"\nDataLoader Configuration:")
    print(f"  Batch size: {batch_size}")
    print(f"  Num workers: {num_workers}")
    print(f"  Max caption length: {max_length}")
    
    print(f"\nDataLoader Sizes:")
    print(f"  Train batches: {len(train_loader)}")
    print(f"  Val batches: {len(val_loader)}")
    print(f"  Test batches: {len(test_loader)}")
    
    return train_loader, val_loader, test_loader


# Decoder-Transformer Implementation

In [16]:
import torch.nn as nn
import torch.nn.functional as F
import math



class ImageCaptioningTransformer(nn.Module):
 
    def __init__(self, vocab_size, embed_dim=512, num_heads=8, 
                 num_layers=4, image_feat_dim=4096, max_len=17,dropout=0.1):
        super().__init__()
        
        self.embed_dim = embed_dim
        self.max_len = max_len
        self.vocab_size = vocab_size
        
        self.image_embed = nn.Sequential(
            nn.Linear(image_feat_dim, embed_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        self.word_embed = nn.Embedding(vocab_size, embed_dim)
        self.positional_embed = nn.Embedding(max_len+1, embed_dim) # +1 for the image

        self.embed_dropout = nn.Dropout(dropout)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=2048,
            dropout=dropout,
            activation='relu',
            batch_first=True, 
            norm_first=False  
        )

        self.transformer = nn.TransformerEncoder(
            encoder_layer=encoder_layer,
            num_layers = num_layers
        )
        
        self.output_project= nn.Linear(embed_dim, vocab_size)
        

    def generate_causal_mask(self, seq_len, device):
        """
        Generate causal mask: upper triangular matrix of -inf.
        
        """
        mask = torch.triu(
            torch.ones(seq_len, seq_len, device=device) * float('-inf'),
            diagonal=1
        )
        return mask
    
    def generate_padding_mask(self,captions,pad_idx=0):
        return (captions == pad_idx)
        
    def forward(self,images,captions):
        """
        forward pass
        """
        if images.dim()==3:
            images = images.squeeze(1)
        device = captions.device
        batch_size, seq_len = captions.shape

        img_embed = self.image_embed(images).unsqueeze(1) # -> (batch, 1,512)

        caption_embed = self.word_embed(captions)

        sequence = torch.cat([img_embed,caption_embed], dim = 1)

        positional_encoding = torch.arange(seq_len + 1,device=device).unsqueeze(0)
        positional_encoding= self.positional_embed(positional_encoding)

        sequence = sequence + positional_encoding
        sequence = self.embed_dropout(sequence)
        
        causal_mask = self.generate_causal_mask(seq_len +1, device)
        padding_mask = self.generate_padding_mask(captions)

        img_padding_mask = torch.zeros(batch_size,1, dtype=bool, device=device)

        padding_mask = torch.cat([img_padding_mask, padding_mask], dim = 1)

        output = self.transformer(
            sequence,
            mask=causal_mask,          
            src_key_padding_mask=padding_mask
        )

        logits = self.output_project(output)
        
        return logits


    def generate(self,images, start_token_idx=1, end_token_idx =2):
        """
        Generate captions autoregressively
        
        Start with [START]
        Loop:
          1. Get predictions for current sequence
          2. Take last prediction
          3. Sample/argmax next token
          4. Append to sequence
          5. Stop if END token or max_len reached
        """
        self.eval()
        batch_size = images.shape[0]
        device = images.device
        generated = torch.full((batch_size,1), start_token_idx,dtype= torch.long, device=device)

        # track what sequences have finished
        finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
        with torch.no_grad():
            for _ in range(self.max_len):
                logits = self.forward(images, generated)

                next_logits = logits[:,-1,:]
                next_token = next_logits.argmax(dim=1,keepdim = True)

                generated = torch.cat([generated, next_token], dim = 1)

                finished = finished | (next_token.squeeze(-1) == end_token_idx)

                # stop if all sequences have finished
                if finished.all():
                    break

        return generated
    


## Loss function and padding mask

In [17]:
class MaskedCrossEntropyLoss(nn.Module):
    def __init__(self,pad_idx=0):
        super().__init__()
        self.pad_idx = pad_idx
        self.criterion = nn.CrossEntropyLoss(reduction='none')

    def forward(self, logits, targets):
        """
        logits (batch,seq_len,vocab_size)
        targets (batch, seq_len, 1)
        """
        batch_size,seq_len, vocab_size = logits.shape
        device = logits.device
        
        logits_flat = logits.reshape(-1,vocab_size) # ( batch * seq_len, vocab_size)
        targets_flat = targets.reshape(-1)  # (batch * seq_len)

        
        loss = self.criterion(logits_flat, targets_flat)

        mask = (targets_flat != self.pad_idx).float()

        loss_masked = loss * mask

        return torch.sum(loss_masked) / torch.sum(mask)
        

## Train step and validation

In [30]:

def train_step(model, images, captions, lengths, criterion, optimizer, device):
    """
    One training step with teacher forcing
    
    Remember:
      - Input to model: captions
      - Model output: logits at each position
      - Predictions: logits[:, :-1, :] (drop last)
      - Targets: captions (original)
    """

    model.train()
    # Move to gpu
    images = images.to(device)
    captions = captions.to(device)
    lengths = lengths.to(device)
    logits = model(images,captions) # we pass ground truth for "teacher forcing"
    
    logits = logits[:,:-1,:]

    loss = criterion(logits, captions)

    
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    
    # Gradient clipping (prevent exploding gradients)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
    
    optimizer.step()
    
    return loss.item()


def validate(model, val_loader, criterion, device):
    """
    Validation loop.
    
    Args:
        model: ImageCaptioningTransformer model
        val_loader: validation DataLoader
        criterion: MaskedCrossEntropyLoss
        device: device to run on
    
    Returns:
        avg_loss: average validation loss
    """
    model.eval()
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        for images, captions, lengths in val_loader:
            images = images.to(device)
            captions = captions.to(device)
            lengths = lengths.to(device)
            
            # forward pass
            logits = model(images, captions)
            
            # compute loss 
            predictions = logits[:, :-1, :]
            targets = captions
            
            loss = criterion(predictions, targets)
            
            total_loss += loss.item()
            num_batches += 1
    
    return total_loss / num_batches

# Training

In [31]:
def train_epoch(model, criterion, train_loader,optimizer,config,epoch):
    model.train()

    total_loss = 0
    num_batches =0

    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config.num_epochs}')
    for batch_idx, (images,captions,lengths) in enumerate(pbar):
       
        loss = train_step(model, images, captions, lengths, 
                         criterion, optimizer, config.device)
        
        total_loss+=loss
        num_batches+=1

        if (batch_idx + 1) % config.print_every == 0:
            print(f'\n[Epoch {epoch+1}, Batch {batch_idx+1}/{len(train_loader)}] '
                  f'Loss: {loss:.4f}, Avg Loss: {total_loss/num_batches:.4f}')
    return total_loss/ num_batches



In [32]:
class Config:
    """
    training configuration
    """
   # model
    vocab_size = 9221
    embed_dim = 512
    num_heads = 8
    num_layers = 4
    image_feat_dim = 4096
    max_len = 17
    dropout = 0.1
    
    # training
    batch_size = 128
    num_epochs = 30
    learning_rate = 5e-5
    weight_decay = 1e-4
    base_dir ="/kaggle/working"
    save_dir="checkpoints"
    save_every= 10 
    print_every=100
    
    
    # device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    


import time
import os
def train(config):
    """
    -. we need to call the train_step function for n epochs for all batches
    - 
    
    """
    train_df = pd.read_pickle('train_captions.pkl')
    val_df = pd.read_pickle('val_captions.pkl')

    train_features = np.load('coco_train_vgg16_features.npy', allow_pickle=True).item()
    val_features = np.load('coco_val_vgg16_features.npy',allow_pickle = True).item()
    
    vocab = build_vocabulary(train_df,config.vocab_size)
    
    train_loader, val_loader, _ = create_dataloaders(# test_features and test_df are not used during training, so i used val_df as a dummy
        train_df,val_df,val_df,train_features,val_features,val_features,vocab,config.batch_size
    )

    model = ImageCaptioningTransformer(config.vocab_size, 
                                       config.embed_dim,
                                       config.num_heads,
                                       config.num_layers, 
                                       config.image_feat_dim,
                                       config.max_len,
                                       config.dropout
                                      ).to(config.device)
    

    optimizer = torch.optim.RMSprop(model.parameters(), 
                                    lr=config.learning_rate)
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config.num_epochs
    )

    criterion = MaskedCrossEntropyLoss()
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []

    os.makedirs(os.path.join(config.base_dir, config.save_dir),exist_ok=True)
    
    for epoch in range(config.num_epochs):
        epoch_start= time.time()
        
        train_loss = train_epoch(model, criterion, train_loader,optimizer,config,epoch)
        train_losses.append(train_loss)

        val_loss = validate(model,val_loader,criterion,config.device)
        val_losses.append(val_loss)
        
        scheduler.step()

        epoch_time = time.time() - epoch_start

        if (epoch +1) % config.save_every ==0:
            checkpoint_path = os.path.join(
                config.base_dir,
                config.save_dir,
                f'checkpoint_epoch_{epoch+1}.pt'
            )
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'config': config
            }, checkpoint_path)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            checkpoint_path = os.path.join(
                config.base_dir,
                config.save_dir,
                f'best_model.pt'
            )
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'config': config
            }, checkpoint_path)
        
    history = {
        'train_losses': train_losses,
        'val_losses': val_losses
    }
    
    history_path = os.path.join(config.base_dir,config.save_dir, 'training_history.pkl')
    with open(history_path, 'wb') as f:
        pickle.dump(history, f)
        
    print("Training history saved")


# Main

In [33]:
config=Config()
train(config)




Processing:   0%|          | 0/113287 [00:00<?, ?it/s][A[A[A


Processing:   1%|          | 1366/113287 [00:00<00:08, 13653.27it/s][A[A[A


Processing:   3%|▎         | 2964/113287 [00:00<00:07, 15017.79it/s][A[A[A


Processing:   4%|▍         | 4539/113287 [00:00<00:07, 15351.84it/s][A[A[A


Processing:   5%|▌         | 6108/113287 [00:00<00:06, 15482.08it/s][A[A[A


Processing:   7%|▋         | 7657/113287 [00:00<00:06, 15130.79it/s][A[A[A


Processing:   8%|▊         | 9175/113287 [00:00<00:06, 15144.87it/s][A[A[A


Processing:   9%|▉         | 10691/113287 [00:00<00:06, 14968.97it/s][A[A[A


Processing:  11%|█         | 12195/113287 [00:00<00:06, 14989.35it/s][A[A[A


Processing:  12%|█▏        | 13695/113287 [00:00<00:06, 14912.93it/s][A[A[A


Processing:  13%|█▎        | 15187/113287 [00:01<00:06, 14914.11it/s][A[A[A


Processing:  15%|█▍        | 16751/113287 [00:01<00:06, 15134.13it/s][A[A[A


Processing:  16%|█▌        | 18265/113287 [00:


DataLoader Configuration:
  Batch size: 64
  Num workers: 4
  Max caption length: 15

DataLoader Sizes:
  Train batches: 1771
  Val batches: 79
  Test batches: 79








[Epoch 1, Batch 50/1771] Loss: 4.7236, Avg Loss: 5.3725

[Epoch 1, Batch 100/1771] Loss: 4.2177, Avg Loss: 4.9437

[Epoch 1, Batch 150/1771] Loss: 4.1572, Avg Loss: 4.7206

[Epoch 1, Batch 200/1771] Loss: 4.0510, Avg Loss: 4.5713

[Epoch 1, Batch 250/1771] Loss: 3.9284, Avg Loss: 4.4548

[Epoch 1, Batch 300/1771] Loss: 3.7850, Avg Loss: 4.3632

[Epoch 1, Batch 350/1771] Loss: 3.8697, Avg Loss: 4.2832

[Epoch 1, Batch 400/1771] Loss: 3.8575, Avg Loss: 4.2137

[Epoch 1, Batch 450/1771] Loss: 3.7071, Avg Loss: 4.1557

[Epoch 1, Batch 500/1771] Loss: 3.6258, Avg Loss: 4.1021

[Epoch 1, Batch 550/1771] Loss: 3.6565, Avg Loss: 4.0521

[Epoch 1, Batch 600/1771] Loss: 3.3626, Avg Loss: 4.0062

[Epoch 1, Batch 650/1771] Loss: 3.3727, Avg Loss: 3.9656

[Epoch 1, Batch 700/1771] Loss: 3.3961, Avg Loss: 3.9268

[Epoch 1, Batch 750/1771] Loss: 3.1984, Avg Loss: 3.8918

[Epoch 1, Batch 800/1771] Loss: 3.2946, Avg Loss: 3.8601

[Epoch 1, Batch 850/1771] Loss: 3.5771, Avg Loss: 3.8293

[Epoch 1, Batc

Epoch 1/30:   0%|          | 0/1771 [00:55<?, ?it/s]



Epoch 2/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 2, Batch 50/1771] Loss: 2.9012, Avg Loss: 2.9990

[Epoch 2, Batch 100/1771] Loss: 2.8934, Avg Loss: 2.9876

[Epoch 2, Batch 150/1771] Loss: 3.0050, Avg Loss: 2.9824

[Epoch 2, Batch 200/1771] Loss: 3.1482, Avg Loss: 2.9795

[Epoch 2, Batch 250/1771] Loss: 3.0134, Avg Loss: 2.9813

[Epoch 2, Batch 300/1771] Loss: 2.9707, Avg Loss: 2.9782

[Epoch 2, Batch 350/1771] Loss: 2.8600, Avg Loss: 2.9729

[Epoch 2, Batch 400/1771] Loss: 2.9823, Avg Loss: 2.9661

[Epoch 2, Batch 450/1771] Loss: 3.0770, Avg Loss: 2.9637

[Epoch 2, Batch 500/1771] Loss: 2.7954, Avg Loss: 2.9578

[Epoch 2, Batch 550/1771] Loss: 2.9355, Avg Loss: 2.9525

[Epoch 2, Batch 600/1771] Loss: 3.0156, Avg Loss: 2.9487

[Epoch 2, Batch 650/1771] Loss: 3.0338, Avg Loss: 2.9447

[Epoch 2, Batch 700/1771] Loss: 2.9316, Avg Loss: 2.9394

[Epoch 2, Batch 750/1771] Loss: 2.6117, Avg Loss: 2.9329

[Epoch 2, Batch 800/1771] Loss: 2.9958, Avg Loss: 2.9307

[Epoch 2, Batch 850/1771] Loss: 2.9032, Avg Loss: 2.9293

[Epoch 2, Batc

Epoch 2/30:   0%|          | 0/1771 [00:54<?, ?it/s]



Epoch 3/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 3, Batch 50/1771] Loss: 2.9744, Avg Loss: 2.7697

[Epoch 3, Batch 100/1771] Loss: 2.7629, Avg Loss: 2.7628

[Epoch 3, Batch 150/1771] Loss: 2.7866, Avg Loss: 2.7619

[Epoch 3, Batch 200/1771] Loss: 2.6280, Avg Loss: 2.7650

[Epoch 3, Batch 250/1771] Loss: 2.6957, Avg Loss: 2.7584

[Epoch 3, Batch 300/1771] Loss: 2.7844, Avg Loss: 2.7579

[Epoch 3, Batch 350/1771] Loss: 2.8643, Avg Loss: 2.7514

[Epoch 3, Batch 400/1771] Loss: 2.7218, Avg Loss: 2.7485

[Epoch 3, Batch 450/1771] Loss: 2.9351, Avg Loss: 2.7476

[Epoch 3, Batch 500/1771] Loss: 2.5824, Avg Loss: 2.7451

[Epoch 3, Batch 550/1771] Loss: 2.6333, Avg Loss: 2.7462

[Epoch 3, Batch 600/1771] Loss: 2.8392, Avg Loss: 2.7435

[Epoch 3, Batch 650/1771] Loss: 2.5568, Avg Loss: 2.7414

[Epoch 3, Batch 700/1771] Loss: 2.6538, Avg Loss: 2.7380

[Epoch 3, Batch 750/1771] Loss: 2.7833, Avg Loss: 2.7363

[Epoch 3, Batch 800/1771] Loss: 2.7078, Avg Loss: 2.7340

[Epoch 3, Batch 850/1771] Loss: 2.5995, Avg Loss: 2.7327

[Epoch 3, Batc

Epoch 3/30:   0%|          | 0/1771 [00:54<?, ?it/s]



Epoch 4/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 4, Batch 50/1771] Loss: 2.6359, Avg Loss: 2.6674

[Epoch 4, Batch 100/1771] Loss: 2.7942, Avg Loss: 2.6478

[Epoch 4, Batch 150/1771] Loss: 2.7716, Avg Loss: 2.6407

[Epoch 4, Batch 200/1771] Loss: 2.7730, Avg Loss: 2.6452

[Epoch 4, Batch 250/1771] Loss: 2.5882, Avg Loss: 2.6412

[Epoch 4, Batch 300/1771] Loss: 2.7949, Avg Loss: 2.6409

[Epoch 4, Batch 350/1771] Loss: 2.6415, Avg Loss: 2.6325

[Epoch 4, Batch 400/1771] Loss: 2.4575, Avg Loss: 2.6336

[Epoch 4, Batch 450/1771] Loss: 2.6528, Avg Loss: 2.6322

[Epoch 4, Batch 500/1771] Loss: 2.3388, Avg Loss: 2.6297

[Epoch 4, Batch 550/1771] Loss: 2.3367, Avg Loss: 2.6271

[Epoch 4, Batch 600/1771] Loss: 2.4610, Avg Loss: 2.6243

[Epoch 4, Batch 650/1771] Loss: 2.5608, Avg Loss: 2.6228

[Epoch 4, Batch 700/1771] Loss: 2.6160, Avg Loss: 2.6222

[Epoch 4, Batch 750/1771] Loss: 2.6900, Avg Loss: 2.6222

[Epoch 4, Batch 800/1771] Loss: 2.6640, Avg Loss: 2.6213

[Epoch 4, Batch 850/1771] Loss: 2.7051, Avg Loss: 2.6216

[Epoch 4, Batc

Epoch 4/30:   0%|          | 0/1771 [00:55<?, ?it/s]



Epoch 5/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 5, Batch 50/1771] Loss: 2.4859, Avg Loss: 2.5794

[Epoch 5, Batch 100/1771] Loss: 2.6279, Avg Loss: 2.5708

[Epoch 5, Batch 150/1771] Loss: 2.7766, Avg Loss: 2.5562

[Epoch 5, Batch 200/1771] Loss: 2.6087, Avg Loss: 2.5613

[Epoch 5, Batch 250/1771] Loss: 2.6899, Avg Loss: 2.5591

[Epoch 5, Batch 300/1771] Loss: 2.7071, Avg Loss: 2.5572

[Epoch 5, Batch 350/1771] Loss: 2.6221, Avg Loss: 2.5605

[Epoch 5, Batch 400/1771] Loss: 2.5225, Avg Loss: 2.5588

[Epoch 5, Batch 450/1771] Loss: 2.3819, Avg Loss: 2.5574

[Epoch 5, Batch 500/1771] Loss: 2.5070, Avg Loss: 2.5584

[Epoch 5, Batch 550/1771] Loss: 2.4461, Avg Loss: 2.5571

[Epoch 5, Batch 600/1771] Loss: 2.5547, Avg Loss: 2.5560

[Epoch 5, Batch 650/1771] Loss: 2.6599, Avg Loss: 2.5547

[Epoch 5, Batch 700/1771] Loss: 2.5471, Avg Loss: 2.5537

[Epoch 5, Batch 750/1771] Loss: 2.7237, Avg Loss: 2.5522

[Epoch 5, Batch 800/1771] Loss: 2.4671, Avg Loss: 2.5519

[Epoch 5, Batch 850/1771] Loss: 2.6620, Avg Loss: 2.5499

[Epoch 5, Batc

Epoch 5/30:   0%|          | 0/1771 [00:54<?, ?it/s]



Epoch 6/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 6, Batch 50/1771] Loss: 2.4767, Avg Loss: 2.4537

[Epoch 6, Batch 100/1771] Loss: 2.6974, Avg Loss: 2.4822

[Epoch 6, Batch 150/1771] Loss: 2.3154, Avg Loss: 2.4899

[Epoch 6, Batch 200/1771] Loss: 2.5951, Avg Loss: 2.4919

[Epoch 6, Batch 250/1771] Loss: 2.5937, Avg Loss: 2.4925

[Epoch 6, Batch 300/1771] Loss: 2.2983, Avg Loss: 2.4943

[Epoch 6, Batch 350/1771] Loss: 2.5343, Avg Loss: 2.4976

[Epoch 6, Batch 400/1771] Loss: 2.4852, Avg Loss: 2.4994

[Epoch 6, Batch 450/1771] Loss: 2.5924, Avg Loss: 2.4983

[Epoch 6, Batch 500/1771] Loss: 2.4710, Avg Loss: 2.5013

[Epoch 6, Batch 550/1771] Loss: 2.8367, Avg Loss: 2.5027

[Epoch 6, Batch 600/1771] Loss: 2.4244, Avg Loss: 2.5028

[Epoch 6, Batch 650/1771] Loss: 2.5815, Avg Loss: 2.5016

[Epoch 6, Batch 700/1771] Loss: 2.6141, Avg Loss: 2.4990

[Epoch 6, Batch 750/1771] Loss: 2.3119, Avg Loss: 2.4979

[Epoch 6, Batch 800/1771] Loss: 2.5451, Avg Loss: 2.4968

[Epoch 6, Batch 850/1771] Loss: 2.4486, Avg Loss: 2.4955

[Epoch 6, Batc

Epoch 6/30:   0%|          | 0/1771 [00:55<?, ?it/s]



Epoch 7/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 7, Batch 50/1771] Loss: 2.4873, Avg Loss: 2.4373

[Epoch 7, Batch 100/1771] Loss: 2.5358, Avg Loss: 2.4440

[Epoch 7, Batch 150/1771] Loss: 2.3549, Avg Loss: 2.4551

[Epoch 7, Batch 200/1771] Loss: 2.6710, Avg Loss: 2.4614

[Epoch 7, Batch 250/1771] Loss: 2.4013, Avg Loss: 2.4669

[Epoch 7, Batch 300/1771] Loss: 2.4142, Avg Loss: 2.4699

[Epoch 7, Batch 350/1771] Loss: 2.2666, Avg Loss: 2.4640

[Epoch 7, Batch 400/1771] Loss: 2.4812, Avg Loss: 2.4622

[Epoch 7, Batch 450/1771] Loss: 2.2281, Avg Loss: 2.4589

[Epoch 7, Batch 500/1771] Loss: 2.3347, Avg Loss: 2.4628

[Epoch 7, Batch 550/1771] Loss: 2.5953, Avg Loss: 2.4611

[Epoch 7, Batch 600/1771] Loss: 2.3766, Avg Loss: 2.4588

[Epoch 7, Batch 650/1771] Loss: 2.4566, Avg Loss: 2.4571

[Epoch 7, Batch 700/1771] Loss: 2.4622, Avg Loss: 2.4568

[Epoch 7, Batch 750/1771] Loss: 2.6176, Avg Loss: 2.4564

[Epoch 7, Batch 800/1771] Loss: 2.5669, Avg Loss: 2.4571

[Epoch 7, Batch 850/1771] Loss: 2.4784, Avg Loss: 2.4579

[Epoch 7, Batc

Epoch 7/30:   0%|          | 0/1771 [00:55<?, ?it/s]



Epoch 8/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 8, Batch 50/1771] Loss: 2.4010, Avg Loss: 2.4205

[Epoch 8, Batch 100/1771] Loss: 2.4848, Avg Loss: 2.4295

[Epoch 8, Batch 150/1771] Loss: 2.4269, Avg Loss: 2.4297

[Epoch 8, Batch 200/1771] Loss: 2.5145, Avg Loss: 2.4330

[Epoch 8, Batch 250/1771] Loss: 2.4033, Avg Loss: 2.4329

[Epoch 8, Batch 300/1771] Loss: 2.5372, Avg Loss: 2.4314

[Epoch 8, Batch 350/1771] Loss: 2.4082, Avg Loss: 2.4271

[Epoch 8, Batch 400/1771] Loss: 2.4552, Avg Loss: 2.4244

[Epoch 8, Batch 450/1771] Loss: 2.5642, Avg Loss: 2.4243

[Epoch 8, Batch 500/1771] Loss: 2.4710, Avg Loss: 2.4261

[Epoch 8, Batch 550/1771] Loss: 2.2055, Avg Loss: 2.4227

[Epoch 8, Batch 600/1771] Loss: 2.2503, Avg Loss: 2.4225

[Epoch 8, Batch 650/1771] Loss: 2.4742, Avg Loss: 2.4223

[Epoch 8, Batch 700/1771] Loss: 2.2480, Avg Loss: 2.4219

[Epoch 8, Batch 750/1771] Loss: 2.4745, Avg Loss: 2.4214

[Epoch 8, Batch 800/1771] Loss: 2.4982, Avg Loss: 2.4211

[Epoch 8, Batch 850/1771] Loss: 2.4166, Avg Loss: 2.4203

[Epoch 8, Batc

Epoch 8/30:   0%|          | 0/1771 [00:55<?, ?it/s]



Epoch 9/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 9, Batch 50/1771] Loss: 2.3778, Avg Loss: 2.3797

[Epoch 9, Batch 100/1771] Loss: 2.4537, Avg Loss: 2.3858

[Epoch 9, Batch 150/1771] Loss: 2.2882, Avg Loss: 2.3937

[Epoch 9, Batch 200/1771] Loss: 2.3451, Avg Loss: 2.3944

[Epoch 9, Batch 250/1771] Loss: 2.4659, Avg Loss: 2.3884

[Epoch 9, Batch 300/1771] Loss: 2.3705, Avg Loss: 2.3892

[Epoch 9, Batch 350/1771] Loss: 2.2898, Avg Loss: 2.3862

[Epoch 9, Batch 400/1771] Loss: 2.2679, Avg Loss: 2.3879

[Epoch 9, Batch 450/1771] Loss: 2.4147, Avg Loss: 2.3923

[Epoch 9, Batch 500/1771] Loss: 2.4622, Avg Loss: 2.3944

[Epoch 9, Batch 550/1771] Loss: 2.1256, Avg Loss: 2.3937

[Epoch 9, Batch 600/1771] Loss: 2.3332, Avg Loss: 2.3930

[Epoch 9, Batch 650/1771] Loss: 2.4218, Avg Loss: 2.3923

[Epoch 9, Batch 700/1771] Loss: 2.2270, Avg Loss: 2.3912

[Epoch 9, Batch 750/1771] Loss: 2.6186, Avg Loss: 2.3919

[Epoch 9, Batch 800/1771] Loss: 2.4398, Avg Loss: 2.3909

[Epoch 9, Batch 850/1771] Loss: 2.2120, Avg Loss: 2.3903

[Epoch 9, Batc

Epoch 9/30:   0%|          | 0/1771 [00:55<?, ?it/s]



Epoch 10/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 10, Batch 50/1771] Loss: 2.4220, Avg Loss: 2.3655

[Epoch 10, Batch 100/1771] Loss: 2.3943, Avg Loss: 2.3606

[Epoch 10, Batch 150/1771] Loss: 2.3991, Avg Loss: 2.3656

[Epoch 10, Batch 200/1771] Loss: 2.4193, Avg Loss: 2.3730

[Epoch 10, Batch 250/1771] Loss: 2.2467, Avg Loss: 2.3704

[Epoch 10, Batch 300/1771] Loss: 2.3159, Avg Loss: 2.3702

[Epoch 10, Batch 350/1771] Loss: 2.1315, Avg Loss: 2.3659

[Epoch 10, Batch 400/1771] Loss: 2.3937, Avg Loss: 2.3692

[Epoch 10, Batch 450/1771] Loss: 2.4436, Avg Loss: 2.3665

[Epoch 10, Batch 500/1771] Loss: 2.3457, Avg Loss: 2.3663

[Epoch 10, Batch 550/1771] Loss: 2.4308, Avg Loss: 2.3668

[Epoch 10, Batch 600/1771] Loss: 2.3735, Avg Loss: 2.3665

[Epoch 10, Batch 650/1771] Loss: 2.6051, Avg Loss: 2.3673

[Epoch 10, Batch 700/1771] Loss: 2.4102, Avg Loss: 2.3667

[Epoch 10, Batch 750/1771] Loss: 2.2874, Avg Loss: 2.3668

[Epoch 10, Batch 800/1771] Loss: 2.4220, Avg Loss: 2.3676

[Epoch 10, Batch 850/1771] Loss: 2.4431, Avg Loss: 2.369

Epoch 10/30:   0%|          | 0/1771 [00:54<?, ?it/s]



Epoch 11/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 11, Batch 50/1771] Loss: 2.2737, Avg Loss: 2.3246

[Epoch 11, Batch 100/1771] Loss: 2.3195, Avg Loss: 2.3385

[Epoch 11, Batch 150/1771] Loss: 2.0831, Avg Loss: 2.3477

[Epoch 11, Batch 200/1771] Loss: 2.5983, Avg Loss: 2.3540

[Epoch 11, Batch 250/1771] Loss: 2.3413, Avg Loss: 2.3528

[Epoch 11, Batch 300/1771] Loss: 2.4179, Avg Loss: 2.3471

[Epoch 11, Batch 350/1771] Loss: 2.4245, Avg Loss: 2.3507

[Epoch 11, Batch 400/1771] Loss: 2.3345, Avg Loss: 2.3504

[Epoch 11, Batch 450/1771] Loss: 2.3221, Avg Loss: 2.3511

[Epoch 11, Batch 500/1771] Loss: 2.3666, Avg Loss: 2.3484

[Epoch 11, Batch 550/1771] Loss: 2.3213, Avg Loss: 2.3490

[Epoch 11, Batch 600/1771] Loss: 2.1217, Avg Loss: 2.3479

[Epoch 11, Batch 650/1771] Loss: 2.5571, Avg Loss: 2.3483

[Epoch 11, Batch 700/1771] Loss: 2.3383, Avg Loss: 2.3465

[Epoch 11, Batch 750/1771] Loss: 2.4732, Avg Loss: 2.3462

[Epoch 11, Batch 800/1771] Loss: 2.4939, Avg Loss: 2.3462

[Epoch 11, Batch 850/1771] Loss: 2.5490, Avg Loss: 2.345

Epoch 11/30:   0%|          | 0/1771 [00:55<?, ?it/s]



Epoch 12/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 12, Batch 50/1771] Loss: 2.3454, Avg Loss: 2.3196

[Epoch 12, Batch 100/1771] Loss: 2.0817, Avg Loss: 2.3107

[Epoch 12, Batch 150/1771] Loss: 2.2494, Avg Loss: 2.3109

[Epoch 12, Batch 200/1771] Loss: 2.3517, Avg Loss: 2.3099

[Epoch 12, Batch 250/1771] Loss: 2.5228, Avg Loss: 2.3132

[Epoch 12, Batch 300/1771] Loss: 2.2052, Avg Loss: 2.3157

[Epoch 12, Batch 350/1771] Loss: 2.6132, Avg Loss: 2.3198

[Epoch 12, Batch 400/1771] Loss: 2.2975, Avg Loss: 2.3217

[Epoch 12, Batch 450/1771] Loss: 2.2306, Avg Loss: 2.3232

[Epoch 12, Batch 500/1771] Loss: 2.1525, Avg Loss: 2.3253

[Epoch 12, Batch 550/1771] Loss: 2.5584, Avg Loss: 2.3275

[Epoch 12, Batch 600/1771] Loss: 2.2591, Avg Loss: 2.3263

[Epoch 12, Batch 650/1771] Loss: 2.1194, Avg Loss: 2.3250

[Epoch 12, Batch 700/1771] Loss: 2.2865, Avg Loss: 2.3238

[Epoch 12, Batch 750/1771] Loss: 2.2730, Avg Loss: 2.3247

[Epoch 12, Batch 800/1771] Loss: 2.2039, Avg Loss: 2.3250

[Epoch 12, Batch 850/1771] Loss: 2.4397, Avg Loss: 2.324

Epoch 12/30:   0%|          | 0/1771 [00:55<?, ?it/s]



Epoch 13/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 13, Batch 50/1771] Loss: 2.3936, Avg Loss: 2.3348

[Epoch 13, Batch 100/1771] Loss: 2.4195, Avg Loss: 2.3318

[Epoch 13, Batch 150/1771] Loss: 2.2794, Avg Loss: 2.3391

[Epoch 13, Batch 200/1771] Loss: 2.1689, Avg Loss: 2.3336

[Epoch 13, Batch 250/1771] Loss: 2.1893, Avg Loss: 2.3333

[Epoch 13, Batch 300/1771] Loss: 2.2591, Avg Loss: 2.3222

[Epoch 13, Batch 350/1771] Loss: 2.3818, Avg Loss: 2.3195

[Epoch 13, Batch 400/1771] Loss: 2.4291, Avg Loss: 2.3212

[Epoch 13, Batch 450/1771] Loss: 2.4139, Avg Loss: 2.3188

[Epoch 13, Batch 500/1771] Loss: 2.2497, Avg Loss: 2.3186

[Epoch 13, Batch 550/1771] Loss: 2.3262, Avg Loss: 2.3175

[Epoch 13, Batch 600/1771] Loss: 2.3369, Avg Loss: 2.3161

[Epoch 13, Batch 650/1771] Loss: 2.3231, Avg Loss: 2.3158

[Epoch 13, Batch 700/1771] Loss: 2.3296, Avg Loss: 2.3152

[Epoch 13, Batch 750/1771] Loss: 2.2100, Avg Loss: 2.3144

[Epoch 13, Batch 800/1771] Loss: 2.2546, Avg Loss: 2.3126

[Epoch 13, Batch 850/1771] Loss: 2.4575, Avg Loss: 2.311

Epoch 13/30:   0%|          | 0/1771 [00:55<?, ?it/s]



Epoch 14/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 14, Batch 50/1771] Loss: 2.3887, Avg Loss: 2.3094

[Epoch 14, Batch 100/1771] Loss: 2.2910, Avg Loss: 2.3027

[Epoch 14, Batch 150/1771] Loss: 2.2943, Avg Loss: 2.2995

[Epoch 14, Batch 200/1771] Loss: 2.1163, Avg Loss: 2.2975

[Epoch 14, Batch 250/1771] Loss: 2.4433, Avg Loss: 2.2997

[Epoch 14, Batch 300/1771] Loss: 2.1489, Avg Loss: 2.3010

[Epoch 14, Batch 350/1771] Loss: 2.3214, Avg Loss: 2.2991

[Epoch 14, Batch 400/1771] Loss: 2.3610, Avg Loss: 2.2998

[Epoch 14, Batch 450/1771] Loss: 2.1790, Avg Loss: 2.3027

[Epoch 14, Batch 500/1771] Loss: 2.2348, Avg Loss: 2.3030

[Epoch 14, Batch 550/1771] Loss: 2.5175, Avg Loss: 2.3014

[Epoch 14, Batch 600/1771] Loss: 2.5167, Avg Loss: 2.3023

[Epoch 14, Batch 650/1771] Loss: 2.1665, Avg Loss: 2.3041

[Epoch 14, Batch 700/1771] Loss: 2.4354, Avg Loss: 2.3016

[Epoch 14, Batch 750/1771] Loss: 2.3206, Avg Loss: 2.3013

[Epoch 14, Batch 800/1771] Loss: 2.3176, Avg Loss: 2.3006

[Epoch 14, Batch 850/1771] Loss: 2.2998, Avg Loss: 2.299

Epoch 14/30:   0%|          | 0/1771 [00:54<?, ?it/s]



Epoch 15/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 15, Batch 50/1771] Loss: 2.4858, Avg Loss: 2.2773

[Epoch 15, Batch 100/1771] Loss: 2.3037, Avg Loss: 2.2772

[Epoch 15, Batch 150/1771] Loss: 2.2506, Avg Loss: 2.2792

[Epoch 15, Batch 200/1771] Loss: 2.2713, Avg Loss: 2.2777

[Epoch 15, Batch 250/1771] Loss: 2.2212, Avg Loss: 2.2860

[Epoch 15, Batch 300/1771] Loss: 2.1943, Avg Loss: 2.2848

[Epoch 15, Batch 350/1771] Loss: 2.1915, Avg Loss: 2.2792

[Epoch 15, Batch 400/1771] Loss: 2.3507, Avg Loss: 2.2789

[Epoch 15, Batch 450/1771] Loss: 2.4260, Avg Loss: 2.2793

[Epoch 15, Batch 500/1771] Loss: 2.2689, Avg Loss: 2.2781

[Epoch 15, Batch 550/1771] Loss: 2.2336, Avg Loss: 2.2800

[Epoch 15, Batch 600/1771] Loss: 2.3788, Avg Loss: 2.2803

[Epoch 15, Batch 650/1771] Loss: 2.1933, Avg Loss: 2.2813

[Epoch 15, Batch 700/1771] Loss: 2.1901, Avg Loss: 2.2813

[Epoch 15, Batch 750/1771] Loss: 2.2751, Avg Loss: 2.2814

[Epoch 15, Batch 800/1771] Loss: 2.1871, Avg Loss: 2.2816

[Epoch 15, Batch 850/1771] Loss: 2.2386, Avg Loss: 2.280

Epoch 15/30:   0%|          | 0/1771 [00:54<?, ?it/s]



Epoch 16/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 16, Batch 50/1771] Loss: 2.2626, Avg Loss: 2.3032

[Epoch 16, Batch 100/1771] Loss: 2.1141, Avg Loss: 2.2761

[Epoch 16, Batch 150/1771] Loss: 2.3541, Avg Loss: 2.2741

[Epoch 16, Batch 200/1771] Loss: 2.4546, Avg Loss: 2.2720

[Epoch 16, Batch 250/1771] Loss: 2.6009, Avg Loss: 2.2764

[Epoch 16, Batch 300/1771] Loss: 2.1928, Avg Loss: 2.2792

[Epoch 16, Batch 350/1771] Loss: 2.0682, Avg Loss: 2.2770

[Epoch 16, Batch 400/1771] Loss: 2.2753, Avg Loss: 2.2756

[Epoch 16, Batch 450/1771] Loss: 2.2690, Avg Loss: 2.2718

[Epoch 16, Batch 500/1771] Loss: 2.3570, Avg Loss: 2.2715

[Epoch 16, Batch 550/1771] Loss: 2.3027, Avg Loss: 2.2694

[Epoch 16, Batch 600/1771] Loss: 2.4519, Avg Loss: 2.2695

[Epoch 16, Batch 650/1771] Loss: 2.2429, Avg Loss: 2.2676

[Epoch 16, Batch 700/1771] Loss: 2.3094, Avg Loss: 2.2691

[Epoch 16, Batch 750/1771] Loss: 2.4335, Avg Loss: 2.2692

[Epoch 16, Batch 800/1771] Loss: 2.3802, Avg Loss: 2.2701

[Epoch 16, Batch 850/1771] Loss: 2.3156, Avg Loss: 2.270

Epoch 16/30:   0%|          | 0/1771 [00:54<?, ?it/s]



Epoch 17/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 17, Batch 50/1771] Loss: 2.3066, Avg Loss: 2.2456

[Epoch 17, Batch 100/1771] Loss: 2.2147, Avg Loss: 2.2438

[Epoch 17, Batch 150/1771] Loss: 2.3120, Avg Loss: 2.2500

[Epoch 17, Batch 200/1771] Loss: 2.0291, Avg Loss: 2.2486

[Epoch 17, Batch 250/1771] Loss: 2.2093, Avg Loss: 2.2496

[Epoch 17, Batch 300/1771] Loss: 2.1697, Avg Loss: 2.2458

[Epoch 17, Batch 350/1771] Loss: 2.3185, Avg Loss: 2.2479

[Epoch 17, Batch 400/1771] Loss: 2.3030, Avg Loss: 2.2519

[Epoch 17, Batch 450/1771] Loss: 2.2970, Avg Loss: 2.2527

[Epoch 17, Batch 500/1771] Loss: 2.3704, Avg Loss: 2.2573

[Epoch 17, Batch 550/1771] Loss: 2.4236, Avg Loss: 2.2597

[Epoch 17, Batch 600/1771] Loss: 2.0919, Avg Loss: 2.2595

[Epoch 17, Batch 650/1771] Loss: 2.2593, Avg Loss: 2.2611

[Epoch 17, Batch 700/1771] Loss: 2.2046, Avg Loss: 2.2604

[Epoch 17, Batch 750/1771] Loss: 2.3928, Avg Loss: 2.2612

[Epoch 17, Batch 800/1771] Loss: 2.3238, Avg Loss: 2.2615

[Epoch 17, Batch 850/1771] Loss: 2.2204, Avg Loss: 2.261

Epoch 17/30:   0%|          | 0/1771 [00:55<?, ?it/s]



Epoch 18/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 18, Batch 50/1771] Loss: 2.2987, Avg Loss: 2.2571

[Epoch 18, Batch 100/1771] Loss: 2.2432, Avg Loss: 2.2498

[Epoch 18, Batch 150/1771] Loss: 2.1955, Avg Loss: 2.2530

[Epoch 18, Batch 200/1771] Loss: 1.9804, Avg Loss: 2.2478

[Epoch 18, Batch 250/1771] Loss: 2.1048, Avg Loss: 2.2479

[Epoch 18, Batch 300/1771] Loss: 2.0937, Avg Loss: 2.2508

[Epoch 18, Batch 350/1771] Loss: 2.1957, Avg Loss: 2.2497

[Epoch 18, Batch 400/1771] Loss: 2.3405, Avg Loss: 2.2536

[Epoch 18, Batch 450/1771] Loss: 2.3609, Avg Loss: 2.2494

[Epoch 18, Batch 500/1771] Loss: 2.3201, Avg Loss: 2.2469

[Epoch 18, Batch 550/1771] Loss: 2.4236, Avg Loss: 2.2449

[Epoch 18, Batch 600/1771] Loss: 2.4742, Avg Loss: 2.2460

[Epoch 18, Batch 650/1771] Loss: 2.1872, Avg Loss: 2.2451

[Epoch 18, Batch 700/1771] Loss: 2.2477, Avg Loss: 2.2461

[Epoch 18, Batch 750/1771] Loss: 2.0798, Avg Loss: 2.2452

[Epoch 18, Batch 800/1771] Loss: 2.2444, Avg Loss: 2.2461

[Epoch 18, Batch 850/1771] Loss: 2.1383, Avg Loss: 2.245

Epoch 18/30:   0%|          | 0/1771 [00:55<?, ?it/s]



Epoch 19/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 19, Batch 50/1771] Loss: 2.2459, Avg Loss: 2.2557

[Epoch 19, Batch 100/1771] Loss: 2.0568, Avg Loss: 2.2489

[Epoch 19, Batch 150/1771] Loss: 2.1107, Avg Loss: 2.2424

[Epoch 19, Batch 200/1771] Loss: 2.1970, Avg Loss: 2.2446

[Epoch 19, Batch 250/1771] Loss: 2.2022, Avg Loss: 2.2404

[Epoch 19, Batch 300/1771] Loss: 2.1798, Avg Loss: 2.2391

[Epoch 19, Batch 350/1771] Loss: 2.3785, Avg Loss: 2.2377

[Epoch 19, Batch 400/1771] Loss: 2.3732, Avg Loss: 2.2373

[Epoch 19, Batch 450/1771] Loss: 2.3735, Avg Loss: 2.2359

[Epoch 19, Batch 500/1771] Loss: 2.3216, Avg Loss: 2.2359

[Epoch 19, Batch 550/1771] Loss: 2.2121, Avg Loss: 2.2352

[Epoch 19, Batch 600/1771] Loss: 2.3939, Avg Loss: 2.2339

[Epoch 19, Batch 650/1771] Loss: 1.9395, Avg Loss: 2.2331

[Epoch 19, Batch 700/1771] Loss: 2.4352, Avg Loss: 2.2324

[Epoch 19, Batch 750/1771] Loss: 2.1220, Avg Loss: 2.2300

[Epoch 19, Batch 800/1771] Loss: 2.2945, Avg Loss: 2.2315

[Epoch 19, Batch 850/1771] Loss: 2.0436, Avg Loss: 2.232

Epoch 19/30:   0%|          | 0/1771 [00:55<?, ?it/s]



Epoch 20/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 20, Batch 50/1771] Loss: 2.4386, Avg Loss: 2.2326

[Epoch 20, Batch 100/1771] Loss: 2.4213, Avg Loss: 2.2394

[Epoch 20, Batch 150/1771] Loss: 2.3698, Avg Loss: 2.2387

[Epoch 20, Batch 200/1771] Loss: 2.3341, Avg Loss: 2.2409

[Epoch 20, Batch 250/1771] Loss: 2.0126, Avg Loss: 2.2430

[Epoch 20, Batch 300/1771] Loss: 2.1382, Avg Loss: 2.2391

[Epoch 20, Batch 350/1771] Loss: 2.2646, Avg Loss: 2.2411

[Epoch 20, Batch 400/1771] Loss: 2.1966, Avg Loss: 2.2364

[Epoch 20, Batch 450/1771] Loss: 2.1864, Avg Loss: 2.2325

[Epoch 20, Batch 500/1771] Loss: 2.4007, Avg Loss: 2.2294

[Epoch 20, Batch 550/1771] Loss: 2.1404, Avg Loss: 2.2308

[Epoch 20, Batch 600/1771] Loss: 2.3288, Avg Loss: 2.2304

[Epoch 20, Batch 650/1771] Loss: 2.3069, Avg Loss: 2.2289

[Epoch 20, Batch 700/1771] Loss: 2.1570, Avg Loss: 2.2285

[Epoch 20, Batch 750/1771] Loss: 2.2777, Avg Loss: 2.2282

[Epoch 20, Batch 800/1771] Loss: 2.1993, Avg Loss: 2.2276

[Epoch 20, Batch 850/1771] Loss: 2.2349, Avg Loss: 2.227

Epoch 20/30:   0%|          | 0/1771 [00:54<?, ?it/s]



Epoch 21/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 21, Batch 50/1771] Loss: 2.1483, Avg Loss: 2.2447

[Epoch 21, Batch 100/1771] Loss: 2.3169, Avg Loss: 2.2350

[Epoch 21, Batch 150/1771] Loss: 2.2553, Avg Loss: 2.2365

[Epoch 21, Batch 200/1771] Loss: 1.9893, Avg Loss: 2.2373

[Epoch 21, Batch 250/1771] Loss: 2.2044, Avg Loss: 2.2341

[Epoch 21, Batch 300/1771] Loss: 2.3283, Avg Loss: 2.2299

[Epoch 21, Batch 350/1771] Loss: 2.1494, Avg Loss: 2.2297

[Epoch 21, Batch 400/1771] Loss: 2.4839, Avg Loss: 2.2343

[Epoch 21, Batch 450/1771] Loss: 2.2500, Avg Loss: 2.2327

[Epoch 21, Batch 500/1771] Loss: 2.0954, Avg Loss: 2.2345

[Epoch 21, Batch 550/1771] Loss: 2.2295, Avg Loss: 2.2334

[Epoch 21, Batch 600/1771] Loss: 1.9832, Avg Loss: 2.2317

[Epoch 21, Batch 650/1771] Loss: 2.1471, Avg Loss: 2.2299

[Epoch 21, Batch 700/1771] Loss: 2.3039, Avg Loss: 2.2283

[Epoch 21, Batch 750/1771] Loss: 2.3130, Avg Loss: 2.2280

[Epoch 21, Batch 800/1771] Loss: 2.4995, Avg Loss: 2.2283

[Epoch 21, Batch 850/1771] Loss: 2.2443, Avg Loss: 2.230

Epoch 21/30:   0%|          | 0/1771 [00:54<?, ?it/s]



Epoch 22/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 22, Batch 50/1771] Loss: 2.3383, Avg Loss: 2.1988

[Epoch 22, Batch 100/1771] Loss: 2.3009, Avg Loss: 2.2068

[Epoch 22, Batch 150/1771] Loss: 2.1983, Avg Loss: 2.2099

[Epoch 22, Batch 200/1771] Loss: 2.2795, Avg Loss: 2.2064

[Epoch 22, Batch 250/1771] Loss: 2.1655, Avg Loss: 2.2126

[Epoch 22, Batch 300/1771] Loss: 2.2374, Avg Loss: 2.2150

[Epoch 22, Batch 350/1771] Loss: 2.2273, Avg Loss: 2.2147

[Epoch 22, Batch 400/1771] Loss: 2.1338, Avg Loss: 2.2128

[Epoch 22, Batch 450/1771] Loss: 2.1415, Avg Loss: 2.2121

[Epoch 22, Batch 500/1771] Loss: 2.1998, Avg Loss: 2.2109

[Epoch 22, Batch 550/1771] Loss: 2.1570, Avg Loss: 2.2106

[Epoch 22, Batch 600/1771] Loss: 2.0281, Avg Loss: 2.2108

[Epoch 22, Batch 650/1771] Loss: 2.2672, Avg Loss: 2.2116

[Epoch 22, Batch 700/1771] Loss: 2.3106, Avg Loss: 2.2125

[Epoch 22, Batch 750/1771] Loss: 2.2275, Avg Loss: 2.2136

[Epoch 22, Batch 800/1771] Loss: 2.2791, Avg Loss: 2.2158

[Epoch 22, Batch 850/1771] Loss: 2.2240, Avg Loss: 2.216

Epoch 22/30:   0%|          | 0/1771 [00:55<?, ?it/s]



Epoch 23/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 23, Batch 50/1771] Loss: 2.0700, Avg Loss: 2.1955

[Epoch 23, Batch 100/1771] Loss: 2.3704, Avg Loss: 2.2103

[Epoch 23, Batch 150/1771] Loss: 2.4058, Avg Loss: 2.2064

[Epoch 23, Batch 200/1771] Loss: 2.0665, Avg Loss: 2.2019

[Epoch 23, Batch 250/1771] Loss: 2.2772, Avg Loss: 2.2031

[Epoch 23, Batch 300/1771] Loss: 2.2272, Avg Loss: 2.2095

[Epoch 23, Batch 350/1771] Loss: 2.2672, Avg Loss: 2.2105

[Epoch 23, Batch 400/1771] Loss: 2.2161, Avg Loss: 2.2140

[Epoch 23, Batch 450/1771] Loss: 2.2200, Avg Loss: 2.2130

[Epoch 23, Batch 500/1771] Loss: 2.0227, Avg Loss: 2.2130

[Epoch 23, Batch 550/1771] Loss: 2.3945, Avg Loss: 2.2139

[Epoch 23, Batch 600/1771] Loss: 2.4308, Avg Loss: 2.2138

[Epoch 23, Batch 650/1771] Loss: 2.1818, Avg Loss: 2.2136

[Epoch 23, Batch 700/1771] Loss: 2.2749, Avg Loss: 2.2116

[Epoch 23, Batch 750/1771] Loss: 2.1454, Avg Loss: 2.2126

[Epoch 23, Batch 800/1771] Loss: 2.1269, Avg Loss: 2.2140

[Epoch 23, Batch 850/1771] Loss: 2.1839, Avg Loss: 2.212

Epoch 23/30:   0%|          | 0/1771 [00:55<?, ?it/s]



Epoch 24/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 24, Batch 50/1771] Loss: 2.1628, Avg Loss: 2.1947

[Epoch 24, Batch 100/1771] Loss: 2.1978, Avg Loss: 2.1825

[Epoch 24, Batch 150/1771] Loss: 2.3507, Avg Loss: 2.1946

[Epoch 24, Batch 200/1771] Loss: 2.2716, Avg Loss: 2.1984

[Epoch 24, Batch 250/1771] Loss: 2.1988, Avg Loss: 2.2034

[Epoch 24, Batch 300/1771] Loss: 2.2827, Avg Loss: 2.2057

[Epoch 24, Batch 350/1771] Loss: 2.0702, Avg Loss: 2.2042

[Epoch 24, Batch 400/1771] Loss: 2.3365, Avg Loss: 2.2057

[Epoch 24, Batch 450/1771] Loss: 2.1327, Avg Loss: 2.2049

[Epoch 24, Batch 500/1771] Loss: 2.1842, Avg Loss: 2.2061

[Epoch 24, Batch 550/1771] Loss: 2.2300, Avg Loss: 2.2054

[Epoch 24, Batch 600/1771] Loss: 2.0871, Avg Loss: 2.2064

[Epoch 24, Batch 650/1771] Loss: 2.5161, Avg Loss: 2.2046

[Epoch 24, Batch 700/1771] Loss: 1.9792, Avg Loss: 2.2045

[Epoch 24, Batch 750/1771] Loss: 2.3551, Avg Loss: 2.2039

[Epoch 24, Batch 800/1771] Loss: 2.2051, Avg Loss: 2.2053

[Epoch 24, Batch 850/1771] Loss: 2.0663, Avg Loss: 2.205

Epoch 24/30:   0%|          | 0/1771 [00:55<?, ?it/s]



Epoch 25/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 25, Batch 50/1771] Loss: 2.1293, Avg Loss: 2.2099

[Epoch 25, Batch 100/1771] Loss: 2.2535, Avg Loss: 2.2036

[Epoch 25, Batch 150/1771] Loss: 2.2617, Avg Loss: 2.2076

[Epoch 25, Batch 200/1771] Loss: 2.2543, Avg Loss: 2.2080

[Epoch 25, Batch 250/1771] Loss: 2.2151, Avg Loss: 2.2079

[Epoch 25, Batch 300/1771] Loss: 2.0802, Avg Loss: 2.2087

[Epoch 25, Batch 350/1771] Loss: 2.0270, Avg Loss: 2.2049

[Epoch 25, Batch 400/1771] Loss: 2.2389, Avg Loss: 2.2084

[Epoch 25, Batch 450/1771] Loss: 2.2996, Avg Loss: 2.2096

[Epoch 25, Batch 500/1771] Loss: 2.1100, Avg Loss: 2.2064

[Epoch 25, Batch 550/1771] Loss: 2.2677, Avg Loss: 2.2067

[Epoch 25, Batch 600/1771] Loss: 2.3651, Avg Loss: 2.2055

[Epoch 25, Batch 650/1771] Loss: 2.2274, Avg Loss: 2.2051

[Epoch 25, Batch 700/1771] Loss: 2.2970, Avg Loss: 2.2044

[Epoch 25, Batch 750/1771] Loss: 2.0629, Avg Loss: 2.2048

[Epoch 25, Batch 800/1771] Loss: 2.2239, Avg Loss: 2.2051

[Epoch 25, Batch 850/1771] Loss: 2.2790, Avg Loss: 2.207

Epoch 25/30:   0%|          | 0/1771 [00:55<?, ?it/s]



Epoch 26/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 26, Batch 50/1771] Loss: 2.1508, Avg Loss: 2.1824

[Epoch 26, Batch 100/1771] Loss: 2.2714, Avg Loss: 2.1924

[Epoch 26, Batch 150/1771] Loss: 2.2416, Avg Loss: 2.1915

[Epoch 26, Batch 200/1771] Loss: 2.1424, Avg Loss: 2.1883

[Epoch 26, Batch 250/1771] Loss: 2.4590, Avg Loss: 2.1935

[Epoch 26, Batch 300/1771] Loss: 2.0928, Avg Loss: 2.1889

[Epoch 26, Batch 350/1771] Loss: 2.0725, Avg Loss: 2.1909

[Epoch 26, Batch 400/1771] Loss: 2.3279, Avg Loss: 2.1925

[Epoch 26, Batch 450/1771] Loss: 2.2121, Avg Loss: 2.1933

[Epoch 26, Batch 500/1771] Loss: 2.1343, Avg Loss: 2.1909

[Epoch 26, Batch 550/1771] Loss: 2.0507, Avg Loss: 2.1922

[Epoch 26, Batch 600/1771] Loss: 2.3887, Avg Loss: 2.1939

[Epoch 26, Batch 650/1771] Loss: 2.2588, Avg Loss: 2.1933

[Epoch 26, Batch 700/1771] Loss: 2.0720, Avg Loss: 2.1961

[Epoch 26, Batch 750/1771] Loss: 2.2305, Avg Loss: 2.1962

[Epoch 26, Batch 800/1771] Loss: 2.1699, Avg Loss: 2.1974

[Epoch 26, Batch 850/1771] Loss: 2.0328, Avg Loss: 2.197

Epoch 26/30:   0%|          | 0/1771 [00:55<?, ?it/s]



Epoch 27/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 27, Batch 50/1771] Loss: 1.9956, Avg Loss: 2.1988

[Epoch 27, Batch 100/1771] Loss: 2.2490, Avg Loss: 2.1987

[Epoch 27, Batch 150/1771] Loss: 2.2409, Avg Loss: 2.1923

[Epoch 27, Batch 200/1771] Loss: 1.8675, Avg Loss: 2.1887

[Epoch 27, Batch 250/1771] Loss: 2.0841, Avg Loss: 2.1908

[Epoch 27, Batch 300/1771] Loss: 2.2791, Avg Loss: 2.1920

[Epoch 27, Batch 350/1771] Loss: 2.2683, Avg Loss: 2.1930

[Epoch 27, Batch 400/1771] Loss: 2.2535, Avg Loss: 2.1951

[Epoch 27, Batch 450/1771] Loss: 2.2106, Avg Loss: 2.1951

[Epoch 27, Batch 500/1771] Loss: 2.0086, Avg Loss: 2.1954

[Epoch 27, Batch 550/1771] Loss: 2.1660, Avg Loss: 2.1946

[Epoch 27, Batch 600/1771] Loss: 2.2971, Avg Loss: 2.1965

[Epoch 27, Batch 650/1771] Loss: 2.2148, Avg Loss: 2.1948

[Epoch 27, Batch 700/1771] Loss: 2.1681, Avg Loss: 2.1955

[Epoch 27, Batch 750/1771] Loss: 1.9127, Avg Loss: 2.1962

[Epoch 27, Batch 800/1771] Loss: 2.1559, Avg Loss: 2.1961

[Epoch 27, Batch 850/1771] Loss: 2.5184, Avg Loss: 2.196

Epoch 27/30:   0%|          | 0/1771 [00:55<?, ?it/s]



Epoch 28/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 28, Batch 50/1771] Loss: 2.1973, Avg Loss: 2.1834

[Epoch 28, Batch 100/1771] Loss: 2.1059, Avg Loss: 2.1975

[Epoch 28, Batch 150/1771] Loss: 2.2427, Avg Loss: 2.1965

[Epoch 28, Batch 200/1771] Loss: 2.1591, Avg Loss: 2.1995

[Epoch 28, Batch 250/1771] Loss: 2.2981, Avg Loss: 2.1952

[Epoch 28, Batch 300/1771] Loss: 2.2859, Avg Loss: 2.1980

[Epoch 28, Batch 350/1771] Loss: 2.1952, Avg Loss: 2.1999

[Epoch 28, Batch 400/1771] Loss: 2.0886, Avg Loss: 2.2037

[Epoch 28, Batch 450/1771] Loss: 2.1251, Avg Loss: 2.2006

[Epoch 28, Batch 500/1771] Loss: 2.4557, Avg Loss: 2.2022

[Epoch 28, Batch 550/1771] Loss: 2.0456, Avg Loss: 2.2012

[Epoch 28, Batch 600/1771] Loss: 2.2082, Avg Loss: 2.2030

[Epoch 28, Batch 650/1771] Loss: 2.3771, Avg Loss: 2.2066

[Epoch 28, Batch 700/1771] Loss: 2.0028, Avg Loss: 2.2078

[Epoch 28, Batch 750/1771] Loss: 2.4223, Avg Loss: 2.2079

[Epoch 28, Batch 800/1771] Loss: 2.2813, Avg Loss: 2.2058

[Epoch 28, Batch 850/1771] Loss: 2.1719, Avg Loss: 2.205

Epoch 28/30:   0%|          | 0/1771 [00:55<?, ?it/s]



Epoch 29/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 29, Batch 50/1771] Loss: 2.0787, Avg Loss: 2.2053

[Epoch 29, Batch 100/1771] Loss: 2.3391, Avg Loss: 2.2078

[Epoch 29, Batch 150/1771] Loss: 2.0476, Avg Loss: 2.2108

[Epoch 29, Batch 200/1771] Loss: 2.0452, Avg Loss: 2.2113

[Epoch 29, Batch 250/1771] Loss: 2.3037, Avg Loss: 2.2120

[Epoch 29, Batch 300/1771] Loss: 2.1267, Avg Loss: 2.2054

[Epoch 29, Batch 350/1771] Loss: 2.1275, Avg Loss: 2.2055

[Epoch 29, Batch 400/1771] Loss: 2.2317, Avg Loss: 2.2059

[Epoch 29, Batch 450/1771] Loss: 2.1878, Avg Loss: 2.2071

[Epoch 29, Batch 500/1771] Loss: 2.0066, Avg Loss: 2.2062

[Epoch 29, Batch 550/1771] Loss: 2.3812, Avg Loss: 2.2079

[Epoch 29, Batch 600/1771] Loss: 2.1753, Avg Loss: 2.2064

[Epoch 29, Batch 650/1771] Loss: 2.4037, Avg Loss: 2.2045

[Epoch 29, Batch 700/1771] Loss: 1.9831, Avg Loss: 2.2047

[Epoch 29, Batch 750/1771] Loss: 2.2034, Avg Loss: 2.2045

[Epoch 29, Batch 800/1771] Loss: 2.3397, Avg Loss: 2.2047

[Epoch 29, Batch 850/1771] Loss: 2.1914, Avg Loss: 2.204

Epoch 29/30:   0%|          | 0/1771 [00:54<?, ?it/s]



Epoch 30/30:   0%|          | 0/1771 [00:00<?, ?it/s][A[A[A


[Epoch 30, Batch 50/1771] Loss: 2.3198, Avg Loss: 2.1936

[Epoch 30, Batch 100/1771] Loss: 2.3063, Avg Loss: 2.2031

[Epoch 30, Batch 150/1771] Loss: 2.2017, Avg Loss: 2.2058

[Epoch 30, Batch 200/1771] Loss: 2.2128, Avg Loss: 2.1980

[Epoch 30, Batch 250/1771] Loss: 2.2696, Avg Loss: 2.1965

[Epoch 30, Batch 300/1771] Loss: 2.0632, Avg Loss: 2.1976

[Epoch 30, Batch 350/1771] Loss: 2.2163, Avg Loss: 2.1983

[Epoch 30, Batch 400/1771] Loss: 2.1462, Avg Loss: 2.1985

[Epoch 30, Batch 450/1771] Loss: 2.2036, Avg Loss: 2.1994

[Epoch 30, Batch 500/1771] Loss: 2.3156, Avg Loss: 2.2000

[Epoch 30, Batch 550/1771] Loss: 2.1348, Avg Loss: 2.1974

[Epoch 30, Batch 600/1771] Loss: 2.0946, Avg Loss: 2.1952

[Epoch 30, Batch 650/1771] Loss: 2.3804, Avg Loss: 2.1963

[Epoch 30, Batch 700/1771] Loss: 2.1737, Avg Loss: 2.1962

[Epoch 30, Batch 750/1771] Loss: 2.3843, Avg Loss: 2.1984

[Epoch 30, Batch 800/1771] Loss: 2.1990, Avg Loss: 2.1983

[Epoch 30, Batch 850/1771] Loss: 2.3002, Avg Loss: 2.199

Epoch 30/30:   0%|          | 0/1771 [00:54<?, ?it/s]


Training history saved


# Testing

In [None]:
def generate_predictions(model, dataloader, config, method='greedy'):
    """
    Generate captions for images using the trained model.
    
    Args:
        model: Trained captioning model with generate() method
        dataloader: DataLoader with images
        config: Configuration dict
        method: Generation method (only 'greedy' for now)
        max_length: Maximum caption length
    
    Returns:
        predictions: List of dict with image_id and generated caption
    """
    vocab = pd.read_pickle("/kaggle/working/vocabulary.pkl")
    
    # Get reverse vocabulary (index -> word)
    idx2word = {idx: word for word, idx in vocab.items()}
    
    model.eval()  # Set model to evaluation mode
    device = next(model.parameters()).device
    
    predictions = []
    
    pbar = tqdm(dataloader, desc=f'Generating predictions using {method} method')
    
    with torch.no_grad():  # No gradient computation needed
        for batch_idx, batch in enumerate(pbar):
            # Handle different batch formats
            if len(batch) == 3:
                images, captions, lengths = batch
                image_ids = None
            elif len(batch) == 2:
                images, image_ids = batch
            else:
                images = batch[0]
                image_ids = None
            
            images = images.to(device)
            
            # Generate predictions using model's generate() method
            generated_indices = model.generate(images, max_length=config.max_length)
            
            # Convert indices to words
            batch_predictions = indices_to_captions(
                generated_indices, 
                idx2word,
                remove_special_tokens=True
            )
            
            # Store predictions
            for i, caption in enumerate(batch_predictions):
                pred_dict = {
                    'caption': caption,
                    'image_id': image_ids[i] if image_ids is not None else batch_idx * len(images) + i
                }
                predictions.append(pred_dict)
    
    return predictions


def indices_to_captions(indices_batch, idx2word, remove_special_tokens=True):
    """
    Convert batch of token indices to text captions.
    
    Args:
        indices_batch: Tensor of shape (batch_size, seq_len)
        idx2word: Dictionary mapping indices to words
        remove_special_tokens: Whether to remove <START>, <END>, <PAD>
    
    Returns:
        captions: List of string captions
    """
    captions = []
    special_tokens = {'<START>', '<END>', '<PAD>', '<UNK>'}
    
    # Convert to numpy if needed
    if torch.is_tensor(indices_batch):
        indices_batch = indices_batch.cpu().numpy()
    
    for indices in indices_batch:
        words = []
        for idx in indices:
            idx = int(idx)
            word = idx2word.get(idx, '<UNK>')
            
            # Stop at <END> token
            if word == '<END>':
                break
            
            # Skip special tokens if requested
            if remove_special_tokens and word in special_tokens:
                continue
            
            words.append(word)
        
        caption = ' '.join(words)
        captions.append(caption)
    
    return captions


        

In [52]:
# os.makedirs("/kaggle/working/test")
print(os.path.join(config.base_dir,"checkpoints"))

checkpoint_path = os.path.join(
                config.base_dir,
                config.save_dir,
                f'best_model.pt'
            )
checkpoint = torch.load(checkpoint_path, map_location=config.device, weights_only= False)
test_model = ImageCaptioningTransformer(config.vocab_size, 
                                       config.embed_dim,
                                       config.num_heads,
                                       config.num_layers, 
                                       config.image_feat_dim,
                                       config.max_len,
                                       config.dropout).to(config.device)
test_model.load_state_dict(checkpoint['model_state_dict'])

/kaggle/working/checkpoints


<All keys matched successfully>

In [61]:
train_df = pd.read_pickle('train_captions.pkl')
val_df = pd.read_pickle('val_captions.pkl')

train_features = np.load('coco_train_vgg16_features.npy', allow_pickle=True).item()
val_features = np.load('coco_val_vgg16_features.npy',allow_pickle = True).item()

vocab = build_vocabulary(train_df,config.vocab_size)

train_loader, val_loader, _ = create_dataloaders(# test_features and test_df are not used during training, so i used val_df as a dummy
    train_df,val_df,val_df,train_features,val_features,val_features,vocab,config.batch_size,shuffle_train=False
)
batch = next(iter(train_loader))





Processing:   0%|          | 0/113287 [00:00<?, ?it/s][A[A[A


Processing:   1%|          | 1343/113287 [00:00<00:08, 13426.08it/s][A[A[A


Processing:   3%|▎         | 2872/113287 [00:00<00:07, 14519.61it/s][A[A[A


Processing:   4%|▍         | 4382/113287 [00:00<00:07, 14780.18it/s][A[A[A


Processing:   5%|▌         | 5866/113287 [00:00<00:07, 14802.68it/s][A[A[A


Processing:   7%|▋         | 7366/113287 [00:00<00:07, 14870.68it/s][A[A[A


Processing:   8%|▊         | 8894/113287 [00:00<00:06, 15006.63it/s][A[A[A


Processing:   9%|▉         | 10415/113287 [00:00<00:06, 15072.46it/s][A[A[A


Processing:  11%|█         | 11923/113287 [00:00<00:06, 15053.93it/s][A[A[A


Processing:  12%|█▏        | 13434/113287 [00:00<00:06, 15069.46it/s][A[A[A


Processing:  13%|█▎        | 14953/113287 [00:01<00:06, 15104.65it/s][A[A[A


Processing:  15%|█▍        | 16465/113287 [00:01<00:06, 15109.08it/s][A[A[A


Processing:  16%|█▌        | 17976/113287 [00:


DataLoader Configuration:
  Batch size: 64
  Num workers: 4
  Max caption length: 15

DataLoader Sizes:
  Train batches: 1771
  Val batches: 79
  Test batches: 79


In [64]:
print(batch)
print(train_df.head())

[tensor([[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.7481, 0.0000]],

        [[0.0000, 1.2407, 0.0000,  ..., 0.0000, 1.4061, 0.0000]],

        ...,

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 2.2256]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.4438, 0.0000]],

        [[0.1779, 0.0000, 0.0000,  ..., 0.0000, 1.4963, 1.8912]]]), tensor([[   1,    4,   23,  ...,    0,    0,    0],
        [   1,    4,   49,  ...,    0,    0,    0],
        [   1,    4,   34,  ...,    0,    0,    0],
        ...,
        [   1,    4,  424,  ...,    0,    0,    0],
        [   1,    4,  145,  ...,    0,    0,    0],
        [   1,    4,   12,  ...,  562, 1074,    2]]), tensor([13, 11, 11, 14, 16, 11, 12, 13, 14, 14, 15, 12, 13, 10, 11, 12, 16, 11,
        10, 10, 12, 12, 13, 17, 15, 13, 10, 12, 12, 11, 17, 16, 13, 13, 15, 11,
        16, 10, 10, 12, 10, 13, 12, 11, 12, 11, 10, 14, 10, 13, 10, 12, 13, 12,
        13, 15,