# PhysioNet - Digitization of ECG Images
- Extract the ECG time-series data from scans and photographs of paper printouts of the ECGs.
### https://www.kaggle.com/competitions/physionet-ecg-image-digitization/

## Import Required Libraries
Loading all necessary packages for data processing, visualization, and deep learning with PyTorch.

In [None]:
# Import libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

## Configuration and Device Setup
Setting up data directories and detecting available GPU/CPU for training.

In [None]:
# Paths and config
DATA_DIR = '/kaggle/input/physionet-ecg-image-digitization'
TRAIN_DIR = f'{DATA_DIR}/train'
TEST_DIR = f'{DATA_DIR}/test'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {DEVICE}')

BATCH_SIZE = 8 
EPOCHS = 1
LR = 0.0001

## Load Metadata and Split Dataset

In [None]:
# Load metadata
from sklearn.model_selection import train_test_split

train_df = pd.read_csv(f'{DATA_DIR}/train.csv')
test_df = pd.read_csv(f'{DATA_DIR}/test.csv')

# Get available record IDs from disk
available_train_ids = [int(d) for d in os.listdir(TRAIN_DIR) if os.path.isdir(os.path.join(TRAIN_DIR, d))]
available_test_ids = [int(f.replace('.png', '')) for f in os.listdir(TEST_DIR) if f.endswith('.png')]

print(f'Original train samples in csv: {len(train_df)}, available on disk: {len(available_train_ids)}')
print(f'Original test samples in csv: {len(test_df)}, available on disk: {len(available_test_ids)}')

# Filter to only use available samples
train_df = train_df[train_df['id'].isin(available_train_ids)].reset_index(drop=True)
test_df = test_df[test_df['id'].isin(available_test_ids)].reset_index(drop=True)

print(f'\nFiltered train samples: {len(train_df)}')
print(f'Filtered test samples: {len(test_df)}')

# Split train into train and validation (80/20)
if len(train_df) > 0:
    train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42)
    train_df = train_df.reset_index(drop=True)
    val_df = val_df.reset_index(drop=True)
    print(f'\nAfter 80/20 split:')
    print(f'Train samples: {len(train_df)}')
    print(f'Validation samples: {len(val_df)}')
else:
    val_df = pd.DataFrame()
    print('\nWARNING: No training data available!')

train_df.head()

## ECG Dataset Class
Custom PyTorch dataset to load ECG images and ground truth signals. Images are resized from original size (varies) to 256x256 pixels, then normalized to [0,1]. Signals are interpolated to a standard length for consistent batching.

In [None]:
# Dataset class for ECG images
class ECGDataset(Dataset):
    def __init__(self, df, data_dir, is_train=True, img_size=(256, 256), standard_len=10000):
        self.df = df
        self.data_dir = data_dir
        self.is_train = is_train
        self.img_size = img_size
        self.standard_len = standard_len  # Standard length for all leads via interpolation
        self.leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']

    def __len__(self):
        return len(self.df)
    
    def interpolate_signal(self, signal, target_length):
        """Interpolate signal to target length using linear interpolation"""
        if len(signal) == target_length:
            return signal
        
        # Create indices for interpolation
        old_indices = np.linspace(0, len(signal) - 1, len(signal))
        new_indices = np.linspace(0, len(signal) - 1, target_length)
        
        # Linear interpolation
        interpolated = np.interp(new_indices, old_indices, signal)
        return interpolated

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        record_id = str(row['id'])

        try:
            if self.is_train:
                # Train: images are in subdirectories with multiple pages
                img_path = f'{self.data_dir}/{record_id}/{record_id}-0001.png'
            else:
                # Test: images are directly in the test directory
                img_path = f'{self.data_dir}/{record_id}.png'

            # Load and preprocess image
            if not os.path.exists(img_path):
                raise FileNotFoundError(f"Image not found: {img_path}")
                
            img = Image.open(img_path).convert('RGB')
            img = img.resize(self.img_size)
            img = np.array(img) / 255.0  # Normalize to [0, 1]
            img = torch.FloatTensor(img).permute(2, 0, 1)  # (C, H, W)

            if self.is_train:
                # Load ground truth CSV
                csv_path = f'{self.data_dir}/{record_id}/{record_id}.csv'
                if not os.path.exists(csv_path):
                    raise FileNotFoundError(f"CSV not found: {csv_path}")
                    
                signal_df = pd.read_csv(csv_path)

                # Process each lead separately and interpolate to standard_len
                processed_leads = []
                for lead in self.leads:
                    if lead not in signal_df.columns:
                        raise ValueError(f"Lead {lead} not found in CSV for record {record_id}")
                        
                    # Get non-null values for this lead
                    lead_signal = signal_df[lead].dropna().values
                    
                    # Interpolate to standard length
                    interpolated = self.interpolate_signal(lead_signal, self.standard_len)
                    processed_leads.append(interpolated)

                # Stack all leads into shape (12, standard_len)
                signal = np.stack(processed_leads, axis=0)
                signal = torch.FloatTensor(signal)

                # Get metadata
                fs = row['fs']
                sig_len = row['sig_len']

                return img, signal, fs, sig_len, record_id
            else:
                # For test, just return image and record_id
                return img, record_id
                
        except Exception as e:
            print(f"Error loading record {record_id}: {str(e)}")
            raise

## Simple CNN Model Architecture


In [None]:
# Simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self, standard_len=2500, num_leads=12):
        super(SimpleCNN, self).__init__()
        self.standard_len = standard_len
        self.num_leads = num_leads
        
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        
        # After 3 maxpool layers: 256/8 = 32, 32/8 = 32 -> 32x32x128
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32*32*128, 2048),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(2048, num_leads * standard_len)
        )
        
    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)
        # Reshape to (batch, num_leads, standard_len)
        x = x.view(-1, self.num_leads, self.standard_len)
        return x

## Verify Dataset Loading
Creating train/validation datasets with standard length of 2500 samples per lead. Testing data loading and checking shapes are correct.

In [None]:
# Test the updated dataset with interpolation
train_subset = train_df.copy()
val_subset = val_df.copy()
print(f'Training on {len(train_subset)} samples')
print(f'Validating on {len(val_subset)} samples')

# Create datasets - both use TRAIN_DIR since validation is split from training data
STANDARD_LEN = 2500  # Standard length for all leads via interpolation (most efficient)
train_dataset = ECGDataset(train_subset, TRAIN_DIR, is_train=True, standard_len=STANDARD_LEN)
val_dataset = ECGDataset(val_subset, TRAIN_DIR, is_train=True, standard_len=STANDARD_LEN)

# Check one sample from train
if len(train_dataset) > 0:
    img, signal, fs, sig_len, record_id = train_dataset[0]
    print(f'\nTrain Sample:')
    print(f'Image shape: {img.shape}')
    print(f'Signal shape: {signal.shape}')  # Should be (12, STANDARD_LEN)
    print(f'Sampling frequency: {fs} Hz')
    print(f'Original signal length: {sig_len}')
    print(f'Record ID: {record_id}')
    print(f'\nNote: All leads interpolated to {STANDARD_LEN} for efficient training')
    print(f'      Short leads (2500) kept as-is, long leads (10000) downsampled')
else:
    print('No training samples available!')

# Check one sample from val
if len(val_dataset) > 0:
    img, signal, fs, sig_len, record_id = val_dataset[0]
    print(f'\nVal Sample:')
    print(f'Image shape: {img.shape}')
    print(f'Signal shape: {signal.shape}')
    print(f'Record ID: {record_id}')
else:
    print('No validation samples available!')

## Visualize Training Sample
Displaying input ECG image (256x256 resized) along with ground truth signals for all 12 leads to verify data quality.

In [None]:
# Visualize one training sample - all 12 leads from CSV with input image
if len(train_dataset) > 0:
    img, signal, fs, sig_len, record_id = train_dataset[0]
    
    print(f'Training Sample Visualization')
    print(f'Record ID: {record_id}')
    print(f'Input Image: {record_id}-0001.png (first page)')
    print(f'Sampling frequency: {fs} Hz')
    print(f'Original signal length: {sig_len}')
    print(f'Interpolated to: {STANDARD_LEN}')
    print(f'Signal shape: {signal.shape} (12 leads, {STANDARD_LEN} samples each)\n')
    
    # Create a figure with input image at top and 12 lead plots below
    leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
    
    fig = plt.figure(figsize=(12, 12))
    gs = fig.add_gridspec(13, 1, height_ratios=[10] + [1]*12, hspace=0.3)
    
    # Plot the input image at the top
    ax_img = fig.add_subplot(gs[0])
    img_display = img.permute(1, 2, 0).numpy()  # Convert from (C, H, W) to (H, W, C)
    ax_img.imshow(img_display)
    ax_img.set_title(f'Input ECG Image - Record {record_id}-0001.png', 
                     fontsize=10, fontweight='bold', pad=8)
    ax_img.axis('off')
    
    # Plot all 12 leads
    for i, lead_name in enumerate(leads):
        ax = fig.add_subplot(gs[i+1])
        lead_signal = signal[i].numpy()
        time_axis = np.arange(len(lead_signal)) / fs  # Convert to seconds
        
        ax.plot(time_axis, lead_signal, linewidth=0.8, color='blue')
        ax.set_ylabel(lead_name, fontweight='bold', fontsize=8)
        ax.grid(True, alpha=0.3)
        
        if i == len(leads) - 1:
            ax.set_xlabel('Time (seconds)', fontsize=8)
    
    plt.suptitle(f'Training Sample - Record {record_id}\nInput Image (0001) and Ground Truth Signals (all 12 leads)', 
                 fontsize=12, fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.show()
    
    print(f'\nSignal statistics:')
    for i, lead_name in enumerate(leads):
        lead_signal = signal[i].numpy()
        print(f'{lead_name:>3}: min={lead_signal.min():7.2f}, max={lead_signal.max():7.2f}, mean={lead_signal.mean():7.2f}, std={lead_signal.std():7.2f}')
else:
    print('No training samples available for visualization!')

## Training Configuration
Setting up data loaders, model, loss function (MSE), and Adam optimizer with learning rate 0.001.

In [None]:
# Training setup
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
model = SimpleCNN(standard_len=STANDARD_LEN, num_leads=12).to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

print(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}')
print(f'Train batches: {len(train_loader)}')
print(f'Val batches: {len(val_loader)}')

## Train the Model

In [None]:
# Training loop with validation
for epoch in range(EPOCHS):
    # Training phase
    model.train()
    train_loss = 0
    train_batches = 0
    
    for batch_data in train_loader:
        images, signals, fs_batch, sig_len_batch, record_ids = batch_data
        images = images.to(DEVICE)
        signals = signals.to(DEVICE)  # Shape: (batch, 12, standard_len) - all interpolated to same length
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)  # Shape: (batch, 12, standard_len)
        
        # Calculate loss
        loss = criterion(outputs, signals)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        train_batches += 1
    
    avg_train_loss = train_loss / train_batches if train_batches > 0 else 0
    
    # Validation phase
    model.eval()
    val_loss = 0
    val_batches = 0
    
    with torch.no_grad():
        for batch_data in val_loader:
            images, signals, fs_batch, sig_len_batch, record_ids = batch_data
            images = images.to(DEVICE)
            signals = signals.to(DEVICE)
            
            # Forward pass
            outputs = model(images)
            
            # Calculate loss
            loss = criterion(outputs, signals)
            val_loss += loss.item()
            val_batches += 1
    
    avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
    
    print(f'Epoch {epoch+1}/{EPOCHS}, Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}')

## Prepare Test Dataset
Creating test dataset from unique record IDs (test.csv has one row per lead). Model outputs standard length which will be interpolated to required lengths.

In [None]:
# Create test dataset and predict
# Get unique record IDs from test_df (test_df has one row per lead, not per record)
unique_test_ids = test_df['id'].unique()
test_df_unique = pd.DataFrame({'id': unique_test_ids})

# Filter to available test images
test_df_unique = test_df_unique[test_df_unique['id'].isin(available_test_ids)].reset_index(drop=True)

print(f'Unique test records: {len(test_df_unique)}')
test_dataset = ECGDataset(test_df_unique, TEST_DIR, is_train=False, standard_len=STANDARD_LEN)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

print(f'Generating predictions for {len(test_dataset)} test samples...')
print(f'Model will output {STANDARD_LEN} values per lead')
print(f'Then interpolate to required lengths per test.csv (flexible for any dimension)')

## Generate Predictions

In [None]:
# Generate predictions
model.eval()
predictions = {}

with torch.no_grad():
    for images, record_ids in tqdm(test_loader):
        images = images.to(DEVICE)
        outputs = model(images)  # Shape: (batch, 12, standard_len)
        
        # Store predictions for each record
        for i, record_id in enumerate(record_ids):
            pred_signal = outputs[i].cpu().numpy()  # Shape: (12, standard_len)
            predictions[record_id] = pred_signal

print(f'Generated predictions for {len(predictions)} records')
print(f'Each prediction has shape (12, {STANDARD_LEN})')
print(f'Will be interpolated to required dimensions during submission generation')

## Compare Ground Truth vs Prediction
Visualizing one validation sample showing how well the model predictions match the actual ECG signals across all 12 leads.

In [None]:
# Validation sample comparison: Ground Truth vs Prediction (all 12 leads)
if len(val_dataset) > 0:
    # Get one validation sample
    img, gt_signal, fs, sig_len, record_id = val_dataset[0]
    
    # Get prediction for this sample
    model.eval()
    with torch.no_grad():
        img_batch = img.unsqueeze(0).to(DEVICE)  # Add batch dimension
        pred_signal = model(img_batch).cpu().squeeze(0)  # Remove batch dimension
    
    print(f'Validation Sample Comparison - GT vs Prediction')
    print(f'Record ID: {record_id}')
    print(f'Sampling frequency: {fs} Hz')
    print(f'Signal length: {STANDARD_LEN}')
    print(f'GT shape: {gt_signal.shape}')
    print(f'Prediction shape: {pred_signal.shape}\n')
    
    # Create a figure with 12 subplots (one per lead)
    leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
    
    fig, axes = plt.subplots(12, 1, figsize=(18, 22))
    fig.suptitle(f'Validation Sample Comparison - Record {record_id}\\nGround Truth (Blue) vs Prediction (Red)', 
                 fontsize=16, fontweight='bold')
    
    for i, (ax, lead_name) in enumerate(zip(axes, leads)):
        gt_lead = gt_signal[i].numpy()
        pred_lead = pred_signal[i].numpy()
        time_axis = np.arange(len(gt_lead)) / fs  # Convert to seconds
        
        # Plot both signals with different colors and some transparency
        ax.plot(time_axis, gt_lead, linewidth=1.2, color='blue', alpha=0.7, label='Ground Truth')
        ax.plot(time_axis, pred_lead, linewidth=1.0, color='red', alpha=0.6, label='Prediction')
        
        ax.set_ylabel(lead_name, fontweight='bold', fontsize=12)
        ax.grid(True, alpha=0.3)
        
        # Add legend only to first subplot
        if i == 0:
            ax.legend(loc='upper right')
        
        if i == len(axes) - 1:
            ax.set_xlabel('Time (seconds)', fontsize=12)
    
    plt.tight_layout()
    plt.show()
    
    # Calculate simple MSE for each lead
    print(f'\\nMean Squared Error (MSE) per lead:')
    for i, lead_name in enumerate(leads):
        gt_lead = gt_signal[i].numpy()
        pred_lead = pred_signal[i].numpy()
        mse = np.mean((gt_lead - pred_lead) ** 2)
        print(f'{lead_name:>3}: MSE = {mse:.6f}')
else:
    print('No validation samples available for comparison!')

## Calculate SNR Metrics
Computing Signal-to-Noise Ratio (SNR) for train and validation sets to measure prediction quality. Higher SNR means better signal extraction.

In [None]:
# Calculate SNR metrics for Train and Validation sets
# Reference: https://www.kaggle.com/code/metric/physionet-ecg-signal-extraction-metric/
# SNR (dB) = 10 * log10(P_signal / P_noise)
# where P_signal is power of original signal, P_noise is power of (original - predicted)

def calculate_snr(ground_truth, prediction):
    """
    Calculate Signal-to-Noise Ratio in dB.
    
    Args:
        ground_truth: numpy array of ground truth signal
        prediction: numpy array of predicted signal
    
    Returns:
        SNR in dB. Returns infinity if prediction is perfect.
    """
    # Calculate signal power
    signal_power = np.mean(ground_truth ** 2)
    
    # Calculate noise power (mean squared error)
    noise_power = np.mean((ground_truth - prediction) ** 2)
    
    # Avoid division by zero
    if noise_power < 1e-10:  # Essentially perfect prediction
        return float('inf')
    
    # Calculate SNR in dB
    snr_db = 10 * np.log10(signal_power / noise_power)
    
    return snr_db


def evaluate_snr_on_dataset(dataset, model, device, dataset_name="Dataset"):
    """
    Evaluate SNR for all samples in a dataset.
    
    Args:
        dataset: ECGDataset instance
        model: trained model
        device: torch device
        dataset_name: name for printing (e.g., "Train" or "Validation")
    
    Returns:
        Dictionary with SNR statistics
    """
    model.eval()
    
    all_snr = []  # Overall SNR per sample (averaged across all leads)
    lead_snr = {lead: [] for lead in ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']}
    leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
    
    print(f'\\nEvaluating SNR on {dataset_name} set ({len(dataset)} samples)...')
    
    with torch.no_grad():
        for idx in tqdm(range(len(dataset)), desc=f"{dataset_name} SNR"):
            # Get ground truth
            img, gt_signal, fs, sig_len, record_id = dataset[idx]
            
            # Get prediction
            img_batch = img.unsqueeze(0).to(device)
            pred_signal = model(img_batch).cpu().squeeze(0)
            
            # Convert to numpy
            gt_signal = gt_signal.numpy()
            pred_signal = pred_signal.numpy()
            
            # Calculate SNR for each lead
            sample_snr_list = []
            for lead_idx, lead_name in enumerate(leads):
                gt_lead = gt_signal[lead_idx]
                pred_lead = pred_signal[lead_idx]
                
                snr = calculate_snr(gt_lead, pred_lead)
                lead_snr[lead_name].append(snr)
                sample_snr_list.append(snr)
            
            # Average SNR across all leads for this sample
            # Filter out infinities for averaging
            finite_snrs = [s for s in sample_snr_list if not np.isinf(s)]
            if len(finite_snrs) > 0:
                avg_snr = np.mean(finite_snrs)
            else:
                avg_snr = float('inf')
            all_snr.append(avg_snr)
    
    # Calculate statistics
    # Filter infinities for statistics
    finite_all_snr = [s for s in all_snr if not np.isinf(s)]
    
    results = {
        'mean_snr': np.mean(finite_all_snr) if len(finite_all_snr) > 0 else float('inf'),
        'median_snr': np.median(finite_all_snr) if len(finite_all_snr) > 0 else float('inf'),
        'std_snr': np.std(finite_all_snr) if len(finite_all_snr) > 0 else 0,
        'min_snr': np.min(finite_all_snr) if len(finite_all_snr) > 0 else float('inf'),
        'max_snr': np.max(finite_all_snr) if len(finite_all_snr) > 0 else float('inf'),
        'lead_snr': {}
    }
    
    # Calculate per-lead statistics
    for lead_name, snr_list in lead_snr.items():
        finite_lead_snr = [s for s in snr_list if not np.isinf(s)]
        if len(finite_lead_snr) > 0:
            results['lead_snr'][lead_name] = {
                'mean': np.mean(finite_lead_snr),
                'median': np.median(finite_lead_snr),
                'std': np.std(finite_lead_snr)
            }
        else:
            results['lead_snr'][lead_name] = {
                'mean': float('inf'),
                'median': float('inf'),
                'std': 0
            }
    
    return results


# Evaluate on Train set
train_results = evaluate_snr_on_dataset(train_dataset, model, DEVICE, "Train")

print(f'\\n{"="*60}')
print(f'TRAIN SET SNR METRICS')
print(f'{"="*60}')
print(f'Overall Statistics:')
print(f'  Mean SNR:   {train_results["mean_snr"]:.2f} dB')
print(f'\\nPer-Lead Mean SNR:')
for lead_name in ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']:
    mean_snr = train_results['lead_snr'][lead_name]['mean']
    print(f'  {lead_name:>3}: {mean_snr:7.2f} dB')

# Evaluate on Validation set
val_results = evaluate_snr_on_dataset(val_dataset, model, DEVICE, "Validation")

print(f'\\n{"="*60}')
print(f'VALIDATION SET SNR METRICS')
print(f'{"="*60}')
print(f'  Mean SNR:   {val_results["mean_snr"]:.2f} dB')
print(f'\\nPer-Lead Mean SNR:')
for lead_name in ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']:
    mean_snr = val_results['lead_snr'][lead_name]['mean']
    print(f'  {lead_name:>3}: {mean_snr:7.2f} dB')

print(f'\\n{"="*60}')

## Create Submission File

In [None]:
# Create submission dataframe with flexible interpolation
def interpolate_signal(signal, target_length):
    """Interpolate signal to target length using linear interpolation"""
    if len(signal) == target_length:
        return signal
    
    # Create indices for interpolation
    old_indices = np.linspace(0, len(signal) - 1, len(signal))
    new_indices = np.linspace(0, len(signal) - 1, target_length)
    
    # Linear interpolation
    interpolated = np.interp(new_indices, old_indices, signal)
    return interpolated

submission_rows = []
leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']

for idx, row in tqdm(test_df.iterrows(), total=len(test_df)):
    record_id = str(row['id'])
    lead_name = row['lead']
    num_rows = int(row['number_of_rows'])
    
    # Skip if we don't have prediction for this record
    if record_id not in predictions:
        continue
    
    # Get prediction for this record: shape (12, STANDARD_LEN)
    pred_signal = predictions[record_id]
    
    # Get index for this lead
    lead_idx = leads.index(lead_name)
    
    # Extract signal for this lead (length = STANDARD_LEN, typically 2500)
    lead_signal = pred_signal[lead_idx]  # Shape: (STANDARD_LEN,)
    
    # Interpolate from STANDARD_LEN to required num_rows (flexible for any dimension)
    if len(lead_signal) != num_rows:
        lead_values = interpolate_signal(lead_signal, num_rows)
    else:
        lead_values = lead_signal
    
    # Create submission rows
    for timestep, value in enumerate(lead_values):
        submission_id = f'{record_id}_{timestep}_{lead_name}'
        submission_rows.append({'id': submission_id, 'value': float(value)})

submission_df = pd.DataFrame(submission_rows)
print(f'Submission shape: {submission_df.shape}')
print(f'\nFlexible interpolation summary:')
print(f'- All predictions generated at standard length: {STANDARD_LEN}')
print(f'- Interpolated to required dimensions per test.csv:')
print(f'  - Lead II: {STANDARD_LEN} → 10000 (upsampled 4x)')
print(f'  - Other leads: {STANDARD_LEN} → 2500 (kept as-is)')
print(f'- This approach works for ANY dimension specified in test.csv')
submission_df.head()

In [None]:
# Save submission
submission_df.to_csv('submission.csv', index=False)
print('Submission saved to submission.csv')

In [None]:
submission_df.head()