In [1]:
import json
import os
import pickle
import numpy as np
import torch
import evaluate
import faiss
from datasets import Dataset, load_from_disk, concatenate_datasets
from transformers import Trainer, EarlyStoppingCallback
from transformers import (
    RagTokenizer,
    RagRetriever,
    RagSequenceForGeneration,
    RagConfig,
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback,
    pipeline
)
from transformers import AutoTokenizer, AutoModel
from transformers import DPRQuestionEncoderTokenizerFast, T5TokenizerFast, T5ForConditionalGeneration
from transformers import RagSequenceForGeneration, T5ForConditionalGeneration
from collections import defaultdict
from sklearn.model_selection import train_test_split
from typing import List, Dict
import nltk
from nltk.translate.bleu_score import corpus_bleu
from nltk.translate.meteor_score import meteor_score
from rouge import Rouge
import textstat
import gc
from sentence_transformers import SentenceTransformer, util


nltk.download('punkt')
nltk.download('wordnet')
nltk.download('omw-1.4')

[nltk_data] Downloading package punkt to
[nltk_data]     /users/PCS0289/myosc24/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /users/PCS0289/myosc24/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     /users/PCS0289/myosc24/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


True

In [2]:

# Step 1: Data Loading, Cleaning, and Augmentation


def load_dataset(file_path):
    """
    Load and parse the JSONL dataset.
    """
    dataset = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line_number, line in enumerate(f, 1):
            try:
                data = json.loads(line.strip())
                dataset.append(data)
            except json.JSONDecodeError as e:
                print(f"[Line {line_number}] JSON decoding error: {e}")
    return dataset

def inspect_dataset(dataset, sample_size=5):
    """
    Inspect the dataset for common issues.
    """
    print("\n--- Dataset Inspection ---\n")
    print(f"Total entries: {len(dataset)}\n")
    
    
    missing_fields = defaultdict(int)
    for entry in dataset:
        for field in ['diff', 'msg']:
            if field not in entry or not entry[field].strip():
                missing_fields[field] += 1
    
    for field, count in missing_fields.items():
        print(f"Entries missing or empty '{field}': {count}")
    
    
    diff_id_counts = defaultdict(int)
    for entry in dataset:
        diff_id = entry.get('diff_id')
        if diff_id is not None:
            diff_id_counts[diff_id] += 1
    
    duplicates = {k: v for k, v in diff_id_counts.items() if v > 1}
    if duplicates:
        print(f"\nDuplicate diff_ids found: {len(duplicates)}")
        for k, v in list(duplicates.items())[:5]:  
            print(f"diff_id: {k}, count: {v}")
    else:
        print("\nNo duplicate diff_ids found.")
    
    
    print(f"\nSample {sample_size} entries:")
    for i, entry in enumerate(dataset[:sample_size], 1):
        print(f"\n--- Entry {i} ---")
        print(f"diff_id: {entry.get('diff_id')}")
        print(f"Diff:\n{entry.get('diff')[:200]}...")  
        print(f"Message: {entry.get('msg')[:100]}...")
    
    print("\n--- End of Inspection ---\n")

def clean_and_validate_dataset_lenient(dataset):
    """
    Clean and validate the dataset with lenient criteria.
    """
    cleaned_dataset = []
    excluded_entries = []
    
    for entry in dataset:
        diff = entry.get('diff', '').strip()
        msg = entry.get('msg', '').strip()
        diff_id = entry.get('diff_id')
        
        exclusion_reasons = []
        
        
        if not diff:
            exclusion_reasons.append('Missing diff')
        if not msg:
            exclusion_reasons.append('Missing commit message')
        
        if exclusion_reasons:
            excluded_entry = {
                'diff_id': diff_id,
                'diff': diff,
                'msg': msg,
                'reasons': exclusion_reasons
            }
            excluded_entries.append(excluded_entry)
        else:
            
            entry['diff'] = diff.replace('<nl>', '\n')
            entry['msg'] = ' '.join(msg.replace('<nl>', '\n').split())
            cleaned_dataset.append(entry)
    
    print(f"Total valid entries after lenient cleaning: {len(cleaned_dataset)}")
    print(f"Total excluded entries after lenient cleaning: {len(excluded_entries)}")
    
    
    excluded_path = 'excluded_entries_lenient.pkl'
    with open(excluded_path, 'wb') as f:
        pickle.dump(excluded_entries, f)
    print(f"Excluded entries saved to {excluded_path}")
    
    return cleaned_dataset, excluded_entries

def augment_messages(messages: List[str], num_aug=1) -> List[str]:
    """
    Augment commit messages by paraphrasing.
    """
    augmented = []
    paraphraser = pipeline("text2text-generation", model="Vamsi/T5_Paraphrase_Paws", device=0 if torch.cuda.is_available() else -1)
    for msg in messages:
        for _ in range(num_aug):
            paraphrased = paraphraser(msg, max_length=128, num_return_sequences=1, do_sample=True)[0]['generated_text']
            augmented.append(paraphrased)
    return augmented

def split_dataset(diffs, messages, diff_ids, train_size=0.7, val_size=0.15, test_size=0.15, random_state=42):
    """
    Split the dataset into training, validation, and test sets.
    """
    assert train_size + val_size + test_size == 1.0, "Train, validation and test sizes must sum to 1."
  
    train_diffs, temp_diffs, train_msgs, temp_msgs, train_ids, temp_ids = train_test_split(
        diffs, messages, diff_ids, train_size=train_size, random_state=random_state)
    
    val_ratio = val_size / (val_size + test_size)
    val_diffs, test_diffs, val_msgs, test_msgs, val_ids, test_ids = train_test_split(
        temp_diffs, temp_msgs, temp_ids, test_size=1 - val_ratio, random_state=random_state)
    
    print(f"Training set: {len(train_diffs)} entries")
    print(f"Validation set: {len(val_diffs)} entries")
    print(f"Test set: {len(test_diffs)} entries")
    
    return {
        'train': {'diffs': train_diffs, 'messages': train_msgs, 'diff_ids': train_ids},
        'validation': {'diffs': val_diffs, 'messages': val_msgs, 'diff_ids': val_ids},
        'test': {'diffs': test_diffs, 'messages': test_msgs, 'diff_ids': test_ids},
    }

def serialize_splits(dataset_splits, output_dir='prepared_data'):
    """
    Serialize the dataset splits into files.
    """
    os.makedirs(output_dir, exist_ok=True)

    for split_name, data in dataset_splits.items():
        file_path = os.path.join(output_dir, f"{split_name}_data.pkl")
        with open(file_path, 'wb') as f:
            pickle.dump(data, f)
        print(f"Serialized {split_name} set to {file_path}")

def verify_data_splits(dataset_splits, expected_total):
    """
    Verify the integrity of the dataset splits.
    """
    total = sum(len(data['diffs']) for data in dataset_splits.values())
    if total != expected_total:
        print(f"Verification failed: Total entries {total} != Expected {expected_total}")
        return False
  
    for split_name, data in dataset_splits.items():
        if len(data['diffs']) != len(data['messages']) or len(data['diffs']) != len(data['diff_ids']):
            print(f"Verification failed for split {split_name}: Inconsistent lengths.")
            return False
  
    print("All data splits verified successfully.")
    return True


dataset_path = 'py.jsonl'  
raw_dataset = load_dataset(dataset_path)
print(f"Total entries loaded: {len(raw_dataset)}")

inspect_dataset(raw_dataset, sample_size=3)

cleaned_dataset_lenient, excluded_entries_lenient = clean_and_validate_dataset_lenient(raw_dataset)

diffs = [entry['diff'] for entry in cleaned_dataset_lenient]
messages = [entry['msg'] for entry in cleaned_dataset_lenient]
diff_ids = [entry['diff_id'] for entry in cleaned_dataset_lenient]


augmented_messages = augment_messages(messages)
augmented_diffs = diffs * 1  
augmented_diff_ids = diff_ids * 1  


diffs.extend(augmented_diffs)
messages.extend(augmented_messages)
diff_ids.extend(augmented_diff_ids)

expected_total = len(diffs)

dataset_splits = split_dataset(diffs, messages, diff_ids)

serialize_splits(dataset_splits)

verify_data_splits(dataset_splits, expected_total=expected_total)

Total entries loaded: 500

--- Dataset Inspection ---

Total entries: 500


No duplicate diff_ids found.

Sample 3 entries:

--- Entry 1 ---
diff_id: 100
Diff:
mmm a / IPython / lib / irunner . py <nl> ppp b / IPython / lib / irunner . py <nl> <nl> import os <nl> import sys <nl> <nl> - # Third - party modules . <nl> - import pexpect <nl> + # Third - party mo...
Message: Use IPython . external for pexpect import .
...

--- Entry 2 ---
diff_id: 103
Diff:
mmm a / tornado / locks . py <nl> ppp b / tornado / locks . py <nl> <nl> __all__ = [ ' Condition ' , ' Event ' , ' Semaphore ' ] <nl> <nl> import collections <nl> - import contextlib <nl> <nl> from to...
Message: Simpler code for Semaphore . acquire ( ) as a context manager .
...

--- Entry 3 ---
diff_id: 105
Diff:
mmm a / Lib / test / test_resource . py <nl> ppp b / Lib / test / test_resource . py <nl> <nl> limit_set = 0 <nl> f = open ( TESTFN , " wb " ) <nl> f . write ( " X " * 1024 ) <nl> + f . flush ( ) <nl>...
Message: Try harder to

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


Training set: 700 entries
Validation set: 150 entries
Test set: 150 entries
Serialized train set to prepared_data/train_data.pkl
Serialized validation set to prepared_data/validation_data.pkl
Serialized test set to prepared_data/test_data.pkl
All data splits verified successfully.


True

In [3]:

# Step 2: Embedding Generation and FAISS Indexing


def load_embedding_model(model_name: str = 'microsoft/codebert-base') -> (AutoTokenizer, AutoModel):
    """
    Load the tokenizer and model for embedding generation.

    Args:
        model_name (str): Hugging Face model name.

    Returns:
        Tuple[AutoTokenizer, AutoModel]: The tokenizer and model instances.
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    
    
    if torch.cuda.is_available():
        model = model.to('cuda')
        print("Model moved to GPU.")
    else:
        print("GPU not available. Using CPU.")
    
    model.eval()  
    return tokenizer, model


tokenizer_embed, model_embed = load_embedding_model()

def generate_embeddings(diffs: List[str],
                        tokenizer: AutoTokenizer,
                        model: AutoModel,
                        batch_size: int = 32,
                        device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> np.ndarray:
    """
    Generate embeddings for a list of code diffs.

    Args:
        diffs (List[str]): List of code diffs.
        tokenizer (AutoTokenizer): Tokenizer for the model.
        model (AutoModel): Pre-trained model for embedding generation.
        batch_size (int): Number of samples per batch.
        device (str): Device to run the model on ('cuda' or 'cpu').

    Returns:
        np.ndarray: Array of embeddings.
    """
    embeddings = []
    num_batches = (len(diffs) + batch_size - 1) // batch_size

    for i in range(num_batches):
        batch_diffs = diffs[i*batch_size : (i+1)*batch_size]
        encoded_input = tokenizer(batch_diffs, padding=True, truncation=True, return_tensors='pt', max_length=512)
        encoded_input = {key: val.to(device) for key, val in encoded_input.items()}
        
        with torch.no_grad():
            model_output = model(**encoded_input)
        
        
        token_embeddings = model_output.last_hidden_state  
        attention_mask = encoded_input['attention_mask']  
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        batch_embeddings = (sum_embeddings / sum_mask).cpu().numpy()
        embeddings.append(batch_embeddings)
        
        print(f"Processed batch {i+1}/{num_batches}")

    embeddings = np.vstack(embeddings)
    print(f"Generated embeddings with shape: {embeddings.shape}")
    return embeddings


embeddings = generate_embeddings(diffs, tokenizer_embed, model_embed)

def build_faiss_index(embeddings: np.ndarray, index_path: str = 'kb_index.faiss') -> faiss.IndexFlatIP:
    """
    Build a FAISS index from the embeddings.

    Args:
        embeddings (np.ndarray): Array of embeddings.
        index_path (str): Path to save the FAISS index.

    Returns:
        faiss.IndexFlatIP: The FAISS index instance.
    """
    
    faiss.normalize_L2(embeddings)
    
    
    index = faiss.IndexFlatIP(embeddings.shape[1])  
    
    
    index.add(embeddings)
    print(f"FAISS index built with {index.ntotal} vectors.")
    
    
    faiss.write_index(index, index_path)
    print(f"FAISS index saved to {index_path}")
    
    return index


faiss_index = build_faiss_index(embeddings, index_path='kb_index.faiss')

def create_metadata_mapping(diff_ids: List[int],
                            diffs: List[str],
                            messages: List[str],
                            mapping_path: str = 'metadata_mapping.pkl') -> Dict[int, Dict]:
    """
    Create a metadata mapping linking FAISS index positions to their data.

    Args:
        diff_ids (List[int]): List of unique diff IDs.
        diffs (List[str]): List of code diffs.
        messages (List[str]): List of commit messages.
        mapping_path (str): Path to save the metadata mapping.

    Returns:
        Dict[int, Dict]: A dictionary mapping index positions to their data.
    """
    assert len(diff_ids) == len(diffs) == len(messages), "All input lists must have the same length."
    
    metadata_mapping = {}
    for idx, (diff_id, diff, message) in enumerate(zip(diff_ids, diffs, messages)):
        metadata_mapping[idx] = {
            'diff_id': diff_id,
            'diff': diff,
            'message': message
        }
    
    
    with open(mapping_path, 'wb') as f:
        pickle.dump(metadata_mapping, f)
    print(f"Metadata mapping saved to {mapping_path}")
    
    return metadata_mapping


metadata_mapping = create_metadata_mapping(diff_ids, diffs, messages)

def create_knowledge_base(diff_ids: List[int], diffs: List[str], embeddings: np.ndarray, 
                         mapping: Dict[int, Dict], output_dir: str = 'kb_dataset') -> Dataset:
    """
    Create and save the knowledge base dataset with required columns.

    Args:
        diff_ids (List[int]): List of unique diff IDs.
        diffs (List[str]): List of code diffs.
        embeddings (np.ndarray): Precomputed embeddings for each diff.
        mapping (Dict[int, Dict]): Metadata mapping.
        output_dir (str): Directory to save the knowledge base dataset.

    Returns:
        Dataset: The knowledge base dataset.
    """
    os.makedirs(output_dir, exist_ok=True)
    
    titles = [f"diff_{diff_id}" for diff_id in diff_ids]  
    texts = diffs
    
    embeddings = embeddings.astype('float32')
    
    kb_dataset = Dataset.from_dict({
        'title': titles,
        'text': texts,
        'embeddings': [embedding for embedding in embeddings]
    })
    
    
    kb_dataset.save_to_disk(output_dir)
    print(f"Knowledge base dataset saved to {output_dir}")
    
    return kb_dataset


kb_dataset = create_knowledge_base(diff_ids, diffs, embeddings, metadata_mapping)

def load_faiss_index(index_path: str = 'kb_index.faiss') -> faiss.IndexFlatIP:
    """
    Load the FAISS index from the saved file.

    Args:
        index_path (str): Path to the FAISS index file.

    Returns:
        faiss.IndexFlatIP: The loaded FAISS index.
    """
    index = faiss.read_index(index_path)
    print(f"FAISS index loaded from {index_path} with {index.ntotal} vectors.")
    return index

def load_metadata_mapping(mapping_path: str = 'metadata_mapping.pkl') -> Dict[int, Dict]:
    """
    Load the metadata mapping from the saved file.

    Args:
        mapping_path (str): Path to the metadata mapping file.

    Returns:
        Dict[int, Dict]: The loaded metadata mapping.
    """
    with open(mapping_path, 'rb') as f:
        metadata_mapping = pickle.load(f)
    print(f"Metadata mapping loaded from {mapping_path} with {len(metadata_mapping)} entries.")
    return metadata_mapping


reloaded_faiss_index = load_faiss_index()
reloaded_metadata_mapping = load_metadata_mapping()



Model moved to GPU.
Processed batch 1/32
Processed batch 2/32
Processed batch 3/32
Processed batch 4/32
Processed batch 5/32
Processed batch 6/32
Processed batch 7/32
Processed batch 8/32
Processed batch 9/32
Processed batch 10/32
Processed batch 11/32
Processed batch 12/32
Processed batch 13/32
Processed batch 14/32
Processed batch 15/32
Processed batch 16/32
Processed batch 17/32
Processed batch 18/32
Processed batch 19/32
Processed batch 20/32
Processed batch 21/32
Processed batch 22/32
Processed batch 23/32
Processed batch 24/32
Processed batch 25/32
Processed batch 26/32
Processed batch 27/32
Processed batch 28/32
Processed batch 29/32
Processed batch 30/32
Processed batch 31/32
Processed batch 32/32
Generated embeddings with shape: (1000, 768)
FAISS index built with 1000 vectors.
FAISS index saved to kb_index.faiss
Metadata mapping saved to metadata_mapping.pkl


Saving the dataset (0/1 shards):   0%|          | 0/1000 [00:00<?, ? examples/s]

Knowledge base dataset saved to kb_dataset
FAISS index loaded from kb_index.faiss with 1000 vectors.
Metadata mapping loaded from metadata_mapping.pkl with 1000 entries.


In [4]:

# Step 3: Training the RAG Model with Joint Training and Hyperparameter Tuning


import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def load_training_data(prepared_data_dir: str = 'prepared_data') -> Dict[str, List]:
    dataset_splits = {}
    for split in ['train', 'validation', 'test']:
        file_path = os.path.join(prepared_data_dir, f"{split}_data.pkl")
        with open(file_path, 'rb') as f:
            split_data = pickle.load(f)
        dataset_splits[split] = split_data
    return dataset_splits

dataset_splits = load_training_data()


question_encoder_tokenizer = DPRQuestionEncoderTokenizerFast.from_pretrained('facebook/dpr-question_encoder-single-nq-base')

generator_tokenizer = T5TokenizerFast.from_pretrained('t5-large')




train_dataset = Dataset.from_dict({
    'input_texts': dataset_splits['train']['diffs'],
    'target_texts': dataset_splits['train']['messages']
})
val_dataset = Dataset.from_dict({
    'input_texts': dataset_splits['validation']['diffs'],
    'target_texts': dataset_splits['validation']['messages']
})

def tokenize_function(examples):
    
    inputs = question_encoder_tokenizer(
        examples['input_texts'],
        max_length=512,
        truncation=True,
        padding='max_length',
    )
    
    labels = generator_tokenizer(
        examples['target_texts'],
        max_length=128,
        truncation=True,
        padding='max_length',
    )
    
    inputs['labels'] = labels['input_ids']
    return inputs


tokenized_train_dataset = train_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=['input_texts', 'target_texts']
)

tokenized_val_dataset = val_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=['input_texts', 'target_texts']
)


kb_dataset_path = 'kb_dataset'  
kb_index_path = 'kb_index.faiss'  

retriever = RagRetriever.from_pretrained(
    "facebook/rag-sequence-base",
    index_name="custom",
    passages_path=kb_dataset_path,
    index_path=kb_index_path,
    use_dummy_dataset=False,
    question_encoder_tokenizer=question_encoder_tokenizer,
    generator_tokenizer=generator_tokenizer,
)


rag_model = RagSequenceForGeneration.from_pretrained(
    "facebook/rag-sequence-base",
    retriever=retriever
)


t5_model = T5ForConditionalGeneration.from_pretrained('t5-large')


rag_model.generator = t5_model



device = 'cuda' if torch.cuda.is_available() else 'cpu'
rag_model.to(device)


for param in rag_model.question_encoder.parameters():
    param.requires_grad = True


rag_model.config.use_cache = False


data_collator = DataCollatorForSeq2Seq(
    tokenizer=generator_tokenizer,
    model=rag_model,
    label_pad_token_id=generator_tokenizer.pad_token_id,
    padding='longest',
)


training_args = TrainingArguments(
    output_dir='./rag_model_output',
    overwrite_output_dir=True,
    num_train_epochs=40,  
    per_device_train_batch_size=1,  
    gradient_accumulation_steps=8,
    learning_rate=1e-5,
    fp16=torch.cuda.is_available(),
    save_strategy="epoch",
    evaluation_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model='loss',
    greater_is_better=False,
    warmup_steps=100,
    weight_decay=0.01,
    logging_steps=10,
    save_total_limit=2,
    report_to="none",
    remove_unused_columns=False,
)


from transformers import Trainer, EarlyStoppingCallback

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        outputs = model(**inputs)
        loss = outputs.loss
        return (loss, outputs) if return_outputs else loss


trainer = CustomTrainer(
    model=rag_model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_val_dataset,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)


trainer.train()


spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

Map:   0%|          | 0/700 [00:00<?, ? examples/s]

Map:   0%|          | 0/150 [00:00<?, ? examples/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizerFast'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'BartTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called fr

model.safetensors:   0%|          | 0.00/2.95G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss
0,161.6684,158.467041
2,102.6384,111.843674
4,91.0903,103.949219
6,86.8726,98.122093
8,76.7761,93.833023
10,64.3725,89.515648
12,59.0141,86.209648
14,49.0875,83.160515
16,40.2094,81.268097
18,35.3397,79.636993


There were missing keys in the checkpoint model loaded: ['rag.generator.model.encoder.embed_tokens.weight', 'rag.generator.model.decoder.embed_tokens.weight', 'generator.encoder.embed_tokens.weight', 'generator.decoder.embed_tokens.weight', 'generator.lm_head.weight'].


TrainOutput(global_step=3480, training_loss=57.64258039189481, metrics={'train_runtime': 7316.8982, 'train_samples_per_second': 3.827, 'train_steps_per_second': 0.476, 'total_flos': 1.0211744425181184e+17, 'train_loss': 57.64258039189481, 'epoch': 39.77142857142857})

In [5]:

# Step 4: Evaluating the Model

from tqdm import tqdm
import gc


def load_test_dataset(prepared_data_dir='prepared_data') -> Dataset:
    file_path = os.path.join(prepared_data_dir, "test_data.pkl")
    with open(file_path, 'rb') as f:
        split_data = pickle.load(f)
    test_dataset = Dataset.from_dict({
        'input_texts': split_data['diffs'],     
        'target_texts': split_data['messages'], 
        'diff_ids': split_data['diff_ids']     
    })
    return test_dataset


test_dataset = load_test_dataset()


def generate_commit_messages(model, test_dataset, generator_tokenizer, device='cuda', output_file='generated_messages.jsonl'):
    """
    Generate commit messages for the test dataset and save them in JSONL format.

    Args:
        model: The trained RAG model.
        test_dataset: The test dataset.
        generator_tokenizer: The generator tokenizer (T5TokenizerFast).
        device (str): Device to run the model on ('cuda' or 'cpu').
        output_file (str): Path to save the generated messages.
    """
    model.eval()
    model.to(device)

    
    import gc
    torch.cuda.empty_cache()
    gc.collect()

    with open(output_file, 'w', encoding='utf-8') as f:
        for example in tqdm(test_dataset):
            input_text = example['input_texts']
            diff_id = example['diff_ids']

            
            inputs = model.retriever.question_encoder_tokenizer(
                input_text,
                return_tensors='pt',
                truncation=True,
                max_length=512
            ).to(device)

            
            with torch.no_grad():
                generated_ids = model.generate(
                    input_ids=inputs['input_ids'],
                    attention_mask=inputs['attention_mask'],
                    num_beams=2,       
                    max_length=64,     
                    early_stopping=True,
                    use_cache=False    
                )

            
            generated_message = generator_tokenizer.decode(generated_ids[0], skip_special_tokens=True)

            
            json_obj = {
                'diff_id': diff_id,
                'diff': input_text,
                'generated_message': generated_message
            }

            
            f.write(json.dumps(json_obj) + '\n')

            
            torch.cuda.empty_cache()
            gc.collect()

    print(f"Generated messages saved to {output_file}")


generate_commit_messages(
    model=rag_model,
    test_dataset=test_dataset,
    generator_tokenizer=generator_tokenizer,
    device=device,
    output_file='generated_messages.jsonl'
)



def load_generated_messages(generated_file='generated_messages.jsonl') -> List[Dict]:
    generated = []
    with open(generated_file, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line.strip())
            generated.append(data)
    return generated


def extract_messages(generated_data, test_dataset):
    generated = []
    references = []
    
    
    ref_mapping = {entry['diff_ids']: entry['target_texts'] for entry in test_dataset}
    
    for entry in generated_data:
        diff_id = entry['diff_id']
        gen_msg = entry['generated_message']
        ref_msg = ref_mapping.get(diff_id, "")
        
        if ref_msg:
            generated.append(gen_msg)
            references.append([ref_msg])  
        else:
            print(f"Warning: No reference message found for diff_id {diff_id}")
    
    return generated, references


generated_data = load_generated_messages('generated_messages.jsonl')
generated_messages, reference_messages = extract_messages(generated_data, test_dataset)


bleu = evaluate.load('bleu')
rouge_metric = evaluate.load('rouge')
meteor = evaluate.load('meteor')
bertscore = evaluate.load('bertscore')


def compute_bleu(generated: List[str], references: List[List[str]]):
    results = bleu.compute(predictions=generated, references=references, smooth=True)
    print(f"BLEU Score: {results['bleu']:.4f}")

def compute_rouge_scores(generated: List[str], references: List[List[str]]):
    results = rouge_metric.compute(predictions=generated, references=references, use_stemmer=True)
    for key, value in results.items():
        print(f"{key.upper()} Score: {value:.4f}")

def compute_meteor(generated: List[str], references: List[List[str]]):
    results = meteor.compute(predictions=generated, references=references)
    print(f"METEOR Score: {results['meteor']:.4f}")

def compute_bertscore(generated: List[str], references: List[List[str]]):
    results = bertscore.compute(
        predictions=generated,
        references=[ref[0] for ref in references],
        lang='en'
    )
    f1 = np.mean(results['f1'])
    print(f"BERTScore F1: {f1:.4f}")

def compute_accuracy(generated: List[str], references: List[List[str]]):
    correct = 0
    total = len(generated)
    for gen, ref in zip(generated, references):
        if gen.strip().lower() == ref[0].strip().lower():
            correct += 1
    accuracy = correct / total if total > 0 else 0
    print(f"Accuracy: {accuracy * 100:.2f}%")

def compute_identifier_matches(generated: List[str], diffs: List[str]):
    identifier_matches = []
    for gen_msg, diff in zip(generated, diffs):
       
        identifiers = set([word.strip('`') for word in diff.split() if word.isidentifier()])
        match_count = sum([1 for idf in identifiers if idf in gen_msg])
        identifier_matches.append(match_count / len(identifiers) if identifiers else 0)
    average_match = sum(identifier_matches) / len(identifier_matches)
    print(f"Average Identifier Match Rate: {average_match:.4f}")
    return average_match

def compute_readability(generated: List[str]):
    scores = [textstat.flesch_reading_ease(text) for text in generated]
    average_score = sum(scores) / len(scores) if scores else 0
    print(f"Average Readability (Flesch Reading Ease): {average_score:.2f}")


compute_bleu(generated_messages, reference_messages)
compute_rouge_scores(generated_messages, reference_messages)
compute_meteor(generated_messages, reference_messages)
compute_bertscore(generated_messages, reference_messages)
compute_accuracy(generated_messages, reference_messages)
compute_identifier_matches(generated_messages, test_dataset['input_texts'])
compute_readability(generated_messages)


100%|██████████| 150/150 [06:30<00:00,  2.61s/it]


Generated messages saved to generated_messages.jsonl


[nltk_data] Downloading package wordnet to
[nltk_data]     /users/PCS0289/myosc24/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /users/PCS0289/myosc24/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     /users/PCS0289/myosc24/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


BLEU Score: 0.1078
ROUGE1 Score: 0.3602
ROUGE2 Score: 0.2336
ROUGEL Score: 0.3340
ROUGELSUM Score: 0.3327
METEOR Score: 0.3697


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BERTScore F1: 0.8401
Accuracy: 0.00%
Average Identifier Match Rate: 0.1297
Average Readability (Flesch Reading Ease): 80.15


