# TNC Linear Evaluation Notebook

This notebook evaluates a pre-trained TNC encoder using linear classification on ECG waveform data. We measure both AUPRC (Area Under Precision-Recall Curve) and Accuracy to assess the quality of learned representations.

## Overview
1. Load pre-trained TNC encoder from checkpoint
2. Extract features from ECG waveform data
3. Train linear classifier on extracted features
4. Evaluate performance with AUPRC and Accuracy metrics
5. Visualize results with confusion matrix and precision-recall curves

## 1. Mount Google Drive and Setup Paths

In [None]:
# Mount Google Drive
from google.colab import drive
import os
import sys

drive.mount('/content/drive')

# Set up paths to your saved checkpoint, data, and plots folders
DRIVE_PATH = '/content/drive/MyDrive'  # Adjust this path as needed
CHECKPOINT_PATH = os.path.join(DRIVE_PATH, 'ckpt')
DATA_PATH = os.path.join(DRIVE_PATH, 'data')
PLOTS_PATH = os.path.join(DRIVE_PATH, 'plots')

# Create plots directory if it doesn't exist
os.makedirs(PLOTS_PATH, exist_ok=True)

print(f"Checkpoint path: {CHECKPOINT_PATH}")
print(f"Data path: {DATA_PATH}")
print(f"Plots path: {PLOTS_PATH}")

# Verify paths exist
print(f"Checkpoint exists: {os.path.exists(CHECKPOINT_PATH)}")
print(f"Data exists: {os.path.exists(DATA_PATH)}")
print(f"Plots exists: {os.path.exists(PLOTS_PATH)}")

## 2. Import Required Libraries

In [None]:
import torch
import torch.nn as nn
import numpy as np
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (accuracy_score, precision_recall_curve, 
                           average_precision_score, roc_auc_score,
                           confusion_matrix, classification_report)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 3. Define TNC Model Architecture

In [None]:
class WFEncoder(nn.Module):
    """TNC Waveform Encoder for ECG data"""
    def __init__(self, encoding_size=64, classify=False, n_classes=None):
        super(WFEncoder, self).__init__()
        
        self.encoding_size = encoding_size
        self.n_classes = n_classes
        self.classify = classify
        self.classifier = None
        
        if self.classify:
            if self.n_classes is None:
                raise ValueError('Need to specify the number of output classes for the encoder')
            else:
                self.classifier = nn.Sequential(
                    nn.Dropout(0.5),
                    nn.Linear(self.encoding_size, self.n_classes)
                )
                nn.init.xavier_uniform_(self.classifier[1].weight)

        # Convolutional feature extractor
        self.features = nn.Sequential(
            nn.Conv1d(2, 64, kernel_size=4, stride=1, padding=1),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(64, eps=0.001),
            nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(64, eps=0.001),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(128, eps=0.001),
            nn.Conv1d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(128, eps=0.001),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(256, eps=0.001),
            nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(256, eps=0.001),
            nn.MaxPool1d(kernel_size=2, stride=2)
        )

        # Fully connected layers
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(79872, 2048),  # Adjust based on your input size
            nn.ELU(inplace=True),
            nn.BatchNorm1d(2048, eps=0.001),
            nn.Linear(2048, self.encoding_size)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        encoding = self.fc(x)
        
        if self.classify:
            c = self.classifier(encoding)
            return c
        else:
            return encoding

print("TNC WFEncoder model defined successfully!")

## 4. Load Trained Encoder from Checkpoint

In [None]:
# Load the trained TNC encoder
checkpoint_file = os.path.join(CHECKPOINT_PATH, 'waveform', 'checkpoint_0.pth.tar')

print(f"Loading checkpoint from: {checkpoint_file}")
print(f"Checkpoint exists: {os.path.exists(checkpoint_file)}")

if not os.path.exists(checkpoint_file):
    print("ERROR: Checkpoint file not found!")
    print("Make sure your checkpoint is saved as: ckpt/waveform/checkpoint_0.pth.tar")
    print("Available files in checkpoint directory:")
    if os.path.exists(os.path.join(CHECKPOINT_PATH, 'waveform')):
        print(os.listdir(os.path.join(CHECKPOINT_PATH, 'waveform')))
    else:
        print("Waveform directory doesn't exist")
else:
    # Load checkpoint
    checkpoint = torch.load(checkpoint_file, map_location=device)
    print(f"Checkpoint loaded successfully!")
    print(f"Available keys in checkpoint: {list(checkpoint.keys())}")
    
    # Initialize encoder
    encoder = WFEncoder(encoding_size=64)
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    encoder = encoder.to(device)
    encoder.eval()  # Set to evaluation mode
    
    print("Encoder loaded and set to evaluation mode!")
    print(f"Encoder is on device: {next(encoder.parameters()).device}")
    
    # Print some checkpoint info if available
    if 'best_accuracy' in checkpoint:
        print(f"Best training accuracy: {checkpoint['best_accuracy']:.3f}")
    if 'epoch' in checkpoint:
        print(f"Training epoch: {checkpoint['epoch']}")

## 5. Load and Prepare Test Data

In [None]:
# Load ECG waveform data
wf_datapath = os.path.join(DATA_PATH, 'waveform_data', 'processed')

# Check if data files exist
x_train_file = os.path.join(wf_datapath, 'x_train.pkl')
y_train_file = os.path.join(wf_datapath, 'state_train.pkl')
x_test_file = os.path.join(wf_datapath, 'x_test.pkl')
y_test_file = os.path.join(wf_datapath, 'state_test.pkl')

print(f"Data directory: {wf_datapath}")
print(f"x_train exists: {os.path.exists(x_train_file)}")
print(f"y_train exists: {os.path.exists(y_train_file)}")
print(f"x_test exists: {os.path.exists(x_test_file)}")
print(f"y_test exists: {os.path.exists(y_test_file)}")

# Load training data for evaluation
try:
    with open(x_train_file, 'rb') as f:
        x_train = pickle.load(f)
    with open(y_train_file, 'rb') as f:
        y_train = pickle.load(f)
    
    print(f"Training data loaded successfully!")
    print(f"x_train shape: {x_train.shape}")
    print(f"y_train shape: {y_train.shape}")
    
    # Also load test data if available
    if os.path.exists(x_test_file) and os.path.exists(y_test_file):
        with open(x_test_file, 'rb') as f:
            x_test = pickle.load(f)
        with open(y_test_file, 'rb') as f:
            y_test = pickle.load(f)
        print(f"Test data loaded successfully!")
        print(f"x_test shape: {x_test.shape}")
        print(f"y_test shape: {y_test.shape}")
    else:
        print("Test data not found, will use train data split")
        x_test, y_test = None, None
        
except Exception as e:
    print(f"Error loading data: {e}")
    print("Please check your data file paths and formats")

In [None]:
# Prepare data for linear classification
window_size = 2500  # Standard window size for waveform data

def prepare_windowed_data(x_data, y_data, window_size):
    """Convert continuous data into windowed segments"""
    T = x_data.shape[-1]
    n_windows = T // window_size
    
    # Reshape into windows
    x_windowed = np.split(x_data[:, :, :window_size * n_windows], n_windows, -1)
    y_windowed = np.split(y_data[:, :window_size * n_windows], n_windows, -1)
    
    # Concatenate all windows
    x_windowed = np.concatenate(x_windowed, 0)
    y_windowed = np.concatenate(y_windowed, 0)
    
    # Get majority vote for each window
    y_windowed = np.array([np.bincount(yy.astype(int)).argmax() for yy in y_windowed])
    
    return x_windowed, y_windowed

# Prepare training data
x_train_windowed, y_train_windowed = prepare_windowed_data(x_train, y_train, window_size)

print(f"Windowed training data shape: {x_train_windowed.shape}")
print(f"Windowed training labels shape: {y_train_windowed.shape}")
print(f"Number of classes: {len(np.unique(y_train_windowed))}")
print(f"Class distribution: {np.bincount(y_train_windowed.astype(int))}")

# Prepare test data if available
if x_test is not None and y_test is not None:
    x_test_windowed, y_test_windowed = prepare_windowed_data(x_test, y_test, window_size)
    print(f"Windowed test data shape: {x_test_windowed.shape}")
    print(f"Windowed test labels shape: {y_test_windowed.shape}")
else:
    # Split training data for evaluation
    split_idx = int(0.7 * len(x_train_windowed))
    x_test_windowed = x_train_windowed[split_idx:]
    y_test_windowed = y_train_windowed[split_idx:]
    x_train_windowed = x_train_windowed[:split_idx]
    y_train_windowed = y_train_windowed[:split_idx]
    
    print(f"Split data - Train: {x_train_windowed.shape}, Test: {x_test_windowed.shape}")

## 6. Extract Features Using Trained Encoder

In [None]:
def extract_features(encoder, data, batch_size=32):
    """Extract features using the trained encoder"""
    encoder.eval()
    features_list = []
    
    with torch.no_grad():
        for i in range(0, len(data), batch_size):
            batch = data[i:i+batch_size]
            batch_tensor = torch.FloatTensor(batch).to(device)
            features = encoder(batch_tensor)
            features_list.append(features.cpu().numpy())
    
    return np.vstack(features_list)

print("Extracting features from training data...")
train_features = extract_features(encoder, x_train_windowed)

print("Extracting features from test data...")
test_features = extract_features(encoder, x_test_windowed)

print(f"Training features shape: {train_features.shape}")
print(f"Test features shape: {test_features.shape}")
print(f"Feature dimension: {train_features.shape[1]}")

# Check for any NaN or infinite values
print(f"Training features - NaN count: {np.isnan(train_features).sum()}")
print(f"Training features - Inf count: {np.isinf(train_features).sum()}")
print(f"Test features - NaN count: {np.isnan(test_features).sum()}")
print(f"Test features - Inf count: {np.isinf(test_features).sum()}")

# Basic statistics
print(f"Training features - Min: {train_features.min():.3f}, Max: {train_features.max():.3f}")
print(f"Training features - Mean: {train_features.mean():.3f}, Std: {train_features.std():.3f}")

# Visualize feature distribution
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.hist(train_features.flatten(), bins=50, alpha=0.7, label='Training Features')
plt.xlabel('Feature Value')
plt.ylabel('Frequency')
plt.title('Distribution of Extracted Features')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_features[:100].T, alpha=0.3)
plt.xlabel('Feature Dimension')
plt.ylabel('Feature Value')
plt.title('Feature Vectors (First 100 samples)')
plt.tight_layout()
plt.savefig(os.path.join(PLOTS_PATH, 'feature_distribution.png'), dpi=300, bbox_inches='tight')
plt.show()

## 7. Train Linear Classifier

In [None]:
# Normalize features for better classifier performance
scaler = StandardScaler()
train_features_scaled = scaler.fit_transform(train_features)
test_features_scaled = scaler.transform(test_features)

print(f"Features normalized - Train mean: {train_features_scaled.mean():.3f}, Test mean: {test_features_scaled.mean():.3f}")

# Train linear classifier (Logistic Regression)
print("Training linear classifier...")

# Try different regularization values to find the best one
from sklearn.model_selection import cross_val_score

C_values = [0.001, 0.01, 0.1, 1.0, 10.0, 100.0]
best_score = 0
best_C = 1.0

for C in C_values:
    classifier = LogisticRegression(C=C, max_iter=1000, random_state=42, multi_class='ovr')
    scores = cross_val_score(classifier, train_features_scaled, y_train_windowed, cv=5, scoring='accuracy')
    mean_score = scores.mean()
    print(f"C={C}: CV Accuracy = {mean_score:.4f} ± {scores.std():.4f}")
    
    if mean_score > best_score:
        best_score = mean_score
        best_C = C

print(f"\\nBest regularization parameter: C = {best_C}")
print(f"Best cross-validation accuracy: {best_score:.4f}")

# Train final classifier with best parameters
final_classifier = LogisticRegression(C=best_C, max_iter=1000, random_state=42, multi_class='ovr')
final_classifier.fit(train_features_scaled, y_train_windowed)

print("Linear classifier training completed!")

## 8. Make Predictions on Test Set

In [None]:
# Make predictions on test set
print("Making predictions on test set...")

# Get predicted labels
y_pred = final_classifier.predict(test_features_scaled)

# Get prediction probabilities for AUPRC calculation
y_pred_proba = final_classifier.predict_proba(test_features_scaled)

print(f"Predictions completed!")
print(f"Test set size: {len(y_test_windowed)}")
print(f"Prediction shape: {y_pred.shape}")
print(f"Probability shape: {y_pred_proba.shape}")

# Check prediction distribution
print(f"\\nPrediction distribution:")
unique_pred, counts_pred = np.unique(y_pred, return_counts=True)
for class_idx, count in zip(unique_pred, counts_pred):
    print(f"Class {class_idx}: {count} predictions ({count/len(y_pred)*100:.1f}%)")

print(f"\\nTrue label distribution:")
unique_true, counts_true = np.unique(y_test_windowed, return_counts=True)
for class_idx, count in zip(unique_true, counts_true):
    print(f"Class {class_idx}: {count} samples ({count/len(y_test_windowed)*100:.1f}%)")

## 9. Calculate AUPRC and Accuracy Scores

In [None]:
# Calculate performance metrics
print("=== PERFORMANCE EVALUATION ===\\n")

# 1. Accuracy Score
accuracy = accuracy_score(y_test_windowed, y_pred)
print(f"🎯 ACCURACY: {accuracy:.4f} ({accuracy*100:.2f}%)")

# 2. AUPRC (Area Under Precision-Recall Curve)
n_classes = len(np.unique(y_test_windowed))

if n_classes == 2:
    # Binary classification
    auprc = average_precision_score(y_test_windowed, y_pred_proba[:, 1])
    print(f"📈 AUPRC (Binary): {auprc:.4f}")
else:
    # Multi-class classification - calculate macro and micro averages
    from sklearn.preprocessing import label_binarize
    from sklearn.metrics import precision_recall_curve, auc
    
    # Binarize labels for multi-class AUPRC
    y_test_bin = label_binarize(y_test_windowed, classes=range(n_classes))
    
    # Micro-average AUPRC
    auprc_micro = average_precision_score(y_test_bin, y_pred_proba, average='micro')
    
    # Macro-average AUPRC
    auprc_macro = average_precision_score(y_test_bin, y_pred_proba, average='macro')
    
    # Per-class AUPRC
    auprc_per_class = []
    for i in range(n_classes):
        if i < y_pred_proba.shape[1]:  # Check if class exists in predictions
            auprc_class = average_precision_score(y_test_bin[:, i], y_pred_proba[:, i])
            auprc_per_class.append(auprc_class)
        else:
            auprc_per_class.append(0.0)
    
    print(f"📈 AUPRC (Micro-avg): {auprc_micro:.4f}")
    print(f"📈 AUPRC (Macro-avg): {auprc_macro:.4f}")
    
    for i, auprc_val in enumerate(auprc_per_class):
        print(f"📈 AUPRC Class {i}: {auprc_val:.4f}")

# 3. Additional metrics
print(f"\\n=== ADDITIONAL METRICS ===")

# ROC AUC (if applicable)
try:
    if n_classes == 2:
        roc_auc = roc_auc_score(y_test_windowed, y_pred_proba[:, 1])
        print(f"🔄 ROC AUC (Binary): {roc_auc:.4f}")
    else:
        roc_auc_micro = roc_auc_score(y_test_bin, y_pred_proba, average='micro', multi_class='ovr')
        roc_auc_macro = roc_auc_score(y_test_bin, y_pred_proba, average='macro', multi_class='ovr')
        print(f"🔄 ROC AUC (Micro-avg): {roc_auc_micro:.4f}")
        print(f"🔄 ROC AUC (Macro-avg): {roc_auc_macro:.4f}")
except Exception as e:
    print(f"⚠️ Could not calculate ROC AUC: {e}")

# Classification report
print(f"\\n=== CLASSIFICATION REPORT ===")
print(classification_report(y_test_windowed, y_pred, digits=4))

## 10. Visualize Results and Performance Metrics

In [None]:
# Create comprehensive visualizations
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# 1. Confusion Matrix
cm = confusion_matrix(y_test_windowed, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0,0])
axes[0,0].set_title('Confusion Matrix')
axes[0,0].set_xlabel('Predicted Label')
axes[0,0].set_ylabel('True Label')

# 2. Precision-Recall Curves
if n_classes == 2:
    # Binary case
    precision, recall, _ = precision_recall_curve(y_test_windowed, y_pred_proba[:, 1])
    axes[0,1].plot(recall, precision, linewidth=2, label=f'AUPRC = {auprc:.3f}')
    axes[0,1].set_xlabel('Recall')
    axes[0,1].set_ylabel('Precision')
    axes[0,1].set_title('Precision-Recall Curve')
    axes[0,1].legend()
    axes[0,1].grid(True, alpha=0.3)
else:
    # Multi-class case - show curves for each class
    for i in range(min(n_classes, y_pred_proba.shape[1])):
        precision, recall, _ = precision_recall_curve(y_test_bin[:, i], y_pred_proba[:, i])
        axes[0,1].plot(recall, precision, linewidth=2, 
                      label=f'Class {i} (AUPRC = {auprc_per_class[i]:.3f})')
    axes[0,1].set_xlabel('Recall')
    axes[0,1].set_ylabel('Precision')
    axes[0,1].set_title('Precision-Recall Curves (Multi-class)')
    axes[0,1].legend()
    axes[0,1].grid(True, alpha=0.3)

# 3. Class Distribution Comparison
x_pos = np.arange(n_classes)
width = 0.35

true_counts = [np.sum(y_test_windowed == i) for i in range(n_classes)]
pred_counts = [np.sum(y_pred == i) for i in range(n_classes)]

axes[1,0].bar(x_pos - width/2, true_counts, width, label='True', alpha=0.8)
axes[1,0].bar(x_pos + width/2, pred_counts, width, label='Predicted', alpha=0.8)
axes[1,0].set_xlabel('Class')
axes[1,0].set_ylabel('Count')
axes[1,0].set_title('True vs Predicted Class Distribution')
axes[1,0].set_xticks(x_pos)
axes[1,0].legend()

# 4. Performance Metrics Summary
metrics_names = ['Accuracy']
metrics_values = [accuracy]

if n_classes == 2:
    metrics_names.extend(['AUPRC', 'ROC AUC'])
    metrics_values.extend([auprc, roc_auc if 'roc_auc' in locals() else 0])
else:
    metrics_names.extend(['AUPRC (Micro)', 'AUPRC (Macro)'])
    metrics_values.extend([auprc_micro, auprc_macro])

bars = axes[1,1].bar(metrics_names, metrics_values, color=['skyblue', 'lightcoral', 'lightgreen'][:len(metrics_values)])
axes[1,1].set_ylabel('Score')
axes[1,1].set_title('Performance Metrics Summary')
axes[1,1].set_ylim(0, 1)

# Add value labels on bars
for bar, value in zip(bars, metrics_values):
    height = bar.get_height()
    axes[1,1].text(bar.get_x() + bar.get_width()/2., height + 0.01,
                   f'{value:.3f}', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.savefig(os.path.join(PLOTS_PATH, 'linear_evaluation_results.png'), dpi=300, bbox_inches='tight')
plt.show()

print(f"\\n📊 Plots saved to: {PLOTS_PATH}")

In [None]:
# Save detailed results to file
results_summary = {
    'model': 'TNC_Linear_Evaluation',
    'encoder_checkpoint': checkpoint_file,
    'test_samples': len(y_test_windowed),
    'feature_dimension': train_features.shape[1],
    'n_classes': n_classes,
    'accuracy': float(accuracy),
    'best_regularization_C': float(best_C),
    'cross_val_accuracy': float(best_score)
}

if n_classes == 2:
    results_summary['auprc'] = float(auprc)
    if 'roc_auc' in locals():
        results_summary['roc_auc'] = float(roc_auc)
else:
    results_summary['auprc_micro'] = float(auprc_micro)
    results_summary['auprc_macro'] = float(auprc_macro)
    results_summary['auprc_per_class'] = [float(x) for x in auprc_per_class]

# Save results
import json
results_file = os.path.join(PLOTS_PATH, 'linear_evaluation_results.json')
with open(results_file, 'w') as f:
    json.dump(results_summary, f, indent=2)

print(f"\\n💾 Results saved to: {results_file}")

# Print final summary
print(f"\\n" + "="*50)
print(f"🎯 FINAL EVALUATION SUMMARY")
print(f"="*50)
print(f"Model: TNC Linear Evaluation")
print(f"Test Samples: {len(y_test_windowed):,}")
print(f"Feature Dimension: {train_features.shape[1]}")
print(f"Number of Classes: {n_classes}")
print(f"\\n📊 MAIN METRICS:")
print(f"   • Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
if n_classes == 2:
    print(f"   • AUPRC: {auprc:.4f}")
else:
    print(f"   • AUPRC (Micro): {auprc_micro:.4f}")
    print(f"   • AUPRC (Macro): {auprc_macro:.4f}")

print(f"\\n✅ Evaluation completed successfully!")
print(f"📁 All results and plots saved to: {PLOTS_PATH}")