## GRU Model Training for Git Commit Message Generation

This notebook trains a **Gated Recurrent Unit (GRU)** language model to generate Git commit messages from Git diffs. It uses a custom-trained Byte Pair Encoding (BPE) tokenizer and follows a causal language modeling objective.

---

### Overview

1. **BPE tokenizer Loading**  
   Loads a trained `custom_bpe_tokenizer.json` file and registers special tokens:
   - `<sos>` (start-of-sequence)
   - `<endOfDiff>` (separator between diff and message)
   - `<endOfCommitMessage>` (end token)
   - `<pad>` (used for batching)

2. **Model Architecture**  
   - 4-layer GRU with:
     - 512-dimensional embeddings
     - 512 hidden units per layer
     - Dropout of 0.2
     - Weight tying between embedding and output layer

3. **Dataset Setup**  
   Each input sample is tokenized as: 
   `<sos> + diff + <endOfDiff> + commit_message + <endOfCommitMessage>`


4. **Training Loop**
- Uses `CrossEntropyLoss` (ignoring padding)
- `AdamW` optimizer with weight decay
- Clip gradient norm to 1.0
- Logs both training and validation loss
- Plots learning curves as `gru_loss_curve.png`

---

### File Requirements

- `custom_bpe_tokenizer.json`: Path to your trained tokenizer
- `data/cleaned_python_commit_dataset.csv`: Must contain `diff` and `commit_message` columns

> **Make sure the CSV path is correct** and matches your project structure.

---

### Output

- Trained model weights saved to:  
`trained_model/gru_model.pt`

- Loss plot image saved to:  
`gru_loss_curve.png`

This notebook serves as the baseline RNN-based model in comparison to the GPT-2 transformer model.



In [2]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from transformers import PreTrainedTokenizerFast
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW
import numpy as np
import matplotlib.pyplot as plt

# Set random seeds for reproducibility
random.seed(42)
torch.manual_seed(42)
np.random.seed(42)

# Load custom BPE tokenizer
custom_tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="custom_bpe_tokenizer.json"
)
custom_tokenizer.add_special_tokens({
    "pad_token": "<pad>",
    "eos_token": "<endOfCommitMessage>",
    "bos_token": "<sos>"
})
custom_tokenizer.add_tokens(["<endOfDiff>"])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# GRU Language Model
class GRULanguageModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=512, hidden_dim=512, num_layers=4, dropout=0.2, pad_id=0):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_id)
        self.drop = nn.Dropout(dropout)
        self.gru = nn.GRU(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True,
                          dropout=dropout if num_layers > 1 else 0.0)
        self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
        self.lm_head.weight = self.embed.weight  # weight tying

    def forward(self, input_ids, hidden=None):
        x = self.drop(self.embed(input_ids))
        out, new_h = self.gru(x, hidden)
        logits = self.lm_head(out)
        return logits, new_h

# Dataset class
class GitDiffDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        self.sep = "<endOfDiff>"
        self.eos = "<endOfCommitMessage>"
        self.bos = "<sos>"

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        diff, msg = map(str, self.data[idx])
        full = self.bos + diff + self.sep + msg + self.eos
        ids = self.tokenizer.encode(full)
        return torch.tensor(ids, dtype=torch.long)

# Collate for batching
def collate_fn(batch):
    max_len = min(512, max(len(seq) for seq in batch))  # truncate long samples
    pad_id = custom_tokenizer.pad_token_id
    padded = [F.pad(seq, (0, max_len - len(seq)), value=pad_id) for seq in batch]
    return torch.stack(padded)

# Training loop
def train_model(model, train_loader, val_loader, epochs=30, lr=2e-4, device='cuda'):
    model.to(device)
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=1e-2)
    criterion = nn.CrossEntropyLoss(ignore_index=custom_tokenizer.pad_token_id)

    train_losses = []
    val_losses = []

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} training"):
            batch = batch.to(device)
            optimizer.zero_grad()
            logits, _ = model(batch[:, :-1])
            loss = criterion(logits.reshape(-1, logits.size(-1)), batch[:, 1:].reshape(-1))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        print(f"Epoch {epoch+1} Train Loss: {avg_train_loss:.4f}")

        # Validation loop
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                logits, _ = model(batch[:, :-1])
                loss = criterion(logits.reshape(-1, logits.size(-1)), batch[:, 1:].reshape(-1))
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        print(f"Epoch {epoch+1} Val Loss: {avg_val_loss:.4f}")

    # Plot learning curves
    plt.figure()
    plt.plot(range(1, epochs + 1), train_losses, label="Training Loss")
    plt.plot(range(1, epochs + 1), val_losses, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("GRU Training vs Validation Loss")
    plt.legend()
    plt.tight_layout()
    plt.savefig("gru_loss_curve.png")
    print("Saved plot to gru_loss_curve.png")
    plt.close()

# Load data
df = pd.read_csv("data/cleaned_python_commit_dataset.csv")
data = list(zip(df["diff"].astype(str), df["commit_message"].astype(str)))
random.shuffle(data)
split_idx = int(0.8 * len(data))
train_data, val_data = data[:split_idx], data[split_idx:]
# train_data, val_data = data[:split_idx], data[:split_idx]

train_loader = DataLoader(GitDiffDataset(train_data, custom_tokenizer), batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(GitDiffDataset(val_data, custom_tokenizer), batch_size=32, shuffle=False, collate_fn=collate_fn)
vocab_size = len(custom_tokenizer)
# Initialize and train the model
model = GRULanguageModel(
    vocab_size=vocab_size,
    embed_dim=512,
    hidden_dim=512,
    num_layers=4,
    dropout=0.2,
    pad_id=custom_tokenizer.pad_token_id
)

train_model(model, train_loader, val_loader, epochs=30, device=device)

# Save the trained weights
os.makedirs("trained_model", exist_ok=True)
torch.save(model.state_dict(), "trained_model/gru_model.pt")
print("trained_model/gru_model.pt")


Epoch 1 training: 100%|██████████| 1914/1914 [11:08<00:00,  2.86it/s]


Epoch 1 Train Loss: 4.4200
Epoch 1 Val Loss: 3.4640


Epoch 2 training: 100%|██████████| 1914/1914 [11:05<00:00,  2.87it/s]


Epoch 2 Train Loss: 3.3897
Epoch 2 Val Loss: 3.0875


Epoch 3 training: 100%|██████████| 1914/1914 [11:05<00:00,  2.88it/s]


Epoch 3 Train Loss: 3.1197
Epoch 3 Val Loss: 2.8997


Epoch 4 training: 100%|██████████| 1914/1914 [11:05<00:00,  2.88it/s]


Epoch 4 Train Loss: 2.9634
Epoch 4 Val Loss: 2.7786


Epoch 5 training: 100%|██████████| 1914/1914 [11:05<00:00,  2.87it/s]


Epoch 5 Train Loss: 2.8521
Epoch 5 Val Loss: 2.6920


Epoch 6 training: 100%|██████████| 1914/1914 [11:06<00:00,  2.87it/s]


Epoch 6 Train Loss: 2.7654
Epoch 6 Val Loss: 2.6172


Epoch 7 training: 100%|██████████| 1914/1914 [11:05<00:00,  2.87it/s]


Epoch 7 Train Loss: 2.6941
Epoch 7 Val Loss: 2.5575


Epoch 8 training: 100%|██████████| 1914/1914 [11:07<00:00,  2.87it/s]


Epoch 8 Train Loss: 2.6334
Epoch 8 Val Loss: 2.5054


Epoch 9 training: 100%|██████████| 1914/1914 [11:06<00:00,  2.87it/s]


Epoch 9 Train Loss: 2.5808
Epoch 9 Val Loss: 2.4601


Epoch 10 training: 100%|██████████| 1914/1914 [11:06<00:00,  2.87it/s]


Epoch 10 Train Loss: 2.5319
Epoch 10 Val Loss: 2.4185


Epoch 11 training: 100%|██████████| 1914/1914 [11:06<00:00,  2.87it/s]


Epoch 11 Train Loss: 2.4889
Epoch 11 Val Loss: 2.3821


Epoch 12 training: 100%|██████████| 1914/1914 [11:06<00:00,  2.87it/s]


Epoch 12 Train Loss: 2.4509
Epoch 12 Val Loss: 2.3511


Epoch 13 training: 100%|██████████| 1914/1914 [11:06<00:00,  2.87it/s]


Epoch 13 Train Loss: 2.4166
Epoch 13 Val Loss: 2.3232


Epoch 14 training: 100%|██████████| 1914/1914 [11:07<00:00,  2.87it/s]


Epoch 14 Train Loss: 2.3850
Epoch 14 Val Loss: 2.2960


Epoch 15 training: 100%|██████████| 1914/1914 [11:06<00:00,  2.87it/s]


Epoch 15 Train Loss: 2.3560
Epoch 15 Val Loss: 2.2713


Epoch 16 training: 100%|██████████| 1914/1914 [11:05<00:00,  2.87it/s]


Epoch 16 Train Loss: 2.3289
Epoch 16 Val Loss: 2.2500


Epoch 17 training: 100%|██████████| 1914/1914 [11:06<00:00,  2.87it/s]


Epoch 17 Train Loss: 2.3046
Epoch 17 Val Loss: 2.2263


Epoch 18 training: 100%|██████████| 1914/1914 [11:04<00:00,  2.88it/s]


Epoch 18 Train Loss: 2.2811
Epoch 18 Val Loss: 2.2096


Epoch 19 training: 100%|██████████| 1914/1914 [11:04<00:00,  2.88it/s]


Epoch 19 Train Loss: 2.2596
Epoch 19 Val Loss: 2.1935


Epoch 20 training: 100%|██████████| 1914/1914 [11:04<00:00,  2.88it/s]


Epoch 20 Train Loss: 2.2394
Epoch 20 Val Loss: 2.1750


Epoch 21 training: 100%|██████████| 1914/1914 [11:04<00:00,  2.88it/s]


Epoch 21 Train Loss: 2.2198
Epoch 21 Val Loss: 2.1612


Epoch 22 training: 100%|██████████| 1914/1914 [11:04<00:00,  2.88it/s]


Epoch 22 Train Loss: 2.2017
Epoch 22 Val Loss: 2.1454


Epoch 23 training: 100%|██████████| 1914/1914 [11:03<00:00,  2.88it/s]


Epoch 23 Train Loss: 2.1847
Epoch 23 Val Loss: 2.1311


Epoch 24 training: 100%|██████████| 1914/1914 [11:05<00:00,  2.88it/s]


Epoch 24 Train Loss: 2.1686
Epoch 24 Val Loss: 2.1189


Epoch 25 training: 100%|██████████| 1914/1914 [11:05<00:00,  2.88it/s]


Epoch 25 Train Loss: 2.1529
Epoch 25 Val Loss: 2.1066


Epoch 26 training: 100%|██████████| 1914/1914 [11:04<00:00,  2.88it/s]


Epoch 26 Train Loss: 2.1382
Epoch 26 Val Loss: 2.0980


Epoch 27 training: 100%|██████████| 1914/1914 [11:20<00:00,  2.81it/s]


Epoch 27 Train Loss: 2.1239
Epoch 27 Val Loss: 2.0851


Epoch 28 training: 100%|██████████| 1914/1914 [11:16<00:00,  2.83it/s]


Epoch 28 Train Loss: 2.1107
Epoch 28 Val Loss: 2.0759


Epoch 29 training: 100%|██████████| 1914/1914 [11:05<00:00,  2.88it/s]


Epoch 29 Train Loss: 2.0974
Epoch 29 Val Loss: 2.0653


Epoch 30 training: 100%|██████████| 1914/1914 [11:09<00:00,  2.86it/s]


Epoch 30 Train Loss: 2.0848
Epoch 30 Val Loss: 2.0576
Saved plot to gru_loss_curve.png
trained_model/gru_model.pt


## Run an inference

In [3]:
# Inference function for GRU model
def generate_commit_message(model, tokenizer, git_diff, max_new_tokens=50, device='cuda'):
    model.eval()
    sep_token = "<endOfDiff>"
    eos_token = "<endOfCommitMessage>"
    bos_token = "<sos>"

    input_text = bos_token + git_diff + sep_token
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
    eos_id = tokenizer.encode(eos_token)[0]

    generated = input_ids
    hidden = None

    with torch.no_grad():
        for _ in range(max_new_tokens):
            logits, hidden = model(generated[:, -1:], hidden)
            next_id = logits[0, -1].argmax(dim=-1).item()
            generated = torch.cat([generated, torch.tensor([[next_id]], device=device)], dim=1)
            if next_id == eos_id:
                break

    decoded = tokenizer.decode(generated[0].tolist())
    return decoded.split(sep_token)[1].replace(eos_token, "").strip()

# Example Git diff input
sample_diff = """diff --git a/config.py b/config.py
index abc123..def456 100644
--- a/config.py
+++ b/config.py
@@ -1,5 +1,6 @@
 DEBUG = False
+LOGGING_ENABLED = True
"""

# Load model and run inference
model.load_state_dict(torch.load("trained_model/gru_model.pt", map_location=device))
model.to(device)

generated_msg = generate_commit_message(model, custom_tokenizer, sample_diff, device=device)
print("Generated Commit Message:\n", generated_msg)


Generated Commit Message:
 't be used for the current_run_and_run_run.py

This is a bug to avoid a bug that is not used in the
and the main function.

This reverts commit 3a4a0a0


## Download weights and loss curve graph

In [None]:
from google.colab import files

# Download cleaned dataset
files.download('trained_model/gru_model.pt')
files.download('gru_loss_curve.png')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>