In [1]:
import torch
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split

# Import the MHC dataset and LSTM model
from torch_dataset import BaseMhcDataset
from models.lstm import AutoencoderLSTM, LSTMTrainer

In [None]:
standardization_df = pd.read_csv("/scratch/users/schuetzn/data/mhc_dataset_out/standardization_params.csv")

scaler_stats = {}
for f_idx, row in standardization_df.iloc[:6].iterrows():
    scaler_stats[f_idx] = (row["mean"], row["std_dev"])

scaler_stats

In [None]:
# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Get input from user
dataset_parquet_path = "../test_dataset.parquet"
root_dir = "/scratch/groups/euan/mhc/mhc_dataset"

print(f"Loading dataset from {dataset_parquet_path}")
print(f"Using root directory: {root_dir}")

# Load the denormalized dataset from parquet
df = pd.read_parquet(dataset_parquet_path)
df["file_uris"] = df["file_uris"].apply(eval)

print(f"Loaded dataset with {len(df)} samples")

# Print available label columns
label_cols = [col for col in df.columns if col.endswith('_value')]
print(f"Available label columns: {label_cols}")

# Create the dataset with mask
dataset = BaseMhcDataset(df, root_dir, include_mask=True, feature_stats=scaler_stats)

# Split into train and validation sets (80% train, 20% validation)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
print(f"Split dataset into {train_size} training and {val_size} validation samples")

# Create data loaders
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

# Initialize model
# Set target_labels to None if there are no labels you want to predict
# Or specify the labels you want to predict (without '_value' suffix)
target_labels = [] #[label.replace('_value', '') for label in label_cols[:2]] if label_cols else None

model = AutoencoderLSTM(
    hidden_size=256,
    encoding_dim=256,
    num_layers=5,
    dropout=0.1,
    bidirectional=False,
    target_labels=target_labels,
    prediction_horizon=1,
    use_masked_loss=True
)
print(f"Initialized LSTM model with target labels: {target_labels}")

# Set up optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Set up trainer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
trainer = LSTMTrainer(model, optimizer, device)

# Train the model
num_epochs = 5  # Reduced for demonstration
print(f"Training for {num_epochs} epochs...")
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    train_loss = trainer.train_epoch(train_loader)
    val_loss = trainer.validate(val_loader)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

# Plot training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()

# Make predictions on a sample
print("Making predictions on a validation sample...")
with torch.no_grad():
    # Get a sample from validation set
    sample_idx = 0
    sample = val_dataset[sample_idx]
    
    # Print metadata for the sample
    print(f"Sample metadata: {sample['metadata']}")
    
    # Prepare batch format
    batch = {
        'data': sample['data'].unsqueeze(0).to(device),  # Add batch dimension
        'mask': sample['mask'].unsqueeze(0).to(device)
    }
    
    # Forward pass
    model.eval()
    output = model(batch)
    
    # Get predictions
    predicted_segments = output['sequence_output'][0].cpu().numpy()  # Remove batch dimension
    target_segments = output['target_segments'][0].cpu().numpy()
    
    # Visualize prediction for a single feature and segment
    feature_idx = 0  # First feature
    segment_idx = 0  # First segment
    
    # Extract the first 30 values
    predicted_values = predicted_segments[segment_idx, :30]
    target_values = target_segments[segment_idx, :30]
    
    plt.figure(figsize=(12, 4))
    plt.plot(predicted_values, label='Predicted', marker='o')
    plt.plot(target_values, label='Target', marker='x')
    plt.xlabel('Time index')
    plt.ylabel('Value')
    plt.title(f'Prediction vs Target for Feature {feature_idx}, Segment {segment_idx}')
    plt.legend()
    plt.grid(True)
    plt.show()

model, trainer  # Return model and trainer