# SBERT Hyperparameter Tuning Experiment 4

**Experiment focus:** Testing different data augmentation techniques for SBERT fine-tuning.

This notebook explores the effect of different data augmentation methods on SBERT performance for scientific claim source retrieval.

In [47]:
import pandas as pd
import pickle
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
import torch
import numpy as np
import random
import nltk
from nltk.corpus import wordnet
from nltk.tokenize import word_tokenize
import re
import os

# Download required NLTK data
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('averaged_perceptron_tagger')

# Set device and environment variables
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Disable PyTorch compilation for stability
os.environ['TORCH_COMPILE_DISABLE'] = '1'
os.environ['TORCH_DYNAMO_DISABLE'] = '1'

Using device: cpu


[nltk_data] Downloading package punkt to /Users/mataonbas/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/mataonbas/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /Users/mataonbas/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


## Data Augmentation Functions

Define different data augmentation techniques.

In [48]:
def clean_text(text):
    # Remove URLs
    text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
    # Remove special characters and digits
    text = re.sub(r'[^\w\s]', '', text)
    # Remove extra whitespace
    text = ' '.join(text.split())
    return text

def get_synonyms(word):
    synonyms = set()
    for syn in wordnet.synsets(word):
        for lemma in syn.lemmas():
            synonym = lemma.name().replace('_', ' ')
            if synonym != word:
                synonyms.add(synonym)
    return synonyms

def synonym_replacement(text, n=1):
    words = text.split()
    new_words = words.copy()
    random_word_list = list(set([word for word in words if wordnet.synsets(word)]))
    random.shuffle(random_word_list)
    num_replaced = 0
    
    for random_word in random_word_list:
        synonyms = get_synonyms(random_word)
        if len(synonyms) >= 1:
            synonym = random.choice(list(synonyms))
            new_words = [synonym if word == random_word else word for word in new_words]
            num_replaced += 1
        if num_replaced >= n:
            break
    
    return ' '.join(new_words)

## Data Loading and Augmentation

Load the data and apply different augmentation techniques.

In [49]:
# Load train and dev data
train_df = pd.read_csv('/Users/mataonbas/AIR-CheckThat!-GroupProject/CheckThat-ScientificClaimSourceRetrieval/subtask4b_query_tweets_train.tsv', 
                      sep='\t', 
                      names=['post_id', 'tweet_text', 'cord_uid'])
dev_df = pd.read_csv('/Users/mataonbas/AIR-CheckThat!-GroupProject/CheckThat-ScientificClaimSourceRetrieval/subtask4b_query_tweets_dev.tsv', 
                    sep='\t', 
                    names=['post_id', 'tweet_text', 'cord_uid'])

# Load papers data
with open('/Users/mataonbas/AIR-CheckThat!-GroupProject/CheckThat-ScientificClaimSourceRetrieval/subtask4b_collection_data.pkl', 'rb') as f:
    papers_df = pickle.load(f)
papers_df['text'] = papers_df['title'] + '. ' + papers_df['abstract']
paper_dict = dict(zip(papers_df['cord_uid'], papers_df['text']))

# Define augmentation strategies
augmentation_strategies = {
    'baseline': lambda x: x,
    'cleaning': clean_text,
    'synonym_replacement': lambda x: synonym_replacement(x, n=2),
}

# Prepare training examples with different augmentation strategies
train_samples = {}
for strategy_name, strategy_func in augmentation_strategies.items():
    samples = []
    for _, row in train_df.iterrows():
        if row['cord_uid'] in paper_dict:
            tweet = strategy_func(row['tweet_text'])
            paper = paper_dict[row['cord_uid']]
            samples.append(InputExample(texts=[tweet, paper]))
    train_samples[strategy_name] = samples

# Prepare dev examples
dev_samples = []
for _, row in dev_df.iterrows():
    if row['cord_uid'] in paper_dict:
        tweet = row['tweet_text']
        paper = paper_dict[row['cord_uid']]
        dev_samples.append(InputExample(texts=[tweet, paper]))

## Training Configuration

Set up the training parameters.

In [50]:
# Fixed hyperparameters
learning_rate = 2e-5
batch_size = 8
epochs = 4
warmup_steps = 100
model_name = 'multi-qa-mpnet-base-cos-v1'

# Print configuration
print("Training Configuration:")
print(f"Model: {model_name}")
print(f"Learning rate: {learning_rate}")
print(f"Batch size: {batch_size}")
print(f"Epochs: {epochs}")
print(f"Warmup steps: {warmup_steps}")

Training Configuration:
Model: multi-qa-mpnet-base-cos-v1
Learning rate: 2e-05
Batch size: 8
Epochs: 4
Warmup steps: 100


## Training and Evaluation Loop

Train and evaluate models with different augmentation strategies.

In [51]:
def evaluate_mrr(model, dev_df, papers_df, top_k=5):
    # Encode dev queries
    query_embeddings = model.encode(dev_df['tweet_text'].tolist(), show_progress_bar=True, convert_to_tensor=True)
    # Encode papers
    paper_embeddings = model.encode(papers_df['text'].tolist(), show_progress_bar=True, convert_to_tensor=True)
    # Compute similarity
    paper_norm = torch.nn.functional.normalize(paper_embeddings, p=2, dim=1)
    paper_ids = papers_df['cord_uid'].tolist()
    predictions = []
    for query_emb in query_embeddings:
        query_norm = torch.nn.functional.normalize(query_emb.unsqueeze(0), p=2, dim=1)
        similarity = torch.matmul(query_norm, paper_norm.T).squeeze()
        top_indices = torch.topk(similarity, k=min(top_k, len(paper_norm))).indices.tolist()
        preds = [paper_ids[i] for i in top_indices]
        predictions.append(preds)
    # Calculate MRR
    scores = []
    for i, row in dev_df.iterrows():
        gold = row['cord_uid']
        preds = predictions[i]
        if gold in preds:
            rank = preds.index(gold) + 1
            scores.append(1.0 / rank)
        else:
            scores.append(0.0)
    return np.mean(scores) if scores else 0.0

In [52]:
# Training loop
print("Verifying data and model setup...")
print(f"Number of training samples for each strategy:")
for strategy_name, samples in train_samples.items():
    print(f"{strategy_name}: {len(samples)} samples")
print(f"\nNumber of dev samples: {len(dev_samples)}")
print(f"Number of papers: {len(paper_dict)}")
results = []
for strategy_name, samples in train_samples.items():
    print(f'\nTraining {model_name} with {strategy_name} augmentation')
    print(f'Total samples: {len(samples)}')
    print(f'Number of batches per epoch: {len(samples) // batch_size}')
    
    # Initialize model with specific device
    model = SentenceTransformer(model_name)
    model.to(device)
    
    # Create dataloader with minimal configuration
    train_dataloader = DataLoader(
        samples, 
        shuffle=True, 
        batch_size=batch_size,
        num_workers=0,
        pin_memory=False
    )
    
    train_loss = losses.MultipleNegativesRankingLoss(model)
    
    # Basic training configuration with progress tracking
    model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        epochs=epochs,
        warmup_steps=warmup_steps,
        show_progress_bar=True,
        optimizer_params={'lr': learning_rate},
        use_amp=False,
        checkpoint_path=None,
        checkpoint_save_steps=0,
        checkpoint_save_total_limit=0,
        callback=lambda epoch, steps, loss: print(f'Epoch {epoch+1}/{epochs}, Step {steps}, Loss: {loss:.4f}')
    )
    
    print(f'\nEvaluating {strategy_name} model...')
    mrr = evaluate_mrr(model, dev_df, papers_df, top_k=5)
    results.append({
        'model': model_name,
        'augmentation_strategy': strategy_name,
        'learning_rate': learning_rate,
        'batch_size': batch_size,
        'epochs': epochs,
        'warmup_steps': warmup_steps,
        'dev_mrr': mrr
    })
    print(f'Result for {strategy_name}: MRR={mrr:.4f}')
    print('-' * 50)


Training multi-qa-mpnet-base-cos-v1 with baseline augmentation
Total samples: 12853
Number of batches per epoch: 1606


Iteration:   3%|▎         | 41/1607 [08:35<5:27:59, 12.57s/it]
Epoch:   0%|          | 0/4 [08:35<?, ?it/s]


KeyboardInterrupt: 

## Save Results

Save the results to a CSV file for later analysis.

In [None]:
# Save results to CSV
pd.DataFrame(results).to_csv('hyperparam_results_4.csv', index=False)
print('Results saved to hyperparam_results_4.csv')

# Display results
print("\nFinal Results:")
print(pd.DataFrame(results))