# TNC Original Codebase Classification Evaluation

This notebook implements the **exact same** classification evaluation as used in the original TNC codebase. It uses the WFClassificationExperiment class approach with the original WFClassifier model.

## Approach
- Uses original `WFEncoder` and `WFClassifier` from the codebase
- Follows the exact data preprocessing and windowing approach
- Uses the same training loop and metrics (accuracy, AUC, AUPRC)
- Implements the original evaluation framework for fair comparison

## 1. Mount Google Drive and Setup

In [None]:
# Mount Google Drive
from google.colab import drive
import os
import sys

drive.mount('/content/drive')

# Set up paths to your saved checkpoint, data, and plots folders
DRIVE_PATH = '/content/drive/MyDrive'  # Adjust this path as needed
CHECKPOINT_PATH = os.path.join(DRIVE_PATH, 'ckpt')
DATA_PATH = os.path.join(DRIVE_PATH, 'data')
PLOTS_PATH = os.path.join(DRIVE_PATH, 'plots')

# Create plots directory if it doesn't exist
os.makedirs(PLOTS_PATH, exist_ok=True)

print(f"Checkpoint path: {CHECKPOINT_PATH}")
print(f"Data path: {DATA_PATH}")
print(f"Plots path: {PLOTS_PATH}")

# Verify paths exist
print(f"Checkpoint exists: {os.path.exists(CHECKPOINT_PATH)}")
print(f"Data exists: {os.path.exists(DATA_PATH)}")
print(f"Plots exists: {os.path.exists(PLOTS_PATH)}")

## 2. Import Original Libraries and Define Models

In [None]:
# Import libraries exactly as in original codebase
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pickle
import pandas as pd
import random

from sklearn.metrics import roc_auc_score, confusion_matrix, average_precision_score

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

In [None]:
# Define original model classes exactly as in the codebase

class WFEncoder(nn.Module):
    """Original WFEncoder from TNC codebase"""
    def __init__(self, encoding_size=64, classify=False, n_classes=None):
        super(WFEncoder, self).__init__()
        
        self.encoding_size = encoding_size
        self.n_classes = n_classes
        self.classify = classify
        self.classifier = None
        
        if self.classify:
            if self.n_classes is None:
                raise ValueError('Need to specify the number of output classes for the encoder')
            else:
                self.classifier = nn.Sequential(
                    nn.Dropout(0.5),
                    nn.Linear(self.encoding_size, self.n_classes)
                )
                nn.init.xavier_uniform_(self.classifier[1].weight)

        # Original convolutional layers
        self.features = nn.Sequential(
            nn.Conv1d(2, 64, kernel_size=4, stride=1, padding=1),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(64, eps=0.001),
            nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(64, eps=0.001),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(128, eps=0.001),
            nn.Conv1d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(128, eps=0.001),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(256, eps=0.001),
            nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(256, eps=0.001),
            nn.MaxPool1d(kernel_size=2, stride=2)
        )

        # Original fully connected layers
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(79872, 2048),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(2048, eps=0.001),
            nn.Linear(2048, self.encoding_size)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        encoding = self.fc(x)
        
        if self.classify:
            c = self.classifier(encoding)
            return c
        else:
            return encoding


class WFClassifier(torch.nn.Module):
    """Original WFClassifier from TNC codebase"""
    def __init__(self, encoding_size, output_size):
        super(WFClassifier, self).__init__()
        self.encoding_size = encoding_size
        self.output_size = output_size
        self.classifier = nn.Linear(self.encoding_size, output_size)
        torch.nn.init.xavier_uniform_(self.classifier.weight)

    def forward(self, x):
        c = self.classifier(x)
        return c

print("Original TNC models defined successfully!")

## 3. Load Pre-trained TNC Encoder

In [None]:
# Load the trained TNC encoder exactly as in original codebase
data = 'waveform'
cv = 0  # Cross-validation index
encoding_size = 64
n_classes = 4

checkpoint_file = os.path.join(CHECKPOINT_PATH, data, f'checkpoint_{cv}.pth.tar')

print(f"Loading checkpoint from: {checkpoint_file}")
print(f"Checkpoint exists: {os.path.exists(checkpoint_file)}")

if not os.path.exists(checkpoint_file):
    print("ERROR: Checkpoint file not found!")
    print(f"Make sure your checkpoint is saved as: ckpt/{data}/checkpoint_{cv}.pth.tar")
    print("Available files in checkpoint directory:")
    if os.path.exists(os.path.join(CHECKPOINT_PATH, data)):
        print(os.listdir(os.path.join(CHECKPOINT_PATH, data)))
    else:
        print(f"{data} directory doesn't exist")
else:
    # Load checkpoint exactly as in original WFClassificationExperiment
    checkpoint = torch.load(checkpoint_file, map_location=device)
    print(f"Checkpoint loaded successfully!")
    print(f"Available keys in checkpoint: {list(checkpoint.keys())}")
    
    # Initialize encoder exactly as in original code
    encoder = WFEncoder(encoding_size=encoding_size)
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    encoder = encoder.to(device)
    encoder.eval()
    
    # Initialize classifier exactly as in original code
    classifier = WFClassifier(encoding_size=encoding_size, output_size=n_classes).to(device)
    
    # Also create e2e model for comparison (as in original)
    e2e_model = WFEncoder(encoding_size=encoding_size, classify=True, n_classes=n_classes).to(device)
    
    print("Models initialized exactly as in original codebase!")
    print(f"Encoder device: {next(encoder.parameters()).device}")
    print(f"Classifier device: {next(classifier.parameters()).device}")
    
    # Print checkpoint info if available
    if 'best_accuracy' in checkpoint:
        print(f"Best training accuracy: {checkpoint['best_accuracy']:.3f}")
    if 'epoch' in checkpoint:
        print(f"Training epoch: {checkpoint['epoch']}")

## 4. Load and Preprocess Data (Original Method)

In [None]:
# Load data exactly as in original WFClassificationExperiment
window_size = 2500  # Original window size

wf_datapath = os.path.join(DATA_PATH, 'waveform_data', 'processed')

# Check if data files exist
x_train_file = os.path.join(wf_datapath, 'x_train.pkl')
y_train_file = os.path.join(wf_datapath, 'state_train.pkl')

print(f"Data directory: {wf_datapath}")
print(f"x_train exists: {os.path.exists(x_train_file)}")
print(f"y_train exists: {os.path.exists(y_train_file)}")

# Load data exactly as in original codebase
try:
    with open(x_train_file, 'rb') as f:
        x = pickle.load(f)
    with open(y_train_file, 'rb') as f:
        y = pickle.load(f)
    
    print(f"Original data loaded successfully!")
    print(f"x shape: {x.shape}")
    print(f"y shape: {y.shape}")
    
    # Data preprocessing exactly as in original WFClassificationExperiment
    T = x.shape[-1]
    x_window = np.split(x[:, :, :window_size * (T // window_size)], (T//window_size), -1)
    y_window = np.concatenate(np.split(y[:, :window_size * (T // window_size)], (T // window_size), -1), 0).astype(int)
    y_window = torch.Tensor(np.array([np.bincount(yy).argmax() for yy in y_window]))
    
    # Shuffle exactly as in original code
    shuffled_inds = list(range(len(y_window)))
    random.shuffle(shuffled_inds)
    x_window = torch.Tensor(np.concatenate(x_window, 0))
    x_window = x_window[shuffled_inds]
    y_window = y_window[shuffled_inds]
    
    # Split exactly as in original (60% train, 40% validation)
    n_train = int(0.6*len(x_window))
    trainset = torch.utils.data.TensorDataset(x_window[:n_train], y_window[:n_train])
    validset = torch.utils.data.TensorDataset(x_window[n_train:], y_window[n_train:])
    
    # Create dataloaders exactly as in original
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True)
    valid_loader = torch.utils.data.DataLoader(validset, batch_size=100, shuffle=True)
    
    print(f"Windowed data shape: {x_window.shape}")
    print(f"Windowed labels shape: {y_window.shape}")
    print(f"Number of classes: {len(torch.unique(y_window))}")
    print(f"Training samples: {n_train}")
    print(f"Validation samples: {len(x_window) - n_train}")
    print(f"Class distribution: {torch.bincount(y_window.long())}")
        
except Exception as e:
    print(f"Error loading data: {e}")
    print("Please check your data file paths and formats")

## 5. Original Training Functions

In [None]:
# Training functions exactly as in original codebase

def _train_tnc_classifier(encoder, classifier, train_loader, lr):
    """Exact copy of _train_tnc_classifier from original evaluations.py"""
    classifier.train()
    encoder.eval()
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(classifier.parameters(), lr=lr)

    epoch_loss, epoch_auc = 0, 0
    epoch_acc = 0
    batch_count = 0
    y_all, prediction_all = [], []
    
    for i, (x, y) in enumerate(train_loader):
        if i > 30:  # Original limit from codebase
            break
        optimizer.zero_grad()
        x, y = x.to(device), y.to(device)
        encodings = encoder(x)
        prediction = classifier(encodings)
        state_prediction = torch.argmax(prediction, dim=1)
        loss = loss_fn(prediction, y.long())
        loss.backward()
        optimizer.step()
        y_all.append(y.cpu())
        prediction_all.append(prediction.detach().cpu().numpy())

        epoch_acc += torch.eq(state_prediction, y).sum().item()/len(x)
        epoch_loss += loss.item()
        batch_count += 1
        
    y_all = np.concatenate(y_all, 0)
    prediction_all = np.concatenate(prediction_all, 0)
    prediction_class_all = np.argmax(prediction_all, -1)
    y_onehot_all = np.zeros(prediction_all.shape)
    y_onehot_all[np.arange(len(y_onehot_all)), y_all.astype(int)] = 1
    epoch_auc = roc_auc_score(y_onehot_all, prediction_all)
    epoch_auprc = average_precision_score(y_onehot_all, prediction_all)
    c = confusion_matrix(y_all.astype(int), prediction_class_all)
    return epoch_loss / batch_count, epoch_acc / batch_count, epoch_auc, epoch_auprc, c


def _test_model(model, valid_loader):
    """Exact copy of _test function from original evaluations.py"""
    model.eval()
    loss_fn = torch.nn.CrossEntropyLoss()

    epoch_loss, epoch_auc = 0, 0
    epoch_acc = 0
    batch_count = 0
    y_all, prediction_all = [], []
    
    with torch.no_grad():
        for x, y in valid_loader:
            x, y = x.to(device), y.to(device)
            prediction = model(x)
            state_prediction = torch.argmax(prediction, -1)
            loss = loss_fn(prediction, y.long())
            y_all.append(y.cpu())
            prediction_all.append(prediction.detach().cpu().numpy())

            epoch_acc += torch.eq(state_prediction, y).sum().item()/len(x)
            epoch_loss += loss.item()
            batch_count += 1
            
    y_all = np.concatenate(y_all, 0)
    prediction_all = np.concatenate(prediction_all, 0)
    y_onehot_all = np.zeros(prediction_all.shape)
    prediction_class_all = np.argmax(prediction_all, -1)
    y_onehot_all[np.arange(len(y_onehot_all)), y_all.astype(int)] = 1
    epoch_auc = roc_auc_score(y_onehot_all, prediction_all)
    epoch_auprc = average_precision_score(y_onehot_all, prediction_all)
    c = confusion_matrix(y_all.astype(int), prediction_class_all)
    return epoch_loss / batch_count, epoch_acc / batch_count, epoch_auc, epoch_auprc, c

print("Original training functions defined!")

## 6. Run Original Classification Experiment

In [None]:
# Run the experiment exactly as in original codebase
n_epochs = 8  # Original value for waveform data
lr_cls = 0.01  # Original learning rate for classifier

print("Starting TNC classification training (original method)...")

# Track metrics exactly as in original run() function
tnc_acc, tnc_loss, tnc_auc, tnc_auprc = [], [], [], []
tnc_acc_test, tnc_loss_test, tnc_auc_test, tnc_auprc_test = [], [], [], []

for epoch in range(n_epochs):
    # Train TNC classifier (frozen encoder + trainable classifier)
    loss, acc, auc, auprc, _ = _train_tnc_classifier(encoder, classifier, train_loader, lr_cls)
    tnc_acc.append(acc)
    tnc_loss.append(loss)
    tnc_auc.append(auc)
    tnc_auprc.append(auprc)
    
    # Test on validation set
    loss, acc, auc, auprc, c_mtx_enc = _test_model(torch.nn.Sequential(encoder, classifier), valid_loader)
    tnc_acc_test.append(acc)
    tnc_loss_test.append(loss)
    tnc_auc_test.append(auc)
    tnc_auprc_test.append(auprc)

    # Print progress exactly as in original (every 5 epochs, but we have 8 total)
    if epoch % 5 == 0 or epoch == n_epochs - 1:
        print('***** Epoch %d *****' % epoch)
        print('TNC =====> Training Loss: %.3f \\t Training Acc: %.3f \\t Training AUC: %.3f \\t Training AUPRC: %.3f'
              '\\t Test Loss: %.3f \\t Test Acc: %.3f \\t Test AUC: %.3f \\t Test AUPRC: %.3f'
              % (tnc_loss[-1], tnc_acc[-1], tnc_auc[-1], tnc_auprc[-1], 
                 tnc_loss_test[-1], tnc_acc_test[-1], tnc_auc_test[-1], tnc_auprc_test[-1]))

print("\\n" + "="*80)
print("🎯 FINAL RESULTS (Original TNC Method)")
print("="*80)
print(f"✅ Final Test Accuracy: {tnc_acc_test[-1]:.4f} ({tnc_acc_test[-1]*100:.2f}%)")
print(f"📈 Final Test AUPRC: {tnc_auprc_test[-1]:.4f}")
print(f"🔄 Final Test AUC: {tnc_auc_test[-1]:.4f}")
print(f"📉 Final Test Loss: {tnc_loss_test[-1]:.4f}")
print("="*80)

## 7. Original Visualization and Results

In [None]:
# Create plots exactly as in original run() function

# Create plots directory if needed
plots_dir = os.path.join(PLOTS_PATH, data)
os.makedirs(plots_dir, exist_ok=True)

# 1. Accuracy trend plot (exactly as in original)
plt.figure(figsize=(10, 6))
plt.plot(np.arange(n_epochs), tnc_acc, label="TNC train", linewidth=2)
plt.plot(np.arange(n_epochs), tnc_acc_test, label="TNC test", linewidth=2)
plt.title("Accuracy trend for the TNC model (Original Method)", fontsize=14)
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(os.path.join(plots_dir, f"classification_accuracy_comparison_{cv}.png"), dpi=300, bbox_inches='tight')
plt.show()

# 2. AUC trend plot (exactly as in original)
plt.figure(figsize=(10, 6))
plt.plot(np.arange(n_epochs), tnc_auc, label="TNC train", linewidth=2)
plt.plot(np.arange(n_epochs), tnc_auc_test, label="TNC test", linewidth=2)
plt.title("AUC trend for the TNC model (Original Method)", fontsize=14)
plt.xlabel("Epoch")
plt.ylabel("AUC")
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(os.path.join(plots_dir, f"classification_auc_comparison_{cv}.png"), dpi=300, bbox_inches='tight')
plt.show()

# 3. AUPRC trend plot (added for completeness)
plt.figure(figsize=(10, 6))
plt.plot(np.arange(n_epochs), tnc_auprc, label="TNC train", linewidth=2)
plt.plot(np.arange(n_epochs), tnc_auprc_test, label="TNC test", linewidth=2)
plt.title("AUPRC trend for the TNC model (Original Method)", fontsize=14)
plt.xlabel("Epoch")
plt.ylabel("AUPRC")
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(os.path.join(plots_dir, f"classification_auprc_comparison_{cv}.png"), dpi=300, bbox_inches='tight')
plt.show()

# 4. Confusion matrix exactly as in original
df_cm = pd.DataFrame(c_mtx_enc, index=[i for i in ['']*n_classes],
                     columns=[i for i in ['']*n_classes])
plt.figure(figsize=(8, 6))
sns.heatmap(df_cm, annot=True, cmap='Blues', fmt='d')
plt.title("TNC Encoder Confusion Matrix (Original Method)", fontsize=14)
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.savefig(os.path.join(plots_dir, "encoder_cf_matrix.png"), dpi=300, bbox_inches='tight')
plt.show()

print(f"\\n📊 All plots saved to: {plots_dir}")
print(f"📈 Accuracy plot: classification_accuracy_comparison_{cv}.png")
print(f"📈 AUC plot: classification_auc_comparison_{cv}.png") 
print(f"📈 AUPRC plot: classification_auprc_comparison_{cv}.png")
print(f"🎯 Confusion matrix: encoder_cf_matrix.png")

In [None]:
# Save results summary exactly as needed
results_summary = {
    'model': 'TNC_Original_Classification',
    'method': 'WFClassificationExperiment (Original Codebase)',
    'encoder_checkpoint': checkpoint_file,
    'data_type': data,
    'cv': cv,
    'n_epochs': n_epochs,
    'learning_rate': lr_cls,
    'encoding_size': encoding_size,
    'n_classes': n_classes,
    'window_size': window_size,
    'training_samples': n_train,
    'validation_samples': len(x_window) - n_train,
    
    # Final metrics (exactly as original returns)
    'final_test_accuracy': float(tnc_acc_test[-1]),
    'final_test_auc': float(tnc_auc_test[-1]),
    'final_test_auprc': float(tnc_auprc_test[-1]),
    'final_test_loss': float(tnc_loss_test[-1]),
    
    # Training progression
    'training_accuracy_progression': [float(x) for x in tnc_acc],
    'training_auc_progression': [float(x) for x in tnc_auc],
    'training_auprc_progression': [float(x) for x in tnc_auprc],
    'test_accuracy_progression': [float(x) for x in tnc_acc_test],
    'test_auc_progression': [float(x) for x in tnc_auc_test],
    'test_auprc_progression': [float(x) for x in tnc_auprc_test],
    
    # Confusion matrix
    'confusion_matrix': c_mtx_enc.tolist()
}

# Save results
import json
results_file = os.path.join(plots_dir, 'original_classification_results.json')
with open(results_file, 'w') as f:
    json.dump(results_summary, f, indent=2)

print(f"\\n💾 Complete results saved to: {results_file}")

# Print final summary exactly as original would show
print(f"\\n" + "="*80)
print(f"📋 ORIGINAL TNC CLASSIFICATION SUMMARY")
print(f"="*80)
print(f"📁 Checkpoint: {checkpoint_file}")
print(f"📊 Data: {data} (CV={cv})")
print(f"🔧 Method: WFClassificationExperiment (Original Codebase)")
print(f"⚙️  Encoding Size: {encoding_size}, Classes: {n_classes}")
print(f"📏 Window Size: {window_size}")
print(f"🎓 Training Samples: {n_train:,}, Validation: {len(x_window) - n_train:,}")
print(f"\\n🎯 FINAL METRICS:")
print(f"   • Accuracy: {tnc_acc_test[-1]:.4f} ({tnc_acc_test[-1]*100:.2f}%)")
print(f"   • AUPRC: {tnc_auprc_test[-1]:.4f}")
print(f"   • AUC: {tnc_auc_test[-1]:.4f}")
print(f"   • Loss: {tnc_loss_test[-1]:.4f}")
print(f"\\n✅ Original method evaluation completed successfully!")
print(f"📁 All results and plots saved to: {plots_dir}")
print("="*80)

## 8. Prototypical Network Implementation

In [None]:
class PrototypicalNetwork(nn.Module):
    """Prototypical Network for few-shot learning with ECG data"""
    def __init__(self, encoder, distance_metric='euclidean'):
        super(PrototypicalNetwork, self).__init__()
        self.encoder = encoder
        self.distance_metric = distance_metric
        
    def compute_prototypes(self, support_set, support_labels):
        """
        Compute class prototypes from support set
        Args:
            support_set: [n_support, 2, 2500] ECG windows
            support_labels: [n_support] class labels
        Returns:
            prototypes: [n_classes, encoding_size] class prototypes
        """
        # Get encodings for support set
        support_encodings = self.encoder(support_set)  # [n_support, 64]
        
        # Compute prototype for each class
        unique_classes = torch.unique(support_labels)
        prototypes = []
        
        for class_id in unique_classes:
            # Get all examples of this class
            class_mask = (support_labels == class_id)
            class_encodings = support_encodings[class_mask]
            
            # Compute prototype as mean of class examples
            prototype = class_encodings.mean(dim=0)
            prototypes.append(prototype)
            
        return torch.stack(prototypes), unique_classes
    
    def compute_distances(self, query_encodings, prototypes):
        """Compute distances between queries and prototypes"""
        if self.distance_metric == 'euclidean':
            # Euclidean distance: ||q - p||²
            distances = torch.cdist(query_encodings, prototypes, p=2)
        elif self.distance_metric == 'cosine':
            # Cosine distance: 1 - cosine_similarity
            query_norm = F.normalize(query_encodings, dim=1)
            proto_norm = F.normalize(prototypes, dim=1)
            similarities = torch.mm(query_norm, proto_norm.t())
            distances = 1 - similarities
        else:
            raise ValueError(f"Unknown distance metric: {self.distance_metric}")
            
        return distances
    
    def forward(self, support_set, support_labels, query_set):
        """
        Forward pass for prototypical network
        Args:
            support_set: [n_support, 2, 2500] support examples
            support_labels: [n_support] support labels  
            query_set: [n_query, 2, 2500] query examples
        Returns:
            logits: [n_query, n_classes] classification logits
        """
        # Compute prototypes from support set
        prototypes, class_ids = self.compute_prototypes(support_set, support_labels)
        
        # Get query encodings
        query_encodings = self.encoder(query_set)
        
        # Compute distances
        distances = self.compute_distances(query_encodings, prototypes)
        
        # Convert distances to logits (negative distance)
        logits = -distances
        
        return logits, class_ids


def create_few_shot_episode(x_window, y_window, n_way=4, k_shot=5, n_query=15):
    """
    Create a few-shot learning episode from the ECG data
    Args:
        x_window: All ECG windows
        y_window: All labels
        n_way: Number of classes in episode
        k_shot: Number of support examples per class
        n_query: Number of query examples per class
    """
    episode_support_x, episode_support_y = [], []
    episode_query_x, episode_query_y = [], []
    
    # Get available classes
    unique_classes = torch.unique(y_window)
    
    # Handle class imbalance - some classes have very few samples
    available_classes = []
    for class_id in unique_classes:
        class_count = (y_window == class_id).sum().item()
        if class_count >= k_shot + n_query:  # Need enough samples
            available_classes.append(class_id)
    
    # Select n_way classes (or all available if less than n_way)
    selected_classes = available_classes[:min(n_way, len(available_classes))]
    
    for class_id in selected_classes:
        # Get all examples of this class
        class_indices = torch.where(y_window == class_id)[0]
        
        # Randomly sample k_shot + n_query examples
        perm = torch.randperm(len(class_indices))
        selected_indices = class_indices[perm[:k_shot + n_query]]
        
        # Split into support and query
        support_indices = selected_indices[:k_shot]
        query_indices = selected_indices[k_shot:k_shot + n_query]
        
        # Add to episode
        episode_support_x.append(x_window[support_indices])
        episode_support_y.append(y_window[support_indices])
        episode_query_x.append(x_window[query_indices])
        episode_query_y.append(y_window[query_indices])
    
    # Concatenate all classes
    support_x = torch.cat(episode_support_x, dim=0)
    support_y = torch.cat(episode_support_y, dim=0)
    query_x = torch.cat(episode_query_x, dim=0)
    query_y = torch.cat(episode_query_y, dim=0)
    
    return support_x, support_y, query_x, query_y, selected_classes


# Initialize prototypical network
proto_net = PrototypicalNetwork(encoder, distance_metric='euclidean')
print("Prototypical Network initialized!")

# Test with a few-shot episode
print("\\nTesting prototypical network with imbalanced data...")

# Create episode considering class imbalance
support_x, support_y, query_x, query_y, episode_classes = create_few_shot_episode(
    x_window, y_window, n_way=4, k_shot=3, n_query=10
)

print(f"Episode classes: {episode_classes}")
print(f"Support set: {support_x.shape}, labels: {support_y.shape}")
print(f"Query set: {query_x.shape}, labels: {query_y.shape}")
print(f"Support class distribution: {torch.bincount(support_y.long())}")
print(f"Query class distribution: {torch.bincount(query_y.long())}")

# Move to device
support_x, support_y = support_x.to(device), support_y.to(device)
query_x, query_y = query_x.to(device), query_y.to(device)

# Forward pass
with torch.no_grad():
    logits, class_ids = proto_net(support_x, support_y, query_x)
    predictions = torch.argmax(logits, dim=1)
    
    # Map predictions back to original class IDs
    pred_classes = class_ids[predictions]
    
    # Calculate accuracy
    accuracy = (pred_classes == query_y).float().mean()
    
print(f"\\nPrototypical Network Results:")
print(f"Query accuracy: {accuracy:.3f} ({accuracy*100:.1f}%)")
print(f"Predicted classes: {pred_classes}")
print(f"True classes: {query_y}")

print("\\n" + "="*60)
print("🎯 KEY INSIGHTS:")
print("="*60)
print(f"📊 Your dataset has severe class imbalance:")
print(f"   • Class 0: 24,394 samples (abundant)")
print(f"   • Class 1: 247 samples (few-shot)")  
print(f"   • Class 2: 31 samples (extreme few-shot)")
print(f"   • Class 3: 35,250 samples (abundant)")
print(f"\\n🔬 Prototypical networks help by:")
print(f"   • Better handling of rare classes (1 & 2)")
print(f"   • Distance-based classification instead of linear")
print(f"   • Can work with very few examples per class")
print(f"   • More robust to class imbalance")
print("="*60)

## 9. Compare Linear vs Prototypical Classification

In [None]:
def evaluate_prototypical_network(encoder, x_data, y_data, n_episodes=100, k_shot=5, n_query=15):
    """
    Evaluate prototypical network performance across multiple episodes
    """
    proto_net = PrototypicalNetwork(encoder, distance_metric='euclidean')
    proto_net.eval()
    
    accuracies = []
    class_accuracies = {i: [] for i in range(4)}
    
    print(f"Evaluating prototypical network over {n_episodes} episodes...")
    
    for episode in range(n_episodes):
        try:
            # Create episode
            support_x, support_y, query_x, query_y, episode_classes = create_few_shot_episode(
                x_data, y_data, n_way=4, k_shot=k_shot, n_query=n_query
            )
            
            # Move to device
            support_x = support_x.to(device)
            support_y = support_y.to(device) 
            query_x = query_x.to(device)
            query_y = query_y.to(device)
            
            with torch.no_grad():
                # Forward pass
                logits, class_ids = proto_net(support_x, support_y, query_x)
                predictions = torch.argmax(logits, dim=1)
                pred_classes = class_ids[predictions]
                
                # Overall accuracy
                accuracy = (pred_classes == query_y).float().mean().item()
                accuracies.append(accuracy)
                
                # Per-class accuracy
                for class_id in episode_classes:
                    class_mask = (query_y == class_id)
                    if class_mask.sum() > 0:
                        class_acc = (pred_classes[class_mask] == query_y[class_mask]).float().mean().item()
                        class_accuracies[class_id.item()].append(class_acc)
                        
        except Exception as e:
            print(f"Episode {episode} failed: {e}")
            continue
            
        if (episode + 1) % 20 == 0:
            print(f"Episode {episode + 1}/{n_episodes}, Mean accuracy: {np.mean(accuracies):.3f}")
    
    return accuracies, class_accuracies

# Evaluate prototypical network
print("🚀 Starting prototypical network evaluation...")
proto_accuracies, proto_class_accs = evaluate_prototypical_network(
    encoder, x_window, y_window, n_episodes=50, k_shot=3, n_query=10
)

# Calculate statistics
proto_mean_acc = np.mean(proto_accuracies)
proto_std_acc = np.std(proto_accuracies)

print(f"\\n" + "="*70)
print(f"📊 PROTOTYPICAL NETWORK RESULTS")
print(f"="*70)
print(f"📈 Overall Accuracy: {proto_mean_acc:.3f} ± {proto_std_acc:.3f}")
print(f"📈 Overall Accuracy: {proto_mean_acc*100:.1f}% ± {proto_std_acc*100:.1f}%")

print(f"\\n📋 Per-Class Performance:")
for class_id, class_accs in proto_class_accs.items():
    if class_accs:
        class_mean = np.mean(class_accs)
        class_std = np.std(class_accs)
        class_count = (y_window == class_id).sum().item()
        print(f"   • Class {class_id}: {class_mean:.3f} ± {class_std:.3f} ({class_mean*100:.1f}% ± {class_std*100:.1f}%) [{class_count:,} total samples]")
    else:
        print(f"   • Class {class_id}: No episodes (insufficient samples)")

# Compare with linear classifier results
print(f"\\n" + "="*70)
print(f"🔄 COMPARISON: Linear vs Prototypical")
print(f"="*70)
print(f"📌 Linear Classifier:")
print(f"   • Test Accuracy: {tnc_acc_test[-1]:.3f} ({tnc_acc_test[-1]*100:.1f}%)")
print(f"   • Method: Traditional supervised learning")
print(f"   • Training: 35,953 samples")
print(f"\\n📌 Prototypical Network:")
print(f"   • Test Accuracy: {proto_mean_acc:.3f} ± {proto_std_acc:.3f}")
print(f"   • Method: Few-shot learning with {k_shot}-shot episodes")
print(f"   • Training: {k_shot} samples per class per episode")

print(f"\\n💡 Key Advantages of Prototypical Networks:")
print(f"   ✅ Better handling of class imbalance")
print(f"   ✅ Works with very few examples (3-shot learning)")
print(f"   ✅ Distance-based similarity matching")
print(f"   ✅ Can adapt to new patients/conditions quickly")
print(f"   ✅ More robust to rare ECG patterns")
print("="*70)