In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import time
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from transformers import get_scheduler
from custom_tokenizers import Tokenizer
from configs.config import DataConfig, EncoderConfig, DecoderConfig, PaliGemmaConfig
from decoder_layers import KVCache, PaliGemmaForConditionalGeneration
from dataloaders import CustomDataLoader
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')



In [3]:
def train(
    model,
    dataloader,
    optimizer,
    lr_scheduler,
    device,
    epoch,
    grad_accumulation_steps=1,
    max_grad_norm=1.0,
    use_amp=False,
):
    model.train()
    kv_cache = KVCache()
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    total_loss = 0.0

    for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}")):
        input_ids, pixel_values, attention_mask = batch
        with torch.cuda.amp.autocast(enabled=use_amp):
            outputs = model(
                input_ids=input_ids,
                pixel_values=pixel_values,
                attention_mask=attention_mask
            )
            
            logits = outputs["logits"]
            loss = torch.nn.functional.cross_entropy(
                logits.view(-1, logits.size(-1)),
                input_ids.view(-1),
                ignore_index=0
            )

        scaler.scale(loss).backward()

        if (step + 1) % grad_accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            lr_scheduler.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

@torch.no_grad()
def validate(model, dataloader, device, use_amp=False):
    model.eval()
    total_loss = 0.0
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    for batch in tqdm(dataloader, desc="Validating"):
        input_ids, pixel_values, attention_mask = batch
        input_ids = input_ids.to(device)
        pixel_values = pixel_values.to(device)
        attention_mask = attention_mask.to(device)

        with torch.cuda.amp.autocast(enabled=use_amp):
            outputs = model(
                input_ids=input_ids,
                pixel_values=pixel_values,
                attention_mask=attention_mask
            )
            logits = outputs["logits"]
            loss = torch.nn.functional.cross_entropy(
                logits.view(-1, logits.size(-1)),
                input_ids.view(-1),
                ignore_index=0
            )
        total_loss += loss.item()

    return total_loss / len(dataloader)

In [4]:
def main():
    # Hyperparameters
    epochs = 2
    batch_size = 4
    learning_rate = 3e-5
    grad_accumulation_steps = 1
    use_amp = True
    patience = 3

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using device:", device)

    # === Load Config ===
    text_config = DecoderConfig()
    vision_config = EncoderConfig()
    text_config.vocab_size += 1
    image_token_index = text_config.vocab_size - 1
    config = PaliGemmaConfig(
        text_config=text_config,
        vision_config=vision_config,
        image_token_index=image_token_index,
    )

    # === Load Model ===
    model = PaliGemmaForConditionalGeneration(config).to(device)

    # Tokenizer
    tokenizer = Tokenizer(DataConfig())

    # === Dataloader ===
    train_loader = CustomDataLoader(
        split="train",
        batch_size=batch_size,
        num_workers=1,
        tokenizer=tokenizer,
        shuffle=True
    )
    val_loader = CustomDataLoader(
        split="val",
        batch_size=batch_size,
        num_workers=1,
        tokenizer=tokenizer,
        shuffle=False
    )

    # === Optimizer and Scheduler ===
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    num_training_steps = epochs * len(train_loader) // grad_accumulation_steps
    lr_scheduler = get_scheduler(
        "linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
    )

    best_val_loss = float("inf")
    patience_counter = 0

    # === Training Loop ===
    start_time = time.time()
    train_loss = []
    val_loss = []
    for epoch in range(1, epochs + 1):
        avg_train_loss = train(
            model,
            train_loader,
            optimizer,
            lr_scheduler,
            device,
            epoch,
            grad_accumulation_steps,
            use_amp=use_amp,
        )
        avg_val_loss = validate(model, val_loader, device, use_amp)
        train_loss.append(avg_train_loss)
        val_loss.append(avg_val_loss)
        print(f"Epoch {epoch} - Train Loss: {avg_train_loss:.4f} - Val Loss: {avg_val_loss:.4f}")

        # Save best model
        if avg_val_loss < best_val_loss:
            print("Validation loss improved. Saving model...")
            best_val_loss = avg_val_loss
            patience_counter = 0
            os.makedirs("checkpoints", exist_ok=True)
            torch.save(model.state_dict(), f"checkpoints/experiment.pt")
        else:
            patience_counter += 1
            print(f"No improvement. Patience: {patience_counter}/{patience}")
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break
    save_dir = "plot"
    os.makedirs(save_dir, exist_ok=True)

    # Plotting
    epochs = list(range(1, len(train_loss) + 1))
    plt.figure(figsize=(8, 6))
    plt.plot(epochs, train_loss, label='Train Loss', marker='o')
    plt.plot(epochs, val_loss, label='Validation Loss', marker='x')
    plt.title("Training and Validation Loss per Epoch")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)

    # Save the plot
    plot_path = os.path.join(save_dir, "experiment.png")
    plt.savefig(plot_path)
    plt.close()

    # Calculate elapsed time
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Elapsed time: {elapsed_time} seconds")

In [5]:
if __name__ == "__main__":
    main()

Using device: cpu


Epoch 1: 100%|██████████| 518/518 [16:15<00:00,  1.88s/it]
Validating: 100%|██████████| 74/74 [01:04<00:00,  1.16it/s]


Epoch 1 - Train Loss: 4.4270 - Val Loss: 4.0784
Validation loss improved. Saving model...


Epoch 2: 100%|██████████| 518/518 [16:04<00:00,  1.86s/it]
Validating: 100%|██████████| 74/74 [01:01<00:00,  1.21it/s]


Epoch 2 - Train Loss: 4.0779 - Val Loss: 4.0283
Validation loss improved. Saving model...
Elapsed time: 2066.8475058078766 seconds
