# ECGBERT Fine-tuning for Heartbeat Classification

This notebook fine-tunes the pretrained ECGBERT model for heartbeat classification using the MIT-BIH Arrhythmia Database.


## Step 1: Setup and Installation


In [None]:
# Clone repository (if using git)
# !git clone https://github.com/your-username/ECGBERT-reproduce-project-.git

# Or if uploaded as dataset, files should be in /kaggle/input/
import os
import sys

# Install dependencies
!pip install -q wfdb neurokit2 pywavelets

# Verify GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


## Step 2: Configure Paths


In [None]:
# Configure paths - ADJUST THESE TO MATCH YOUR DATASET NAMES
MITDB_PATH = '/kaggle/input/mit-bih-arrhythmia-database/mitdb'  # Adjust to your dataset
PRETRAINED_MODEL_PATH = '/kaggle/input/ecgbert-pretrained-models'  # Adjust to your dataset
CLUSTERING_DIR = '/kaggle/input/ecgbert-clustering-models/preprocessed/clustering_models/0.1'  # Adjust
OUTPUT_DIR = '/kaggle/working/fine_tune_output'

# Add project to path - IMPORTANT: Add parent directory to sys.path
PROJECT_DIR = '/kaggle/working/ECGBERT-reproduce-project-'  # Adjust if different
if PROJECT_DIR not in sys.path:
    sys.path.insert(0, PROJECT_DIR)

# Change to project directory to ensure relative imports work
os.chdir(PROJECT_DIR)

# Verify paths exist
print("Checking paths...")
print(f"Project directory: {PROJECT_DIR}")
print(f"Current working directory: {os.getcwd()}")
print(f"MIT-BIH exists: {os.path.exists(MITDB_PATH)}")
print(f"Pretrained model exists: {os.path.exists(PRETRAINED_MODEL_PATH)}")
print(f"Clustering models exist: {os.path.exists(CLUSTERING_DIR)}")
print(f"Python path includes project: {PROJECT_DIR in sys.path}")


## Step 3: Run Fine-tuning


In [None]:
from fine_tune.Fine_tune_heartbeat_main_kaggle import run_finetuning

# Run fine-tuning
run_finetuning(
    mitdb_path=MITDB_PATH,
    pretrained_model_path=PRETRAINED_MODEL_PATH,
    clustering_dir=CLUSTERING_DIR,
    output_dir=OUTPUT_DIR,
    binary_classification=False,  # Set to True for binary classification
    task_name='Heartbeat_Classification'
)


## Step 4: Check Results


In [None]:
# List output files
results_dir = os.path.join(OUTPUT_DIR, 'Heartbeat_Classification', 'results')
if os.path.exists(results_dir):
    print("\nTrained models:")
    for file in os.listdir(results_dir):
        if file.endswith('.pth'):
            file_path = os.path.join(results_dir, file)
            size_mb = os.path.getsize(file_path) / (1024 * 1024)
            print(f"  - {file} ({size_mb:.2f} MB)")
else:
    print("Results directory not found")
