### Load Data and Setup

In [None]:
import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tqdm import tqdm 
from helper_functions import train, create_block_stack, BlockDataset, LoadDataFrames

# Set device to GPU if available, otherwise CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load all data files using the data loader class
data_loader = LoadDataFrames()
data_loader.load_data()
train_df = data_loader.train_df
static_df = data_loader.static_df
test_df = data_loader.load_test_data()
creator_encodings = data_loader.creator_encoding_df

### Data Preprocessing

In [None]:
# Convert dataframes to PyTorch tensors in sequential blocks
features, creator_idx, labels = create_block_stack(features_df=train_df, creator_encodings=creator_encodings, static_df=static_df, stop=20000)

In [None]:
# Create PyTorch dataset from tensors
dataset = BlockDataset(features, creator_idx, labels)

# Split into 80% training, 20% testing
train_blocks, test_blocks = train_test_split(dataset, test_size=0.2, shuffle=True, random_state=42)

# Create data loaders for batch processing during training
train_dataloader = DataLoader(train_blocks, batch_size=32, num_workers=2)
test_dataloader = DataLoader(test_blocks, batch_size=32, shuffle=False, num_workers=2)

### Model Setup

In [None]:
from GRU_model import GRUEmbedding

# Define model hyperparameters
num_features=19       # Number of input features per timestep
hidden_size=512       # Size of GRU hidden state
num_layers=2          # Number of GRU layers
embedding_size=1000   # Size of creator embedding
num_classes=1         # Binary classification output

# Initialize and compile the model
model = GRUEmbedding(num_features=num_features, hidden_size=hidden_size, num_layers=num_layers, num_classes=num_classes, device=device).to(device)
model = torch.compile(model) 
losses = []

### Model Training

In [None]:
# Set training hyperparameters
epochs = 1
learning_rate = 3e-4
optimizer = torch.optim.AdamW

# Train the model and collect training results
results = train(model, train_dataloader, n_epoch=epochs, report_every=1, learning_rate=learning_rate, optimizer=optimizer, device=device)

# Store losses for plotting (assumes losses list exists)
losses.append(results['train_loss'])

### Results and Visualization

In [None]:
# Plot training losses across epochs
plt.plot(losses)

In [None]:
# Compare validation and training losses
plt.plot(results['val_loss'], label='Validation loss')
plt.plot(results['train_loss'], label='Training loss')
plt.legend()

In [None]:
# Display final validation metrics from last epoch
print("\nFinal validation metrics:")
final_metrics = results['val_metrics'][-1]
for metric, value in final_metrics.items():
    print(f"  {metric}: {value:.4f}")

In [None]:
# View all validation metrics for each epoch
results['val_metrics']

### Save Model

In [None]:
# Save the trained model parameters to disk
torch.save(obj=model.state_dict(), # only saving the state_dict() only saves the learned parameters
           f='model.pth')