In [1]:
import torch
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split

# Import the MHC dataset and LSTM model
from torch_dataset import BaseMhcDataset
from models.lstm import AutoencoderLSTM, LSTMTrainer

In [None]:
standardization_df = pd.read_csv("/scratch/users/schuetzn/data/mhc_dataset_out/standardization_params.csv")

scaler_stats = {}
for f_idx, row in standardization_df.iloc[:6].iterrows():
    scaler_stats[f_idx] = (row["mean"], row["std_dev"])

scaler_stats

In [None]:
# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Get input from user
dataset_parquet_path = "/scratch/users/schuetzn/data/mhc_dataset_out/splits/train_dataset.parquet"
root_dir = "/scratch/groups/euan/mhc/mhc_dataset"

print(f"Loading dataset from {dataset_parquet_path}")
print(f"Using root directory: {root_dir}")

# Load the denormalized dataset from parquet
df = pd.read_parquet(dataset_parquet_path)
df["file_uris"] = df["file_uris"].apply(eval)

print(f"Loaded dataset with {len(df)} samples")

# Print available label columns
label_cols = [col for col in df.columns if col.endswith('_value')]
print(f"Available label columns: {label_cols}")

# Create the dataset with mask
dataset = BaseMhcDataset(df, root_dir, include_mask=True, feature_stats=scaler_stats, feature_indices=list(range(6)))

# Split into train and validation sets (80% train, 20% validation)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
print(f"Split dataset into {train_size} training and {val_size} validation samples")

# Create data loaders
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=16)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=16)

# Initialize model
# Set target_labels to None if there are no labels you want to predict
# Or specify the labels you want to predict (without '_value' suffix)
target_labels = [] #[label.replace('_value', '') for label in label_cols[:2]] if label_cols else None

model = AutoencoderLSTM(
    num_features=6,
    hidden_size=256,
    encoding_dim=256,
    num_layers=5,
    dropout=0.1,
    bidirectional=False,
    target_labels=target_labels,
    prediction_horizon=1,
    use_masked_loss=True,
    teacher_forcing_ratio=1
)
print(f"Initialized LSTM model with target labels: {target_labels}")
# Calculate total parameters
total_params = sum(p.numel() for p in model.parameters())

# Calculate trainable parameters (parameters that require gradients)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}") 

# Set up optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Set up trainer (no changes needed here)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
trainer = LSTMTrainer(model, optimizer, device)

# --- Teacher Forcing Decay Parameters ---
initial_tf = 1  # Start with 80% teacher forcing
final_tf = 0.0    # Decay to 0%
decay_epochs = 15 # Decay over 15 epochs (e.g., first 15 epochs)
tf_decay_step = (initial_tf - final_tf) / max(1, decay_epochs) # Avoid division by zero

# Train the model
num_epochs = 20 # Example: Train for 20 epochs total
print(f"Training for {num_epochs} epochs...")
print(f"Teacher forcing decay: {initial_tf*100:.0f}% -> {final_tf*100:.0f}% over {decay_epochs} epochs.")

train_losses = []
val_losses = []

for epoch in range(num_epochs): # epoch will be 0, 1, 2, ...

    # --- Calculate and set teacher forcing ratio for this epoch ---
    if epoch < decay_epochs:
        current_tf_ratio = initial_tf - tf_decay_step * epoch
        # Ensure it doesn't go below the final ratio
        trainer.model.teacher_forcing_ratio = max(final_tf, current_tf_ratio) 
    else:
        # After decay period, keep it at the final ratio
        trainer.model.teacher_forcing_ratio = final_tf
        
    current_tf_ratio_for_print = trainer.model.teacher_forcing_ratio # Get the value actually set

    # --- Train for one epoch (no need to pass epoch here) ---
    train_loss = trainer.train_epoch(train_loader) 
    
    # --- Validate (teacher forcing is off in eval mode anyway) ---
    val_loss = trainer.validate(val_loader)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

# Plot training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()

# Make predictions on a sample
print("Making predictions on a validation sample...")
with torch.no_grad():
    # Get a sample from validation set
    sample_idx = 0
    sample = val_dataset[sample_idx]
    
    # Print metadata for the sample
    print(f"Sample metadata: {sample['metadata']}")
    
    # Prepare batch format
    batch = {
        'data': sample['data'].unsqueeze(0).to(device),  # Add batch dimension
        'mask': sample['mask'].unsqueeze(0).to(device)
    }
    
    # Forward pass
    model.eval()
    output = model(batch)
    
    # Get predictions
    predicted_segments = output['sequence_output'][0].cpu().numpy()  # Remove batch dimension
    target_segments = output['target_segments'][0].cpu().numpy()
    
    # Visualize prediction for a single feature and segment
    feature_idx = 0  # First feature
    segment_idx = 0  # First segment
    
    # Extract the first 30 values
    predicted_values = predicted_segments[segment_idx, :30]
    target_values = target_segments[segment_idx, :30]
    
    plt.figure(figsize=(12, 4))
    plt.plot(predicted_values, label='Predicted', marker='o')
    plt.plot(target_values, label='Target', marker='x')
    plt.xlabel('Time index')
    plt.ylabel('Value')
    plt.title(f'Prediction vs Target for Feature {feature_idx}, Segment {segment_idx}')
    plt.legend()
    plt.grid(True)
    plt.show()

model, trainer  # Return model and trainer

In [None]:
# Make predictions on a sample
print("Making predictions on a validation sample...")
with torch.no_grad():
    # Get a sample from validation set
    sample_idx = 1000
    sample = val_dataset[sample_idx]
    
    # Print metadata for the sample
    print(f"Sample metadata: {sample['metadata']}")
    
    # Prepare batch format
    batch = {
        'data': sample['data'].unsqueeze(0).to(device),  # Add batch dimension
        'mask': sample['mask'].unsqueeze(0).to(device)
    }
    
    # Forward pass
    model.eval()
    output = model(batch)
    
    # Get predictions
    predicted_segments = output['sequence_output'][0].cpu().numpy()  # Remove batch dimension
    target_segments = output['target_segments'][0].cpu().numpy()
    
    # Visualize prediction for a single feature and segment
    feature_idx = 3  # First feature
    segment_idx = 116  # First segment
    
    # Extract the first 30 values
    predicted_values = predicted_segments[segment_idx, :30]
    target_values = target_segments[segment_idx, :30]
    
    plt.figure(figsize=(12, 4))
    plt.plot(predicted_values, label='Predicted', marker='o')
    plt.plot(target_values, label='Target', marker='x')
    plt.xlabel('Time index')
    plt.ylabel('Value')
    plt.title(f'Prediction vs Target for Feature {feature_idx}, Segment {segment_idx}')
    plt.legend()
    plt.grid(True)
    plt.show()

model, trainer  # Return model and trainer

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset, Subset


def visualize_sample_day_feature(
    dataset, 
    sample_idx: int, 
    day_index: int, 
    feature_index: int,
    include_mask_overlay: bool = True
):
    """
    Visualizes the time series data for a specific feature and day within a dataset sample.

    Args:
        dataset: An initialized instance of BaseMhcDataset or a Subset wrapping it.
        sample_idx: The index of the sample to visualize (relative to the dataset/subset).
        day_index: The index of the day within the sample's time range (0-based).
        feature_index: The index of the feature channel to visualize (0-based, typically 0-23).
        include_mask_overlay: If True and the dataset includes masks, highlights the 
                              masked/missing time points on the plot.
    """
    
    # --- Input Validation ---
    if not isinstance(dataset, (Dataset, Subset)):
         print("Error: 'dataset' must be a PyTorch Dataset or Subset.")
         return
        
    if sample_idx < 0 or sample_idx >= len(dataset):
        print(f"Error: sample_idx {sample_idx} is out of bounds for dataset/subset of size {len(dataset)}.")
        return

    # --- Determine the underlying dataset and its properties ---
    # Access the original dataset, whether it's the input or wrapped in a Subset
    original_dataset = dataset.dataset if isinstance(dataset, Subset) else dataset
    # Check include_mask attribute on the original dataset
    dataset_includes_mask = getattr(original_dataset, 'include_mask', False)

    # --- Retrieve Sample ---
    try:
        # Getting the item works the same for Dataset and Subset
        sample = dataset[sample_idx] 
        data_tensor = sample['data'] # Shape: (num_days, 24, 1440)
        metadata = sample.get('metadata', {})
        health_code = metadata.get('healthCode', 'N/A')
        time_range = metadata.get('time_range', 'N/A')
        
        # Check if the *sample* contains a mask AND the *original dataset* was set to include them
        has_mask = 'mask' in sample and dataset_includes_mask
        mask_tensor = sample['mask'] if has_mask else None # Shape: (num_days, 24, 1440) or None
        
    except IndexError:
        print(f"Error: Could not retrieve sample at index {sample_idx}.")
        return
    except Exception as e:
        print(f"An error occurred while retrieving sample {sample_idx}: {e}")
        return

    # --- Data Shape and Index Validation ---
    num_days, num_features, num_minutes = data_tensor.shape
    
    if day_index < 0 or day_index >= num_days:
        print(f"Error: day_index {day_index} is out of bounds. Sample has {num_days} days.")
        return
        
    if feature_index < 0 or feature_index >= num_features:
        print(f"Error: feature_index {feature_index} is out of bounds. Sample has {num_features} features.")
        return
        
    if num_minutes != 1440:
         print(f"Warning: Expected 1440 minutes (time points), but found {num_minutes}.")

    # --- Extract Data and Mask for Plotting ---
    feature_data = data_tensor[day_index, feature_index, :].cpu().numpy()
    time_axis = np.arange(num_minutes)
    
    masked_indices = None
    if include_mask_overlay and has_mask and mask_tensor is not None:
        feature_mask = mask_tensor[day_index, feature_index, :].cpu().numpy()
        # Mask == 0 means the data point is missing/masked
        masked_indices = np.where(feature_mask == 0)[0] 
        print(f"Found {len(masked_indices)} masked points for this feature/day.")


    # --- Plotting ---
    plt.figure(figsize=(15, 5))
    
    # Plot the main data
    plt.plot(time_axis, feature_data, label=f'Feature {feature_index} Data', color='dodgerblue', linewidth=1)
    
    # Overlay masked points if requested and available
    if include_mask_overlay and masked_indices is not None and len(masked_indices) > 0:
        plt.scatter(time_axis[masked_indices], feature_data[masked_indices], 
                    color='red', marker='x', s=20, label='Masked/Missing (Mask=0)', zorder=5)

    plt.title(f'Sample {sample_idx} (HC: {health_code}) - Day {day_index} - Feature {feature_index}\nTime Range: {time_range}')
    plt.xlabel('Minute of Day')
    plt.ylabel('Feature Value (potentially standardized)')
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.legend()
    plt.xlim(0, num_minutes - 1)
    plt.tight_layout()
    plt.show()

In [None]:
visualize_sample_day_feature(val_dataset, sample_idx=1005, day_index=3, feature_index=5, include_mask_overlay=True)