# Refactored BitrateLSTM Colab Notebook
This notebook sets up the full data pipeline + 3-layer BitrateLSTM training on dummy data.

In [None]:
# Clone your project repository
!git clone https://github.com/YourUsername/split-inference-grpc-demo.git
%cd split-inference-grpc-demo

In [None]:
# Install Python dependencies
!pip install -r requirements.txt
# Ensure pyarrow/parquet support
!pip install pandas pyarrow

## Generate Dummy Data

In [None]:
# Generate a dummy training dataset (parquet)
!python scripts/prepare_data.py --dummy --output data/training_data.parquet

## Load Dataset & DataLoader

In [None]:
import torch
from torch.utils.data import DataLoader
from core.dataset import TraceDataset

# Hyperparameters
SEQ_LEN = 10
BATCH_SIZE = 16

# Initialize dataset & loader
dataset = TraceDataset(
    parquet_path='data/training_data.parquet',
    seq_len=SEQ_LEN,
    normalize=True
)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
print(f'Dataset size: {len(dataset)}, Example X shape: {next(iter(loader))[0].shape}')

## Define & Instantiate Model

In [None]:
from core.model import BitrateLSTM, quantile_loss

# Model hyperparameters
INPUT_SIZE = 1
HIDDEN_SIZE = 128
NUM_LAYERS = 3
NUM_OUTPUTS = 3
DROPOUT = 0.2
LR = 1e-3

# Build model
model = BitrateLSTM(
    input_size=INPUT_SIZE,
    hidden_size=HIDDEN_SIZE,
    num_layers=NUM_LAYERS,
    num_outputs=NUM_OUTPUTS,
    dropout=DROPOUT
)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
print(model)

## Training Loop

In [None]:
EPOCHS = 3
model.train()
for epoch in range(1, EPOCHS + 1):
    total_loss = 0.0
    for X_batch, y_batch in loader:
        preds, _ = model(X_batch)  # (B, 3)
        loss = quantile_loss(preds, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * X_batch.size(0)
    avg_loss = total_loss / len(dataset)
    print(f'Epoch {epoch:02d}, Avg Loss: {avg_loss:.4f}')

## (Optional) Save Model to Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')
# Save weights
torch.save(model.state_dict(), '/content/drive/MyDrive/bitrate_model.pt')