In [4]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, BertForSequenceClassification
from peft import get_peft_model, LoraConfig, TaskType
from sklearn.model_selection import train_test_split
import numpy as np
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
# 1. Load and prepare the data
def load_data(tweets_file, papers_file):
    """Load tweet and paper data"""
    tweets_df = pd.read_csv(tweets_file, sep='\t')
    papers_df = pd.read_pickle(papers_file)
    
    # Create a mapping from cord_uid to paper data
    papers_dict = {}
    for _, paper in papers_df.iterrows():
        papers_dict[paper['cord_uid']] = {
            'title': str(paper['title']),
            'abstract': str(paper['abstract']) if pd.notna(paper['abstract']) else ''
        }
    
    return tweets_df, papers_dict

In [14]:
class TweetPaperMatchDataset(Dataset):
    def __init__(self, tweets_df, papers_dict, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        self.tweet_texts = []
        self.paper_texts = []
        self.labels = []  # 1 for matching pairs, 0 for non-matching
        
        # Create positive samples (tweets with their matching papers)
        for _, tweet in tweets_df.iterrows():
            tweet_text = tweet['tweet_text']
            cord_uid = tweet['cord_uid']
            
            if cord_uid in papers_dict:
                paper = papers_dict[cord_uid]
                paper_text = f"{paper['title']} {paper['abstract']}"
                
                self.tweet_texts.append(tweet_text)
                self.paper_texts.append(paper_text)
                self.labels.append(1)  # Positive sample
        
        # Create negative samples (tweets with random non-matching papers)
        all_paper_ids = list(papers_dict.keys())
        for _, tweet in tweets_df.iterrows():
            tweet_text = tweet['tweet_text']
            cord_uid = tweet['cord_uid']
            
            # Select a random non-matching paper
            negative_candidates = [pid for pid in all_paper_ids if pid != cord_uid]
            if negative_candidates:
                random_paper_id = np.random.choice(negative_candidates)
                paper = papers_dict[random_paper_id]
                paper_text = f"{paper['title']} {paper['abstract']}"
                
                self.tweet_texts.append(tweet_text)
                self.paper_texts.append(paper_text)
                self.labels.append(0)  # Negative sample
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        tweet = self.tweet_texts[idx]
        paper = self.paper_texts[idx]
        label = self.labels[idx]
        
        encoding = self.tokenizer(
            tweet, 
            paper,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt"
        )
        
        # Remove the batch dimension
        for key in encoding:
            encoding[key] = encoding[key].squeeze()
        
        return {
            'input_ids': encoding['input_ids'],
            'attention_mask': encoding['attention_mask'],
            'token_type_ids': encoding.get('token_type_ids', torch.zeros_like(encoding['input_ids'])),
            'labels': torch.tensor(label, dtype=torch.long)
        }


In [17]:
# 3. Set up LoRA configuration
def setup_lora_model(model_name="bert-base-uncased", num_labels=2):
    model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
    
    # Define LoRA Config
    lora_config = LoraConfig(
        task_type=TaskType.SEQ_CLS,
        inference_mode=False,
        r=8,  # Rank of the update matrices
        lora_alpha=32,  # Alpha parameter for LoRA scaling
        lora_dropout=0.1,  # Dropout probability for LoRA layers
        # Target the attention modules in BERT
        target_modules=["query", "key", "value"],
    )
    
    # Apply LoRA to model
    model = get_peft_model(model, lora_config)
    return model

In [19]:
# 4. Training function
def train_model(model, train_dataloader, val_dataloader, device, epochs=3, lr=2e-5):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    
    model.to(device)
    best_accuracy = 0
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs} - Training"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            labels = batch['labels'].to(device)
            
            optimizer.zero_grad()
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                labels=labels
            )
            
            loss = outputs.loss
            train_loss += loss.item()
            
            loss.backward()
            optimizer.step()
            
            # Calculate accuracy
            logits = outputs.logits
            predictions = torch.argmax(logits, dim=-1)
            train_correct += (predictions == labels).sum().item()
            train_total += labels.size(0)
        
        train_accuracy = train_correct / train_total
        
        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc=f"Epoch {epoch+1}/{epochs} - Validation"):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                token_type_ids = batch['token_type_ids'].to(device)
                labels = batch['labels'].to(device)
                
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids,
                    labels=labels
                )
                
                loss = outputs.loss
                val_loss += loss.item()
                
                # Calculate accuracy
                logits = outputs.logits
                predictions = torch.argmax(logits, dim=-1)
                val_correct += (predictions == labels).sum().item()
                val_total += labels.size(0)
        
        val_accuracy = val_correct / val_total
        
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {train_loss/len(train_dataloader):.4f}, Train Accuracy: {train_accuracy:.4f}")
        print(f"Val Loss: {val_loss/len(val_dataloader):.4f}, Val Accuracy: {val_accuracy:.4f}")
        
        # Save the best model
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            # Save only the LoRA weights - much more efficient
            model.save_pretrained("./best_lora_tweet_paper_model")
    
    return model

In [8]:
tweets_df, papers_dict = load_data("../subtask4b_query_tweets_dev.tsv", "../subtask4b_collection_data.pkl")

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

full_dataset = TweetPaperMatchDataset(tweets_df, papers_dict, tokenizer)



Using device: cpu


In [16]:
# Split into train and validation
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
    
# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=8)

In [18]:
# Setup model with LoRA
model = setup_lora_model()

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [20]:
# Train model
model = train_model(model, train_dataloader, val_dataloader, device, epochs=5)

Epoch 1/5 - Training:  41%|████      | 114/280 [14:42<21:24,  7.74s/it]


KeyboardInterrupt: 