In [2]:
!ls

0_setup.ipynb            3_convert_to_onnx.ipynb  6_fine_tune.ipynb
1_data_processing.ipynb  4_benchmarks.ipynb       6_fine_tune_backup.ipynb
2_train_models.ipynb     5_explainability.ipynb


In [3]:
%cd ..

/Users/matthew/Documents/deepmind_internship


In [4]:
# ONNX Conversion Code - Run This First!

# Cell 1: Imports
import gc
import json
from pathlib import Path
import pandas as pd
from sklearn.model_selection import train_test_split
from typing import List

# Data & ML
import numpy as np
import torch
import onnx
import onnxruntime as ort
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType

# Hugging Face
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Suppress ONNX Runtime logging
import logging
logging.getLogger("onnxruntime").setLevel(logging.ERROR)

# Cell 2: Configuration & Model Discovery
# Model & ONNX Configuration
BASE_DIR = Path("models")
ONNX_OPSET_VERSION = 17

# Data & Split Configuration
DATA_FILE_PATH = Path("data/FinancialPhraseBank/all-data.csv")
RANDOM_SEED = 42
TEST_SIZE = 0.25 # 25% for the test set

def is_valid_model_dir(d: Path) -> bool:
    """Checks if a directory contains a valid Hugging Face model."""
    config_file = d / "config.json"
    model_file_exists = (d / "pytorch_model.bin").exists() or (d / "model.safetensors").exists()
    
    if not config_file.exists() or not model_file_exists:
        return False
    
    # Check if the config follows Hugging Face format
    try:
        with open(config_file, 'r') as f:
            config = json.load(f)
        
        # Valid HF configs should have either 'model_type' with a known architecture
        # or 'architectures' field with valid architecture names
        has_valid_model_type = config.get('model_type') in [
            'bert', 'distilbert', 'roberta', 'albert', 'electra', 'deberta', 
            'deberta-v2', 'xlnet', 'xlm-roberta', 'camembert', 'flaubert'
        ]
        
        has_valid_architectures = 'architectures' in config and any(
            arch.endswith(('ForSequenceClassification', 'Model')) 
            for arch in config['architectures']
        )
        
        return has_valid_model_type or has_valid_architectures
        
    except (json.JSONDecodeError, Exception):
        return False

def can_load_with_transformers(model_dir: Path) -> bool:
    """Test if a model can actually be loaded by transformers library."""
    try:
        # Try to load tokenizer and model without actually loading the weights
        tokenizer = AutoTokenizer.from_pretrained(model_dir)
        # Just check if we can initialize the model class without loading weights
        model = AutoModelForSequenceClassification.from_pretrained(
            model_dir, 
            torch_dtype=torch.float32,
            device_map=None
        )
        del tokenizer, model
        gc.collect()
        return True
    except Exception as e:
        print(f"   - ⚠️  Cannot load with transformers: {e}")
        return False

def prepare_calibration_data(data_path, test_size, random_seed, num_samples=100):
    """Loads, splits, and samples the data to create a calibration set."""
    print(f"Loading data from {data_path}...")
    df = pd.read_csv(
        data_path,
        header=None,
        names=['sentiment', 'text'],
        encoding='latin-1')

    # Split data to get the test set
    _, test_df = train_test_split(
        df, test_size=test_size, random_state=random_seed, stratify=df['sentiment'])

    # Sample the calibration set from the test data
    calibration_df = test_df.sample(n=num_samples, random_state=random_seed)
    print(f"✅ Created a calibration dataset with {len(calibration_df)} samples.")
    return calibration_df

# Find all valid model directories
print("🔍 Discovering and validating models...")
all_model_dirs = [d for d in BASE_DIR.iterdir() if d.is_dir()]
valid_model_dirs = []

for model_dir in all_model_dirs:
    if is_valid_model_dir(model_dir):
        print(f"   - Checking {model_dir.name}...")
        if can_load_with_transformers(model_dir):
            valid_model_dirs.append(model_dir)
            print(f"   - ✅ {model_dir.name} - Valid and loadable")
        else:
            print(f"   - ❌ {model_dir.name} - Valid format but not loadable with transformers")
    else:
        print(f"   - ❌ {model_dir.name} - Invalid Hugging Face model format")

model_dirs = valid_model_dirs
print(f"\n✅ Found {len(model_dirs)} valid and loadable models.")

# Call the function to prepare data
calibration_df = prepare_calibration_data(DATA_FILE_PATH, TEST_SIZE, RANDOM_SEED)


# Cell 3: Automated Node Finder
def find_final_nodes_to_exclude(onnx_model_path: Path) -> List[str]:
    """
    Analyzes an ONNX model to find the names of the final MatMul or Add nodes
    right before the output, tracing backwards past common post-processing nodes.
    """
    nodes_to_exclude = []
    try:
        model = onnx.load(str(onnx_model_path))
        
        # Create maps of all node inputs/outputs
        output_to_node_map = {out: node for node in model.graph.node for out in node.output}

        # Find the final output of the graph
        graph_outputs = [output.name for output in model.graph.output]
        
        for graph_output in graph_outputs:
            # Start tracing backwards from the graph's output
            current_node = output_to_node_map.get(graph_output)
            
            # Trace backwards past common non-computational nodes
            while current_node and current_node.op_type in ['Softmax', 'LogSoftmax', 'Identity']:
                parent_node_output = current_node.input[0]
                current_node = output_to_node_map.get(parent_node_output)

            # Check if the traced-back node is a good candidate for exclusion
            if current_node and (current_node.op_type == 'MatMul' or current_node.op_type == 'Add' or current_node.op_type == 'Gemm'):
                nodes_to_exclude.append(current_node.name)
                print(f"   -> 🎯 Automatically identified final node to exclude: '{current_node.name}' ({current_node.op_type})")
            else:
                print(f"   -> ⚠️  Could not automatically identify a final MatMul/Add/Gemm node to exclude for output '{graph_output}'.")

    except Exception as e:
        print(f"   -> ❌ Error analyzing ONNX graph: {e}. Proceeding without exclusions.")

    return nodes_to_exclude


# Cell 4: ONNX Helper Classes
class ONNXExportWrapper(torch.nn.Module):
    """A wrapper to ensure model output is a simple tensor for ONNX compatibility."""
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids, attention_mask):
        outputs = self.model(
            input_ids=input_ids, attention_mask=attention_mask, return_dict=False
        )
        return outputs[0]

class TextCalibrationDataReader(CalibrationDataReader):
    """A robust data reader that adapts to the model's specific inputs."""
    def __init__(self, data_df: pd.DataFrame, tokenizer, onnx_model_path: Path):
        self.tokenizer = tokenizer
        self.data_list = data_df["text"].tolist()
        self.index = 0

        # Find the model's required inputs
        session = ort.InferenceSession(str(onnx_model_path), providers=["CPUExecutionProvider"])
        model_inputs = {input.name for input in session.get_inputs()}

        # Tokenize all data and filter to only include the model's inputs
        tokenized_data = self.tokenizer(
            self.data_list, padding="max_length", truncation=True, max_length=128, return_tensors="np"
        )
        self.feed = {
            key: tokenized_data[key] for key in tokenized_data if key in model_inputs
        }
        self.input_names = list(self.feed.keys())

    def get_next(self):
        if self.index >= len(self.data_list):
            return None

        item = {name: self.feed[name][self.index:self.index+1] for name in self.input_names}
        self.index += 1
        return item

# Cell 5: Main Processing & Export Loop
def export_model_to_onnx(model, tokenizer, onnx_path: Path, opset_version: int):
    """Exports a PyTorch model to the ONNX format."""
    print("   - Wrapping model for ONNX export...")
    wrapped_model = ONNXExportWrapper(model)
    wrapped_model.eval()
    dummy_input = tokenizer("This is a sample sentence.", return_tensors="pt")
    print(f"   - 🚀 Exporting to ONNX (Opset {opset_version})...")
    torch.onnx.export(
        model=wrapped_model,
        args=(dummy_input["input_ids"], dummy_input["attention_mask"]),
        f=str(onnx_path), input_names=["input_ids", "attention_mask"], output_names=["output"],
        dynamic_axes={
            "input_ids": {0: "batch_size", 1: "sequence_length"},
            "attention_mask": {0: "batch_size", 1: "sequence_length"},
            "output": {0: "batch_size"},
        },
        opset_version=opset_version, do_constant_folding=True,
    )
    print(f"   - ✅ Model successfully exported to {onnx_path.name}")

for model_dir in model_dirs:
    print("-" * 70)
    print(f"⏳ Processing model: {model_dir.name}")

    onnx_dir = model_dir / "onnx"
    onnx_dir.mkdir(exist_ok=True)
    onnx_model_path = onnx_dir / "model.onnx"
    quantised_model_path = onnx_dir / "model-quantised.onnx"

    # --- Step 1: Export to ONNX if needed ---
    if not onnx_model_path.exists():
        print("   - 📦 ONNX model not found. Starting export...")
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_dir)
            model = AutoModelForSequenceClassification.from_pretrained(model_dir)
            export_model_to_onnx(model, tokenizer, onnx_model_path, ONNX_OPSET_VERSION)
            del model, tokenizer
            gc.collect()
        except Exception as e:
            print(f"   - ❌ Export failed for {model_dir.name}: {e}")
            continue
    else:
        print(f"   - ✅ Standard ONNX model already exists.")

   
print("-" * 70)
print("🎉 All models have been processed.")

🔍 Discovering and validating models...
   - Checking tinybert-financial-classifier-fine-tuned...
   - ✅ tinybert-financial-classifier-fine-tuned - Valid and loadable
   - Checking all-MiniLM-L6-v2-financial-sentiment...
   - ✅ all-MiniLM-L6-v2-financial-sentiment - Valid and loadable
   - Checking distilbert-financial-sentiment...
   - ✅ tinybert-financial-classifier-fine-tuned - Valid and loadable
   - Checking all-MiniLM-L6-v2-financial-sentiment...
   - ✅ all-MiniLM-L6-v2-financial-sentiment - Valid and loadable
   - Checking distilbert-financial-sentiment...
   - ✅ distilbert-financial-sentiment - Valid and loadable
   - Checking finbert-tone-financial-sentiment...
   - ✅ finbert-tone-financial-sentiment - Valid and loadable
   - ✅ distilbert-financial-sentiment - Valid and loadable
   - Checking finbert-tone-financial-sentiment...
   - ✅ finbert-tone-financial-sentiment - Valid and loadable
   - Checking SmolLM2-360M-Instruct-financial-sentiment...
   - Checking SmolLM2-360M-Instr