In [11]:
!pip install datasets



In [17]:
import torch
import torch.nn as nn
from peft import LoraConfig, get_peft_model, TaskType
from transformers import (
    BigBirdTokenizerFast,
    BigBirdConfig,
    BigBirdModel,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding
)
import pandas as pd
from datasets import Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    classification_report,
    precision_score,
    recall_score,
    f1_score
)
import numpy as np
from safetensors.torch import save_file
import os
import json

# =============================================================================
# Constants and File Paths
# =============================================================================
LABEL_COLUMN = 'label'
TEXT_COLUMN = 'text'
CLUE_COLUMN = 'clue'
LABEL_COLUMN_gen = 'generated'
TRAIN_FILE = '/content/RAG_results_train.parquet'
TEST_FILE = '/content/RAG_results_test1.parquet'
MODEL_NAME = 'google/bigbird-roberta-base'
OUTPUT_DIR = './output/'

# =============================================================================
# Data Loading and Preprocessing
# =============================================================================
def load_data(file_path):
    df = pd.read_parquet(file_path)
    df = df.copy()
    df[LABEL_COLUMN] = df[LABEL_COLUMN_gen].astype(int)
    df[TEXT_COLUMN] = df[TEXT_COLUMN].fillna('').astype(str)
    df[CLUE_COLUMN] = df[CLUE_COLUMN].fillna('').astype(str)
    return df

# Load data and create train/validation splits
train_df = load_data(TRAIN_FILE)
test_df = load_data(TEST_FILE)
train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)
display(train_df.head())

# Convert dataframes to Hugging Face Datasets
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)

# Load Tokenizer
tokenizer = BigBirdTokenizerFast.from_pretrained(
    MODEL_NAME,
    model_max_length=4096,
    padding_side="right",
    pad_to_multiple_of=64
)

def preprocess_function(examples):
    text_inputs = tokenizer(
        examples[TEXT_COLUMN],
        padding="max_length",
        truncation=True,
        max_length=4096,
        return_tensors=None
    )
    clue_inputs = tokenizer(
        examples[CLUE_COLUMN],
        padding="max_length",
        truncation=True,
        max_length=768,
        return_tensors=None
    )
    return {
        "input_ids": text_inputs["input_ids"],
        "attention_mask": text_inputs["attention_mask"],
        "clue_input_ids": clue_inputs["input_ids"],
        "clue_attention_mask": clue_inputs["attention_mask"],
        "labels": examples["label"]
    }

# Apply preprocessing on datasets
train_dataset = train_dataset.map(preprocess_function, batched=True, batch_size=32)
val_dataset = val_dataset.map(preprocess_function, batched=True, batch_size=32)

def collate_function(features):
    batch = {
        "input_ids": torch.stack([torch.tensor(f["input_ids"]) for f in features]),
        "attention_mask": torch.stack([torch.tensor(f["attention_mask"]) for f in features]),
        "clue_input_ids": torch.stack([torch.tensor(f["clue_input_ids"]) for f in features]),
        "clue_attention_mask": torch.stack([torch.tensor(f["clue_attention_mask"]) for f in features]),
        "labels": torch.tensor([f["labels"] for f in features]),
    }
    return batch

Unnamed: 0,index,text,generated,clue,bm25,label
2576,2577,"Canberra, Australia - In a surprise move, the ...",1,['Gov. Gavin Newsom of California signed legis...,"[130.84947754501428, 130.84947754501428, 150.5...",1
799,800,Wes Streeting has defended the growing use of ...,0,['The changes were intended to encourage more ...,"[186.17401266066173, 173.02897845778216, 173.0...",0
5135,5136,"At about 3.15am on New Year¬ís Day, Caroline Mc...",0,['Cars sped past the 51-year-old man as he tru...,"[309.77877873299434, 293.8986670266023, 283.18...",0
1608,1609,"INDIANAPOLIS, IN - The Indiana Fever's 2023 WN...",1,"['Alarmed and angry, 80 experts published a ma...","[122.29280634066899, 132.45712275633673, 121.1...",1
6454,6455,"In a shocking turn of events, Detroit Lions wi...",1,['As he waited for a call from his agent in Se...,"[107.4627853104494, 101.00252285268488, 121.72...",1


Map:   0%|          | 0/10497 [00:00<?, ? examples/s]

Map:   0%|          | 0/1167 [00:00<?, ? examples/s]

In [18]:

# =============================================================================
# Model Definition
# =============================================================================
class DualChannelModel(nn.Module):
    def __init__(self, model_name, num_labels, lora_config=None):
        super(DualChannelModel, self).__init__()
        # Load the base BigBird model
        self.bigbird = BigBirdModel.from_pretrained(model_name)

        # Apply LoRA if a configuration is provided
        if lora_config is not None:
            self.bigbird = get_peft_model(self.bigbird, lora_config)  # Only apply LoRA to the BigBird model

        # Update the classifier to accept the concatenated features.
        self.classifier = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(self.bigbird.config.hidden_size * 2, self.bigbird.config.hidden_size),
            nn.GELU(),
            nn.Linear(self.bigbird.config.hidden_size, num_labels)
        )

    def forward(self, input_ids, attention_mask, clue_input_ids, clue_attention_mask, labels=None):
        # Extract [CLS] token from main input
        text_outputs = self.bigbird(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
        # Compute mean pooling for the clue input
        clue_outputs = self.bigbird(input_ids=clue_input_ids, attention_mask=clue_attention_mask).last_hidden_state.mean(dim=1)
        # print(text_outputs.shape)
        # print(clue_outputs.shape)

        # Concatenate text_outputs and clue_outputs along the feature dimension
        fused_features = torch.cat([text_outputs, clue_outputs], dim=1)
        # print(fused_features.shape)

        logits = self.classifier(fused_features)
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
        return {"loss": loss, "logits": logits}

    def save(self, output_dir):
        """Custom method to save the model and LoRA config"""
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        # Save the model's state_dict
        torch.save(self.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))

        # Save the LoRA config if available
        if hasattr(self, 'bigbird') and hasattr(self.bigbird, 'config') and self.bigbird.config:
            lora_config_path = os.path.join(output_dir, "adapter_config.json")
            config = self.bigbird.config.to_dict()  # Assuming LoRA config is part of the BigBird model config
            with open(lora_config_path, "w") as f:
                json.dump(config, f)

        print(f"‚úÖ Model and LoRA adapter saved to {output_dir}")

# Utility function to compute and print the model size in MB
def print_model_size(model):
    param_size = sum(param.nelement() * param.element_size() for param in model.parameters())
    buffer_size = sum(buffer.nelement() * buffer.element_size() for buffer in model.buffers())
    size_mb = (param_size + buffer_size) / 1024 ** 2
    print(f"Model size: {size_mb:.3f} MB")

# Example model name and configuration
model_name = "google/bigbird-roberta-base"  # Replace with your model name if needed
num_labels = 2
model = DualChannelModel(model_name, num_labels)
model.eval()  # Set the model to evaluation mode

batch_size = 2
seq_length = 16

# Create dummy inputs
input_ids = torch.randint(0, 1000, (batch_size, seq_length))
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
clue_input_ids = torch.randint(0, 1000, (batch_size, seq_length))
clue_attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
labels = torch.randint(0, num_labels, (batch_size,))

# Forward pass
with torch.no_grad():
    outputs = model(input_ids, attention_mask, clue_input_ids, clue_attention_mask, labels)
print("Test output:", outputs)

# Debug: print out the model size
print_model_size(model)

Attention type 'block_sparse' is not possible if sequence_length: 16 <= num global tokens: 2 * config.block_size + min. num sliding tokens: 3 * config.block_size + config.num_random_blocks * config.block_size + additional buffer: config.num_random_blocks * config.block_size = 704 with config.block_size = 64, config.num_random_blocks = 3. Changing attention type to 'original_full'...


Test output: {'loss': tensor(0.6913), 'logits': tensor([[-0.0962, -0.1493],
        [-0.0738, -0.1362]])}
Model size: 490.826 MB


In [19]:
# =============================================================================
# Training Setup
# =============================================================================
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    learning_rate=2e-5,
    per_device_train_batch_size=1,      # Further reduced batch size
    per_device_eval_batch_size=1,       # Further reduced batch size
    gradient_accumulation_steps=4,       # Effective batch size of 4
    num_train_epochs=3,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    logging_steps=50,
    warmup_ratio=0.1,
    report_to="none",
    fp16=True,                         # Enable mixed precision training
)

lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=[
        "query", "key", "value",
        "output.dense",
        "classifier.out_proj"
    ],
    lora_dropout=0.1,
    bias="none",
    task_type="SEQ_CLS",
)

model = DualChannelModel(MODEL_NAME, num_labels=2, lora_config=lora_config)

# Enable gradient checkpointing to reduce memory usage
model.bigbird.config.gradient_checkpointing = True
# ËÆ°ÁÆóÂèØËÆ≠ÁªÉÂèÇÊï∞ÁöÑÊÄªÊï∞
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())

print(f"Trainable Parameters: {trainable_params} / {total_params} ({trainable_params / total_params:.2%})")


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average="binary")
    accuracy = accuracy_score(labels, predictions)
    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

class CustomTrainer(Trainer):
    def _save(self, output_dir: str, state_dict=None):
        """Ëá™ÂÆö‰πâ‰øùÂ≠òÊñπÊ≥ïÔºåÁ°Æ‰øùÂêåÊó∂‰øùÂ≠òÊ®°ÂûãÊùÉÈáçÂíå LoRA ÈÄÇÈÖçÂô®"""

        print(f"\nüîç Saving model and adapter to {output_dir} ...")

        # 1. Á°Æ‰øù `state_dict` Â§Ñ‰∫é contiguous Áä∂ÊÄÅ
        if state_dict is None:
            state_dict = {
                key: value.contiguous() if not value.is_contiguous() else value
                for key, value in self.model.state_dict().items()
            }

        # 2. Ë∞ÉÁî® Trainer ÁöÑÂéüÂßã `_save()` ÊñπÊ≥ïÔºå‰øùÂ≠òÂü∫Á°ÄÊ®°ÂûãÊùÉÈáç
        super()._save(output_dir, state_dict=state_dict)

        # 3. ‰øùÂ≠ò LoRA ÈÄÇÈÖçÂô®
        adapter_output_dir = os.path.join(output_dir, "checkpoint-lora")
        self.model.save(adapter_output_dir)

        # 4. ‰øùÂ≠ò `safetensors` ÊùÉÈáç
        safetensors_path = os.path.join(adapter_output_dir, "model.safetensors")
        save_file(state_dict, safetensors_path)

        print(f"‚úÖ Model and adapter saved successfully to {adapter_output_dir}")

        # 5. Á°Æ‰øù `adapter_config.json` Â≠òÂú®
        adapter_config_path = os.path.join(adapter_output_dir, "adapter_config.json")
        if not os.path.exists(adapter_config_path):
            raise FileNotFoundError(f"‚ùå {adapter_config_path} Ê≤°ÊúâÊ≠£Á°Æ‰øùÂ≠òÔºåËØ∑Ê£ÄÊü•ËÆ≠ÁªÉËøáÁ®ãÔºÅ")

        print(f"‚úÖ Adapter config verified at {adapter_config_path}.")


trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_function,
    compute_metrics=compute_metrics
)

# =============================================================================
# Training
# =============================================================================
trainer.train()

# =============================================================================
# Prediction and Evaluation
# =============================================================================
def predict(model, texts, clues, batch_size=8):
    model.eval()
    predictions = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        batch_clues = clues[i:i+batch_size]

        text_inputs = tokenizer(
            batch_texts,
            padding="max_length",
            truncation=True,
            max_length=4096,
            return_tensors="pt"
        ).to(device)

        clue_inputs = tokenizer(
            batch_clues,
            padding="max_length",
            truncation=True,
            max_length=4096,
            return_tensors="pt"
        ).to(device)

        with torch.no_grad():
            outputs = model(
                input_ids=text_inputs["input_ids"],
                attention_mask=text_inputs["attention_mask"],
                clue_input_ids=clue_inputs["input_ids"],
                clue_attention_mask=clue_inputs["attention_mask"]
            )
        probs = torch.softmax(outputs["logits"], dim=-1)[:, 1].cpu().numpy()
        predictions.extend(probs)
    return np.array(predictions)

# Prepare test set lists
texts = test_df[TEXT_COLUMN].tolist()
clues = test_df[CLUE_COLUMN].tolist()
labels = test_df["label"].tolist()

predictions = predict(model, texts, clues, batch_size=4)

def evaluate(y_true, y_pred_probs, threshold=0.5):
    y_pred = (y_pred_probs >= threshold).astype(int)
    print("\nClassification report:")
    print(classification_report(y_true, y_pred, target_names=["ÁúüÂÆû", "ËôöÂÅá"], zero_division=0))
    return {
        'accuracy': accuracy_score(y_true, y_pred),
        'precision': precision_score(y_true, y_pred, zero_division=0),
        'recall': recall_score(y_true, y_pred, zero_division=0),
        'f1': f1_score(y_true, y_pred, zero_division=0)
    }

results = evaluate(labels, predictions)
for k, v in results.items():
    print(f"{k.upper():<10}: {v:.4f}")

Trainable Parameters: 2140418 / 129609218 (1.65%)


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.0483,0.095869,0.978578,0.962025,0.992537,0.977043
2,0.0089,0.03659,0.993145,0.988889,0.996269,0.992565



üîç Saving model and adapter to ./output/checkpoint-2625 ...
‚úÖ Model and LoRA adapter saved to ./output/checkpoint-2625/checkpoint-lora
‚úÖ Model and adapter saved successfully to ./output/checkpoint-2625/checkpoint-lora
‚úÖ Adapter config verified at ./output/checkpoint-2625/checkpoint-lora/adapter_config.json.

üîç Saving model and adapter to ./output/checkpoint-5250 ...
‚úÖ Model and LoRA adapter saved to ./output/checkpoint-5250/checkpoint-lora
‚úÖ Model and adapter saved successfully to ./output/checkpoint-5250/checkpoint-lora
‚úÖ Adapter config verified at ./output/checkpoint-5250/checkpoint-lora/adapter_config.json.

üîç Saving model and adapter to ./output/checkpoint-7872 ...
‚úÖ Model and LoRA adapter saved to ./output/checkpoint-7872/checkpoint-lora
‚úÖ Model and adapter saved successfully to ./output/checkpoint-7872/checkpoint-lora
‚úÖ Adapter config verified at ./output/checkpoint-7872/checkpoint-lora/adapter_config.json.

Classification report:
              precisio

In [20]:
texts[0]

'Agony for Kane Williamson, ecstasy for Matthew Potts. Granted, it may not trip off the tongue like Ian Smith‚Äôs famous commentary at Lord‚Äôs five years ago, but both emotions were very much on show after the moment that changed day one in Hamilton. Williamson had looked indelible during the opening exchanges of this third Test, cruising to 44 and driving New Zealand to an apparent position of strength at 185 for three after tea. But when he was bowled by Potts, having deflected the ball on to his stumps playing late to one that bounced, England‚Äôs fightback was sparked. ‚ÄúI didn‚Äôt have a great view of it,‚Äù said Potts as New Zealand closed on 315 for nine at stumps, the seamer having claimed three for 75 on his return to England‚Äôs XI. ‚ÄúI was a bit confused but then I saw a bail drop down by his feet and it was pure elation after that.‚Äù Asked about his personal success against Williamson, a hold that has returned four dismissals from five encounters, Potts said: ‚ÄúIt was 

In [21]:
# =============================================================================
# Save Prediction Results to File
# =============================================================================
prediction_df = pd.DataFrame({
    "text": texts,
    "clue": clues,
    "true_label": labels,
    "predicted_probability": predictions
})
prediction_df.to_csv("predictions_test1.csv", index=False)
print("Prediction results have been stored in 'predictions_test1.csv'.")

Prediction results have been stored in 'predictions_test1.csv'.


In [25]:
import os
import torch
from safetensors.torch import load_file

# 1. ÂàùÂßãÂåñÊ®°Âûã
model = DualChannelModel(MODEL_NAME, num_labels=2, lora_config=lora_config)
model.bigbird.config.gradient_checkpointing = True  # ÂèØÈÄâ

# 2. ËØªÂèñÊùÉÈáç
checkpoint_path = "output/checkpoint-7872"
checkpoint_file = os.path.join(checkpoint_path, "model.safetensors")

if not os.path.exists(checkpoint_file):
    raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_file}")

state_dict = load_file(checkpoint_file)

# 3. ËÆæÂ§áÂåπÈÖç
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state_dict = {k: v.to(device) for k, v in state_dict.items()}  # ËΩ¨Êç¢ÊùÉÈáçËÆæÂ§á

# 4. Âä†ËΩΩÊùÉÈáçÔºàÂÖÅËÆ∏ÈÉ®ÂàÜ‰∏çÂåπÈÖçÔºâ
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()

# 5. Âä†ËΩΩÊµãËØïÊï∞ÊçÆ
TEST_FILE2 = '/content/RAG_results_test2.parquet'
test_df2 = load_data(TEST_FILE2)

texts2 = test_df2[TEXT_COLUMN].tolist()
clues2 = test_df2[CLUE_COLUMN].tolist()
labels2 = test_df2["generated"].tolist()

# 6. È¢ÑÊµã
predictions2 = predict(model, texts2, clues2, batch_size=4)

# 7. ËØÑ‰º∞
print("\n=== Inference on Test2 Data ===")
results2 = evaluate(labels2, predictions2)
for k, v in results2.items():
    print(f"TEST2 {k.upper():<10}: {v:.4f}")


=== Inference on Test2 Data ===

Classification report:
              precision    recall  f1-score   support

          ÁúüÂÆû       1.00      0.96      0.98       750
          ËôöÂÅá       0.97      1.00      0.98       868

    accuracy                           0.98      1618
   macro avg       0.98      0.98      0.98      1618
weighted avg       0.98      0.98      0.98      1618

TEST2 ACCURACY  : 0.9802
TEST2 PRECISION : 0.9665
TEST2 RECALL    : 0.9977
TEST2 F1        : 0.9819


In [26]:
# =============================================================================
# Save Prediction Results to File
# =============================================================================
prediction_df = pd.DataFrame({
    "text": texts2,
    "clue": clues2,
    "true_label": labels2,
    "predicted_probability": predictions2
})
prediction_df.to_csv("predictions_test2.csv", index=False)
print("Prediction results have been stored in 'predictions_test2.csv'.")

Prediction results have been stored in 'predictions_test2.csv'.


In [27]:
!zip -r checkpoint.zip /content/output/checkpoint-7872

  adding: content/output/checkpoint-7872/ (stored 0%)
  adding: content/output/checkpoint-7872/checkpoint-lora/ (stored 0%)
  adding: content/output/checkpoint-7872/checkpoint-lora/pytorch_model.bin (deflated 7%)
  adding: content/output/checkpoint-7872/checkpoint-lora/model.safetensors (deflated 7%)
  adding: content/output/checkpoint-7872/checkpoint-lora/adapter_config.json (deflated 58%)
  adding: content/output/checkpoint-7872/trainer_state.json (deflated 78%)
  adding: content/output/checkpoint-7872/scaler.pt (deflated 60%)
  adding: content/output/checkpoint-7872/scheduler.pt (deflated 55%)
  adding: content/output/checkpoint-7872/model.safetensors (deflated 7%)
  adding: content/output/checkpoint-7872/training_args.bin (deflated 51%)
  adding: content/output/checkpoint-7872/optimizer.pt (deflated 7%)
  adding: content/output/checkpoint-7872/rng_state.pth (deflated 25%)


In [28]:
from google.colab import files
files.download("checkpoint.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [29]:
from google.colab import files
files.download("predictions_test1.csv")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [30]:
from google.colab import files
files.download("predictions_test2.csv")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
from google.colab import drive
drive.mount('/content/drive')