# MADVRL Poker VAE Training Notebook

## Setup and Imports

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from vae_preprocessing import create_vae_dataloader
from vae import PokerHiddenStateVAE, train_poker_vae

## Hyperparameters and Configuration

In [None]:
# Model Hyperparameters
input_dim = 10
hidden_state_dim = 5
hidden_dims = [64, 128]
latent_dim = 32
batch_size = 256
learning_rate = 1e-3
epochs = 1000

# Device Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Data Preparation

In [None]:
# Create DataLoader
dataloader = create_vae_dataloader('poker_game_metrics_full.parquet')

## Model Initialization

In [None]:
# Initialize model
model = PokerHiddenStateVAE(
    input_dim=input_dim,
    hidden_state_dim=hidden_state_dim,
    hidden_dims=hidden_dims,
    latent_dim=latent_dim,
    beta=0.1,  # KL divergence weight
    alpha=1.0  # Hidden state supervision weight
).to(device)

## Training

In [None]:
# Optimizer and Scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=50, verbose=True
)

# Train Model
metrics = train_poker_vae(
    model, 
    dataloader, 
    optimizer, 
    device=device, 
    epochs=epochs,
    log_interval=50,
    scheduler=scheduler
)

## Visualization and Analysis

In [None]:
# Plot Training Metrics
plt.figure(figsize=(15, 10))

# Total Loss
plt.subplot(2, 2, 1)
plt.plot(metrics['total_loss'], label='Total Loss')
plt.title('Total Loss over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.yscale('log')
plt.legend()

# Reconstruction Loss
plt.subplot(2, 2, 2)
plt.plot(metrics['reconstruction_loss'], label='Reconstruction Loss', color='green')
plt.title('Reconstruction Loss over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.yscale('log')
plt.legend()

# KL Divergence
plt.subplot(2, 2, 3)
plt.plot(metrics['kl_divergence'], label='KL Divergence', color='red')
plt.title('KL Divergence over Epochs')
plt.xlabel('Epoch')
plt.ylabel('KL Divergence')
plt.yscale('log')
plt.legend()

# Hidden State Loss
plt.subplot(2, 2, 4)
plt.plot(metrics['hidden_state_loss'], label='Hidden State Loss', color='purple')
plt.title('Hidden State Loss over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.yscale('log')
plt.legend()

plt.tight_layout()
plt.savefig('training_metrics.png')
plt.show()

## Model Saving

In [None]:
# Save trained model
torch.save(model.state_dict(), 'poker_hidden_state_vae.pth')
print("Model saved successfully!")