## Load the Model

The model we will be using is GenerativeImage2Text (GIT) from Microsoft. Check out the model on [Huggingface](https://huggingface.co/microsoft/git-base). 

In [1]:
from transformers import AutoProcessor, AutoModelForCausalLM
import torch

processor = AutoProcessor.from_pretrained("microsoft/git-base")
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")

model.to("cuda" if torch.cuda.is_available() else "cpu")
print(model.device)

  from .autonotebook import tqdm as notebook_tqdm
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


cuda:0


In [None]:
from datasets import load_dataset
dataset = load_dataset("vipulmaheshwari/GTA-Image-Captioning-Dataset")['train']

train_ds, test_ds = dataset.train_test_split(test_size=0.2).values()


In [3]:
print(f"Train dataset size: {len(train_ds)}")
print(f"Test dataset size: {len(test_ds)}")

Train dataset size: 3200
Test dataset size: 800


### Preprocess Data

Now, we will preprocess the data to get it ready for the model. This involves tokenizing the text and processing the images.

In [None]:
def preprocess_function(examples):
    # Process images
    image_inputs = processor(images=[x.convert("RGB") for x in examples["image"]], return_tensors="pt")
    
    # Process captions
    text_inputs = processor.tokenizer(
        text=examples["text"],
        padding="max_length",
        truncation=True,
        max_length=128,
        return_tensors="pt"
    )
    
    # Set labels, ignoring padding
    labels = text_inputs.input_ids.clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100
    
    return {
        "pixel_values": image_inputs.pixel_values,
        "input_ids": text_inputs.input_ids,
        "attention_mask": text_inputs.attention_mask,
        "labels": labels
    }

In [5]:
import gc
import multiprocessing

try:
	multiprocessing.set_start_method('spawn')
except RuntimeError:
	pass  # start method has already been set

processed_train_ds = train_ds.map(preprocess_function, batched=True, batch_size=4, num_proc=4, remove_columns=train_ds.column_names)
processed_test_ds = test_ds.map(preprocess_function, batched=True, batch_size=4, num_proc=4, remove_columns=test_ds.column_names)

del train_ds, test_ds
gc.collect()

Map (num_proc=4): 100%|██████████| 3200/3200 [00:13<00:00, 231.63 examples/s]
Map (num_proc=4): 100%|██████████| 800/800 [00:03<00:00, 212.05 examples/s]


48

In [6]:
print(processed_train_ds.column_names)

['pixel_values', 'input_ids', 'attention_mask', 'labels']


### Set up Training

In [None]:
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from tqdm.auto import tqdm
import numpy as np

# --- Training Configuration ---
num_epochs = 4
train_batch_size = 16
eval_batch_size = 16
learning_rate = 2e-4
gradient_accumulation_steps = 2

# --- DataLoaders ---
processed_train_ds.set_format(type='torch', columns=['pixel_values', 'input_ids', 'attention_mask', 'labels'])
processed_test_ds.set_format(type='torch', columns=['pixel_values', 'input_ids', 'attention_mask', 'labels'])

train_dataloader = DataLoader(processed_train_ds, shuffle=True, batch_size=train_batch_size)
eval_dataloader = DataLoader(processed_test_ds, batch_size=eval_batch_size)

# --- Optimizer and Scheduler ---
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=20, num_training_steps=num_training_steps
)

# --- Mixed-Precision Training Scaler ---
scaler = torch.amp.GradScaler('cuda')

# --- Training Loop ---
best_eval_loss = float('inf')
progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for step, batch in enumerate(train_dataloader):
        inputs = {k: v.to(model.device) for k, v in batch.items()}
        
        with torch.amp.autocast('cuda'):
            outputs = model(**inputs)
            loss = outputs.loss
            loss = loss / gradient_accumulation_steps
        
        scaler.scale(loss).backward()
        
        if (step + 1) % gradient_accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            lr_scheduler.step()
            optimizer.zero_grad()
        
        total_loss += loss.item() * gradient_accumulation_steps
        progress_bar.update(1)
        progress_bar.set_description(f"Epoch {epoch+1}, Loss: {total_loss / (step + 1):.4f}")

    # --- Evaluation Loop ---
    model.eval()
    eval_loss = 0
    with torch.no_grad():
        for batch in eval_dataloader:
            inputs = {k: v.to(model.device) for k, v in batch.items()}
            with torch.amp.autocast('cuda'):
                outputs = model(**inputs)
            eval_loss += outputs.loss.item()
    
    avg_eval_loss = eval_loss / len(eval_dataloader)
    perplexity = np.exp(avg_eval_loss)
    print(f"\n--- Epoch {epoch+1} Evaluation ---")
    print(f"  Average Loss: {avg_eval_loss:.4f}")
    print(f"  Perplexity: {perplexity:.4f}")
    print("--------------------------")

print("Training complete.")

Epoch 1, Loss: 4.4767:   8%|▊         | 200/2400 [01:22<15:25,  2.38it/s]


--- Epoch 1 Evaluation ---
  Average Loss: 3.0915
  Perplexity: 22.0091
--------------------------


Epoch 2, Loss: 2.8385:  17%|█▋        | 400/2400 [02:57<13:46,  2.42it/s]  


--- Epoch 2 Evaluation ---
  Average Loss: 2.6427
  Perplexity: 14.0509
--------------------------


Epoch 3, Loss: 2.4067:  25%|██▌       | 600/2400 [04:31<12:33,  2.39it/s]  


--- Epoch 3 Evaluation ---
  Average Loss: 2.4365
  Perplexity: 11.4327
--------------------------


Epoch 4, Loss: 2.1317:  33%|███▎      | 800/2400 [06:05<11:18,  2.36it/s]  


--- Epoch 4 Evaluation ---
  Average Loss: 2.3308
  Perplexity: 10.2862
--------------------------


Epoch 5, Loss: 1.9237:  42%|████▏     | 1000/2400 [07:40<09:41,  2.41it/s] 


--- Epoch 5 Evaluation ---
  Average Loss: 2.2787
  Perplexity: 9.7644
--------------------------


Epoch 6, Loss: 1.7397:  50%|█████     | 1200/2400 [09:14<08:15,  2.42it/s]  


--- Epoch 6 Evaluation ---
  Average Loss: 2.2614
  Perplexity: 9.5961
--------------------------


Epoch 7, Loss: 1.5877:  58%|█████▊    | 1400/2400 [10:48<06:50,  2.43it/s]  


--- Epoch 7 Evaluation ---
  Average Loss: 2.2600
  Perplexity: 9.5827
--------------------------


Epoch 8, Loss: 1.4359:  67%|██████▋   | 1600/2400 [12:22<05:27,  2.44it/s]  


--- Epoch 8 Evaluation ---
  Average Loss: 2.2739
  Perplexity: 9.7176
--------------------------


Epoch 9, Loss: 1.3030:  75%|███████▌  | 1800/2400 [13:56<04:08,  2.41it/s]


--- Epoch 9 Evaluation ---
  Average Loss: 2.3054
  Perplexity: 10.0283
--------------------------


Epoch 10, Loss: 1.1716:  83%|████████▎ | 2000/2400 [15:30<02:44,  2.43it/s]


--- Epoch 10 Evaluation ---
  Average Loss: 2.3439
  Perplexity: 10.4218
--------------------------


Epoch 11, Loss: 1.0538:  92%|█████████▏| 2200/2400 [17:05<01:22,  2.43it/s]


--- Epoch 11 Evaluation ---
  Average Loss: 2.3827
  Perplexity: 10.8337
--------------------------


Epoch 12, Loss: 0.9443: 100%|██████████| 2400/2400 [18:39<00:00,  2.41it/s]


--- Epoch 12 Evaluation ---
  Average Loss: 2.4282
  Perplexity: 11.3388
--------------------------
Training complete.


### Save the Model

Finally, we will save the model to a directory so we can use it later.

In [None]:
torch.save(model.state_dict(), "captioning-model.pth")