In [None]:
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from model import CaptionModel
from transformers import GPT2TokenizerFast
from tqdm import tqdm

class ModelTrainer:
    """
    Trainer class for training the CaptionModel.
    """
    def __init__(self, model_config, train_config, data_loaders):
        self.device = train_config.device
        self.model = CaptionModel(model_config).to(self.device)
        
        self.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.train_loader, self.val_loader = data_loaders
        self.optimizer = AdamW(self.model.parameters(), lr=train_config.lr, weight_decay=1e-4)
        total_steps = len(self.train_loader) * train_config.epochs
        self.scheduler = OneCycleLR(self.optimizer, max_lr=train_config.lr, total_steps=total_steps)

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        
        for image, input_ids, labels in tqdm(self.train_loader, desc="Training"):
            image, input_ids, labels = image.to(self.device), input_ids.to(self.device), labels.to(self.device)
            loss = self.model(image, input_ids, labels)
            
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()
            self.optimizer.zero_grad()
            total_loss += loss.item()
            
        return total_loss / len(self.train_loader)

    @torch.no_grad()
    def evaluate(self):
        self.model.eval()
        total_loss = 0
        for image, input_ids, labels in tqdm(self.val_loader, desc="Validating"):
            image, input_ids, labels = image.to(self.device), input_ids.to(self.device), labels.to(self.device)
            loss = self.model(image, input_ids, labels)
            total_loss += loss.item()
        return total_loss / len(self.val_loader)

    def fit(self, epochs):
        best_loss = float('inf')
        
        for epoch in range(epochs):
            train_loss = self.train_epoch()
            val_loss = self.evaluate()
            
            if val_loss < best_loss:
                best_loss = val_loss
                torch.save(self.model.state_dict(), train_config.model_path / 'best_model.pth')
            
            print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")