In [None]:
import numpy as np
import pandas as pd
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.models import efficientnet_b0
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os
from tqdm import tqdm
import time
from scipy.signal import savgol_filter
import math

In [None]:
KAGGLE_DIR = "/kaggle/input/physionet-ecg-image-digitization"
EFFICIENTNET_PATH = "/kaggle/input/efficientnet-b0/efficientnet_b0_rwightman-7f5810bc.pth"
TRAIN_CSV = os.path.join(KAGGLE_DIR, "train.csv")
TEST_CSV = os.path.join(KAGGLE_DIR, "test.csv")
SUBMISSION_CSV = os.path.join(KAGGLE_DIR, "sample_submission.parquet")
TRAIN_DIR = os.path.join(KAGGLE_DIR, "train")
TEST_DIR = os.path.join(KAGGLE_DIR, "test")

train_meta = pd.read_csv(TRAIN_CSV)
test_meta = pd.read_csv(TEST_CSV)
submission_template = pd.read_parquet(SUBMISSION_CSV)

print(f"Training samples: {len(train_meta)}")
print(f"Test samples: {len(test_meta)}")

In [None]:
def compute_snr(original, reconstructed, max_shift=50):
    """Compute Signal-to-Noise Ratio with optimal alignment"""
    if len(original) != len(reconstructed):
        min_len = min(len(original), len(reconstructed))
        original = original[:min_len]
        reconstructed = reconstructed[:min_len]
    
    if len(original) == 0:
        return 0
    
    best_snr = -np.inf
    
    for shift in range(-max_shift, max_shift + 1):
        if shift >= 0:
            rec_shifted = reconstructed[shift:]
            orig_shifted = original[:len(rec_shifted)]
        else:
            rec_shifted = reconstructed[:len(reconstructed) + shift]
            orig_shifted = original[-shift:len(rec_shifted) - shift]
        
        if len(rec_shifted) > 100:
            signal_power = np.mean(orig_shifted ** 2)
            noise_power = np.mean((rec_shifted - orig_shifted) ** 2)
            
            if noise_power > 1e-10 and signal_power > 1e-10:
                snr = 10 * np.log10(signal_power / noise_power)
                if snr > best_snr:
                    best_snr = snr
    
    return best_snr if best_snr != -np.inf else 0

def compute_rmse(original, reconstructed):
    """Compute Root Mean Square Error"""
    if len(original) != len(reconstructed):
        min_len = min(len(original), len(reconstructed))
        original = original[:min_len]
        reconstructed = reconstructed[:min_len]
    
    if len(original) == 0:
        return 0
    
    valid_mask = ~(np.isnan(original) | np.isnan(reconstructed) | np.isinf(original) | np.isinf(reconstructed))
    if np.sum(valid_mask) == 0:
        return 0
    
    original_clean = original[valid_mask]
    reconstructed_clean = reconstructed[valid_mask]
    
    return np.sqrt(np.mean((original_clean - reconstructed_clean) ** 2))

def compute_correlation(original, reconstructed):
    """Compute Pearson correlation coefficient"""
    if len(original) != len(reconstructed):
        min_len = min(len(original), len(reconstructed))
        original = original[:min_len]
        reconstructed = reconstructed[:min_len]
    
    if len(original) < 2:
        return 0
    
    valid_mask = ~(np.isnan(original) | np.isnan(reconstructed) | np.isinf(original) | np.isinf(reconstructed))
    if np.sum(valid_mask) < 2:
        return 0
    
    original_clean = original[valid_mask]
    reconstructed_clean = reconstructed[valid_mask]
    
    return np.corrcoef(original_clean, reconstructed_clean)[0, 1]

In [None]:
class StableECGDataset(Dataset):
    def __init__(self, meta_df, data_dir, split='train', transform=None, max_samples=None, img_size=224):
        self.meta_df = meta_df.copy()
        self.data_dir = data_dir
        self.split = split
        self.transform = transform
        self.img_size = img_size
        
        if max_samples:
            self.meta_df = self.meta_df.sample(n=min(max_samples, len(self.meta_df))).reset_index(drop=True)
        
        self.leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
        self.lead_durations = {'II': 10.0, **{lead: 2.5 for lead in self.leads if lead != 'II'}}
        
    def __len__(self):
        return len(self.meta_df)
    
    def load_image(self, ecg_id):
        if self.split == 'test':
            image_file = f"{self.data_dir}/{ecg_id}.png"
            if os.path.exists(image_file):
                image = cv2.imread(image_file, cv2.IMREAD_COLOR)
                if image is not None:
                    return cv2.resize(image, (self.img_size, self.img_size))
        else:
            for idx in range(1, 13):
                image_file = f"{self.data_dir}/{ecg_id}/{ecg_id}-{idx:04d}.png"
                if os.path.exists(image_file):
                    image = cv2.imread(image_file, cv2.IMREAD_COLOR)
                    if image is not None:
                        return cv2.resize(image, (self.img_size, self.img_size))
        
        return np.ones((self.img_size, self.img_size, 3), dtype=np.uint8) * 128
    
    def load_ground_truth(self, ecg_id):
        gt_file = f"{self.data_dir}/{ecg_id}/{ecg_id}.csv"
        if os.path.exists(gt_file):
            return pd.read_csv(gt_file)
        return None
    
    def __getitem__(self, idx):
        row = self.meta_df.iloc[idx]
        ecg_id = row['id']
        fs = row['fs']
        
        image = self.load_image(ecg_id)
        
        if self.split == 'test':
            target_lead = row['lead']
            
            if self.transform:
                transformed = self.transform(image=image)
                image_tensor = transformed['image']
            else:
                image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
            
            return {
                'image': image_tensor,
                'ecg_id': ecg_id,
                'fs': fs,
                'target_lead': target_lead
            }
        else:
            ground_truth = self.load_ground_truth(ecg_id)
            
            lead_signals = {}
            for lead in self.leads:
                if ground_truth is not None and lead in ground_truth.columns:
                    duration = self.lead_durations[lead]
                    expected_length = int(math.floor(fs * duration))
                    signal = ground_truth[lead].values.astype(np.float32)
                    
                    if len(signal) > expected_length:
                        signal = signal[:expected_length]
                    elif len(signal) < expected_length:
                        signal = np.pad(signal, (0, expected_length - len(signal)), mode='constant')
                    
                    signal = np.nan_to_num(signal, nan=0.0)
                    signal = (signal - signal.mean()) / (signal.std() + 1e-8)
                    lead_signals[lead] = signal
                else:
                    lead_signals[lead] = np.zeros(2500, dtype=np.float32)
            
            if self.transform:
                transformed = self.transform(image=image)
                image_tensor = transformed['image']
            else:
                image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
            
            return {
                'image': image_tensor,
                'signals': lead_signals,
                'ecg_id': ecg_id,
                'fs': fs
            }

In [None]:
class StableECGModel(nn.Module):
    def __init__(self, num_leads=12, signal_length=2500):
        super(StableECGModel, self).__init__()
        
        self.backbone = efficientnet_b0()
        
        checkpoint = torch.load(EFFICIENTNET_PATH)
        self.backbone.load_state_dict(checkpoint)
        
        self.backbone.features[0][0] = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
        in_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity()
        
        self.shared_encoder = nn.Sequential(
            nn.Linear(in_features, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
        )
        
        self.lead_heads = nn.ModuleDict()
        for lead in ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']:
            self.lead_heads[lead] = nn.Linear(512, signal_length)
        
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        features = self.backbone(x)
        encoded = self.shared_encoder(features)
        
        outputs = {}
        for lead, head in self.lead_heads.items():
            outputs[lead] = head(encoded)
        
        return outputs

In [None]:
def get_stable_transforms(img_size=224):
    train_transform = A.Compose([
        A.Resize(img_size, img_size),
        A.HorizontalFlip(p=0.3),
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])
    
    test_transform = A.Compose([
        A.Resize(img_size, img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])
    
    return train_transform, test_transform

In [None]:
class StableECGLoss(nn.Module):
    def __init__(self):
        super(StableECGLoss, self).__init__()
        self.mse_loss = nn.MSELoss()
        
    def forward(self, predictions, targets):
        total_loss = 0
        lead_count = 0
        
        for lead in predictions.keys():
            if lead in targets:
                pred_signal = predictions[lead]
                target_signal = targets[lead]
                
                valid_mask = ~torch.isnan(target_signal) & ~torch.isinf(target_signal)
                if valid_mask.sum() > 0:
                    pred_clean = pred_signal[valid_mask]
                    target_clean = target_signal[valid_mask]
                    
                    loss = self.mse_loss(pred_clean, target_clean)
                    total_loss += loss
                    lead_count += 1
        
        return total_loss / max(lead_count, 1)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = StableECGModel(num_leads=12, signal_length=2500).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

criterion = StableECGLoss()

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
def stable_collate_fn(batch):
    batch_dict = {}
    
    for key in batch[0].keys():
        if key == 'signals':
            signals_dict = {}
            for lead in batch[0]['signals'].keys():
                signal_list = []
                for item in batch:
                    signal = item['signals'][lead].copy()
                    signal = np.nan_to_num(signal, nan=0.0)
                    if len(signal) < 2500:
                        signal = np.pad(signal, (0, 2500 - len(signal)), mode='constant')
                    elif len(signal) > 2500:
                        signal = signal[:2500]
                    signal_list.append(torch.from_numpy(signal))
                signals_dict[lead] = torch.stack(signal_list)
            batch_dict[key] = signals_dict
        else:
            values = [item[key] for item in batch]
            if isinstance(values[0], torch.Tensor):
                batch_dict[key] = torch.stack(values)
            else:
                batch_dict[key] = values
    
    return batch_dict

In [None]:
def train_epoch_stable(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    batch_count = 0
    
    pbar = tqdm(train_loader, desc="Training")
    
    for batch_idx, batch in enumerate(pbar):
        images = batch['image'].to(device)
        signals = batch['signals']
        
        optimizer.zero_grad()
        
        predictions = model(images)
        
        signal_targets = {}
        for lead in predictions.keys():
            signal_targets[lead] = signals[lead].to(device)
        
        loss = criterion(predictions, signal_targets)
        
        if torch.isnan(loss) or torch.isinf(loss):
            continue
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        total_loss += loss.item()
        batch_count += 1
        
        if batch_idx % 5 == 0:
            pbar.set_postfix({
                'Loss': f"{loss.item():.4f}",
                'AvgLoss': f"{total_loss/batch_count:.4f}"
            })
    
    return total_loss / batch_count if batch_count > 0 else float('inf')

In [None]:
def evaluate_model_stable(model, val_loader, device):
    """Evaluate model performance with multiple metrics"""
    model.eval()
    all_snrs = []
    all_rmses = []
    all_correlations = []
    
    with torch.no_grad():
        for batch in val_loader:
            images = batch['image'].to(device)
            signals = batch['signals']
            
            predictions = model(images)
            
            for lead in predictions.keys():
                if lead in signals:
                    pred_signal = predictions[lead].cpu().numpy()
                    true_signal = signals[lead].cpu().numpy()
                    
                    for i in range(pred_signal.shape[0]):
                        pred = pred_signal[i]
                        true = true_signal[i]
                        
                        pred = np.nan_to_num(pred, nan=0.0)
                        true = np.nan_to_num(true, nan=0.0)
                        
                        snr = compute_snr(true, pred)
                        rmse = compute_rmse(true, pred)
                        corr = compute_correlation(true, pred)
                        
                        if not np.isnan(snr) and not np.isinf(snr):
                            all_snrs.append(snr)
                        if not np.isnan(rmse) and not np.isinf(rmse):
                            all_rmses.append(rmse)
                        if not np.isnan(corr) and not np.isinf(corr):
                            all_correlations.append(corr)
    
    metrics = {
        'snr_mean': np.mean(all_snrs) if all_snrs else 0,
        'snr_std': np.std(all_snrs) if all_snrs else 0,
        'rmse_mean': np.mean(all_rmses) if all_rmses else 0,
        'rmse_std': np.std(all_rmses) if all_rmses else 0,
        'correlation_mean': np.mean(all_correlations) if all_correlations else 0,
        'correlation_std': np.std(all_correlations) if all_correlations else 0
    }
    
    return metrics

In [None]:
train_transform, test_transform = get_stable_transforms(img_size=512)

train_dataset = StableECGDataset(
    meta_df=train_meta,
    data_dir=TRAIN_DIR,
    split='train',
    transform=train_transform,
    max_samples=50,
    img_size=512
)

test_dataset = StableECGDataset(
    meta_df=test_meta,
    data_dir=TEST_DIR,
    split='test',
    transform=test_transform,
    img_size=512
)

train_loader = DataLoader(
    train_dataset, 
    batch_size=12, 
    shuffle=True, 
    num_workers=2,
    collate_fn=stable_collate_fn,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=2, 
    shuffle=False, 
    num_workers=2,
    pin_memory=True
)

print(f"Training batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")

In [None]:
num_epochs = 2500
best_loss = float('inf')
best_snr = -float('inf')
train_losses = []
performance_metrics = []

print("Starting stable training...")

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("=" * 50)
    
    start_time = time.time()
    
    avg_loss = train_epoch_stable(model, train_loader, criterion, optimizer, device)
    
    epoch_time = time.time() - start_time
    
    if not np.isinf(avg_loss):
        train_losses.append(avg_loss)
        
        metrics = evaluate_model_stable(model, train_loader, device)
        performance_metrics.append(metrics)
        
        print(f"Average Loss: {avg_loss:.4f} | Time: {epoch_time:.2f}s")
        print(f"SNR: {metrics['snr_mean']:.2f} ± {metrics['snr_std']:.2f} dB")
        print(f"RMSE: {metrics['rmse_mean']:.4f} ± {metrics['rmse_std']:.4f}")
        print(f"Correlation: {metrics['correlation_mean']:.4f} ± {metrics['correlation_std']:.4f}")
        
        if metrics['snr_mean'] > best_snr:
            best_snr = metrics['snr_mean']
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
                'snr': best_snr,
                'metrics': metrics
            }, 'best_stable_model.pth')
            print(f"New best model saved with SNR: {best_snr:.2f} dB")
        
        if avg_loss < best_loss:
            best_loss = avg_loss
    else:
        print(f"Epoch failed - skipping")

In [None]:
def safe_signal_processing(signal, fs, target_lead):
    signal = np.nan_to_num(signal, nan=0.0, posinf=0.0, neginf=0.0)
    
    duration = 10.0 if target_lead == 'II' else 2.5
    expected_length = int(math.floor(fs * duration))
    
    if len(signal) > expected_length:
        signal = signal[:expected_length]
    elif len(signal) < expected_length:
        signal = np.pad(signal, (0, expected_length - len(signal)), mode='edge')
    
    if len(signal) > 11:
        try:
            window_length = min(11, len(signal)//2*2+1)
            if window_length > 3:
                signal = savgol_filter(signal, window_length, 3)
        except:
            pass
    
    signal = np.clip(signal, -2.0, 2.0)
    
    return signal

def predict_signal_stable(model, image, target_lead, fs, device):
    model.eval()
    
    with torch.no_grad():
        if isinstance(image, np.ndarray):
            image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float() / 255.0
        else:
            image_tensor = image.unsqueeze(0) if image.dim() == 3 else image
        
        image_tensor = image_tensor.to(device)
        
        predictions = model(image_tensor)
        
        signal = predictions[target_lead][0].cpu().numpy()
        signal = np.nan_to_num(signal, nan=0.0)
        
        return safe_signal_processing(signal, fs, target_lead)

In [None]:
def generate_stable_submission(model, test_loader, device):
    model.eval()
    all_predictions = []
    
    print("Generating stable submission...")
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Processing"):
            images = batch['image']
            ecg_ids = batch['ecg_id']
            fs_values = batch['fs']
            target_leads = batch['target_lead']
            
            for i in range(len(images)):
                ecg_id = ecg_ids[i]
                fs = fs_values[i]
                target_lead = target_leads[i]
                image = images[i]
                
                try:
                    signal = predict_signal_stable(model, image, target_lead, fs, device)
                    
                    for row_id, value in enumerate(signal):
                        submission_id = f"{ecg_id}_{row_id}_{target_lead}"
                        all_predictions.append({
                            'id': submission_id,
                            'value': float(value)
                        })
                except Exception as e:
                    print(f"Error processing {ecg_id}, {target_lead}: {e}")
                    duration = 10.0 if target_lead == 'II' else 2.5
                    expected_length = int(math.floor(fs * duration))
                    for row_id in range(expected_length):
                        submission_id = f"{ecg_id}_{row_id}_{target_lead}"
                        all_predictions.append({
                            'id': submission_id,
                            'value': 0.0
                        })
    
    return pd.DataFrame(all_predictions)

In [None]:
print("Creating stable submission...")
submission_df = generate_stable_submission(model, test_loader, device)

print(f"Generated {len(submission_df)} predictions")

final_metrics = evaluate_model_stable(model, train_loader, device)
print(f"\nFinal Performance Metrics:")
print(f"SNR: {final_metrics['snr_mean']:.2f} ± {final_metrics['snr_std']:.2f} dB")
print(f"RMSE: {final_metrics['rmse_mean']:.4f} ± {final_metrics['rmse_std']:.4f}")
print(f"Correlation: {final_metrics['correlation_mean']:.4f} ± {final_metrics['correlation_std']:.4f}")

print(f"\nValue statistics:")
print(f"Min: {submission_df['value'].min():.6f}")
print(f"Max: {submission_df['value'].max():.6f}")
print(f"Mean: {submission_df['value'].mean():.6f}")
print(f"Std: {submission_df['value'].std():.6f}")

In [None]:
# Save submission as both Parquet and CSV
submission_df.to_parquet('submission.parquet', index=False)
submission_df.to_csv('submission.csv', index=False)
print("Submission saved as 'submission.parquet' and 'submission.csv'")

# Show lead distribution
lead_counts = submission_df['id'].str.split('_', expand=True)[2].value_counts()
print("\nLead distribution:")
for lead, count in lead_counts.items():
    print(f"  {lead}: {count} samples")

print(f"\nSubmission completed successfully!")
print(f"Best SNR achieved: {best_snr:.2f} dB")
