# Imports

In [1]:
# General + Pytorch training
import os
import torch
import wandb
import pickle
import numpy as np
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
import torch.nn.functional as F
from pytorch_lightning.loggers import WandbLogger
from scipy.spatial.transform import Rotation as R
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

In [2]:
seed = 42
pl.seed_everything(seed)

Seed set to 42


42

# Settings

In [3]:
# Set the float32 matmul precision for Tensor Cores
torch.set_float32_matmul_precision('high')

# Load Data

In [4]:
from pathlib import Path
import pickle

ep_path = Path("../episode_data/")
episodes = []

for file in ep_path.glob("*.pkl"):  # Use "*.pickle" if your files use that extension
    with open(file, "rb") as f:
        episode = pickle.load(f)
        episodes.append(episode)

In [5]:
len(episodes[0]["knots"])

3971

In [6]:
episodes[0]["knots"]

[{'knots': array([[ 4.30552214e-02, -8.77010729e-03,  2.45102808e-01,
           2.96669751e-02, -3.11908871e-02,  2.99022999e-02,
           9.08855125e-02,  1.06557943e-01, -2.11720332e-03,
           9.36844051e-02,  7.79056773e-02, -1.98710710e-02,
           8.26807171e-02, -2.09094379e-02,  2.42876872e-01,
           9.14765615e-03,  3.98620330e-02, -3.45996432e-02,
          -2.37988457e-01,  1.50555298e-01, -3.82377356e-01,
          -1.27633354e-02,  1.44044474e-01, -2.95621008e-01,
          -1.31469294e-01, -2.06641793e-01, -2.87596524e-01,
          -2.97058113e-02,  1.56118751e-01, -2.01018095e-01,
           5.12083471e-02, -1.55928537e-01, -7.15920478e-02,
          -1.18688997e-02,  5.03562689e-02,  1.72168035e-02,
          -1.38804927e-01,  1.43858880e-01,  2.93762088e-01,
           2.76415318e-01,  1.02661930e-01],
         [ 1.10030139e-03,  1.29094943e-01, -1.95780903e-01,
           1.66269839e-01,  2.13320598e-01, -9.31192562e-02,
           7.48682246e-02, -3.4

In [7]:
import numpy as np

X = []
Y = []

for episode in episodes:  # assuming episodes is a list of loaded pickles
    knots_list = episode["knots"]
    
    for kinfo in knots_list:
        # Build input (concatenate position and velocity)
        x = np.concatenate([kinfo["qpos"], kinfo["qvel"]], axis=0)

        # Build output (flattened knot vector)
        y = kinfo["knots"].reshape(-1)

        X.append(x)
        Y.append(y)

X = np.array(X)
Y = np.array(Y)

In [8]:
X.shape

(3971, 95)

In [9]:
Y.shape

(3971, 164)

In [12]:
# Convert to torch
X_t = torch.from_numpy(X).float()
Y_t = torch.from_numpy(Y).float()

# MLP Regressor

In [20]:
class MLPRegressor(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, hidden_dim3, output_dim=1, learning_rate=1e-3):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            nn.GELU(),
            
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.GELU(),
            
            nn.Linear(hidden_dim2, hidden_dim3),
            nn.GELU(),
            
            nn.Linear(hidden_dim3, output_dim)
        )
        self.test_outputs = []
        self.criterion = nn.MSELoss()
        self.learning_rate = learning_rate

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log('val_loss', loss)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        self.test_outputs.append({'y_hat': y_hat, 'y': y})
        return {'y_hat': y_hat, 'y': y}

    def on_test_epoch_end(self):
        y_hat = torch.cat([o['y_hat'] for o in self.test_outputs], dim=0).cpu().numpy()
        y = torch.cat([o['y'] for o in self.test_outputs], dim=0).cpu().numpy()

        # Flatten if needed: remove extra dimensions
        y_hat = y_hat.squeeze()
        y = y.squeeze()

        # If they're still 3D, reshape explicitly
        if y_hat.ndim > 2:
            y_hat = y_hat.reshape(y_hat.shape[0], -1)
        if y.ndim > 2:
            y = y.reshape(y.shape[0], -1)
        
        mse = mean_squared_error(y, y_hat)
        mae = mean_absolute_error(y, y_hat)
        r2 = r2_score(y, y_hat)
        self.log('test_mse', mse)
        self.log('test_mae', mae)
        self.log('test_r2', r2)
        print(f'Test MSE: {mse}, MAE: {mae}, R2: {r2}')
        self.test_outputs.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.75, patience=5, mode = 'min', cooldown = 2, min_lr = 10e-7)
        optimizer_dict = {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
        return optimizer_dict

# Training

In [21]:
def load_data(inputs, outputs, batch_size=32, num_workers=10):
    X_train, X_temp, y_train, y_temp = train_test_split(inputs, outputs, test_size=0.1, random_state=seed)
    X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=seed)

    train_dataset = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train).unsqueeze(1))
    val_dataset = TensorDataset(torch.Tensor(X_val), torch.Tensor(y_val).unsqueeze(1))
    test_dataset = TensorDataset(torch.Tensor(X_test), torch.Tensor(y_test).unsqueeze(1))
    
    # train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
    # val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True)
    # test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)
    
    return train_loader, val_loader, test_loader

In [22]:
# Define sweep configuration
sweep_config = {
    'method': 'grid',
    'metric': {
        'name': 'val_loss',
        'goal': 'minimize'
    },
    'parameters': {
        'hidden_dim1': {'values': [1024]},
        'hidden_dim2': {'values': [1024]},
        'hidden_dim3': {'values': [1024]},
        'learning_rate': {'values': [1e-3]}
    }
}

In [23]:
# def save_model_as_pt(model, checkpoint_path, save_path):
#     # Load the best model checkpoint
#     checkpoint = torch.load(checkpoint_path)
#     model.load_state_dict(checkpoint['state_dict'])
    
#     # Save the model with torch script
#     model.eval()
#     model_compiled = torch.jit.script(model)
#     torch.jit.save(model_compiled, save_path)
#     print(f'Model saved to {save_path}')

def save_model_state_dict(model, checkpoint_path, save_path):
    # Load the best model checkpoint
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['state_dict'])
    
    # Save only the state_dict (weights)
    torch.save(model.state_dict(), save_path)
    print(f'Model state_dict saved to {save_path}')

In [24]:
def train_model(config=None):
    with wandb.init(config=config):
        config = wandb.config

        model = MLPRegressor(
            input_dim=X_t.shape[1],
            hidden_dim1=config.hidden_dim1,
            hidden_dim2=config.hidden_dim2,
            hidden_dim3=config.hidden_dim3,
            output_dim=Y_t.shape[1],
            learning_rate=config.learning_rate
        )
        
        train_loader, val_loader, test_loader = load_data(X_t, Y_t, batch_size=512, num_workers=10)

        # Define ModelCheckpoint callback to save the best model
        checkpoint_callback = ModelCheckpoint(
            monitor='val_loss',
            dirpath='./value_function_checkpoints',
            filename='model-{epoch:02d}-{val_loss:.6f}',
            save_top_k=1,
            mode='min'
        )

        trainer = pl.Trainer(
            logger=WandbLogger(),
            max_epochs=100,
            accelerator='gpu',
            devices=1,
            enable_progress_bar=False,  # Suppress the progress bar
            log_every_n_steps=5,  # Adjust logging frequency
            callbacks=[checkpoint_callback, EarlyStopping(monitor='val_loss', patience=5)]
        )
        
        trainer.fit(model, train_loader, val_loader)
        trainer.test(model, dataloaders=test_loader)

        # Save the best model as .pt file
        best_checkpoint_path = checkpoint_callback.best_model_path
        save_path = best_checkpoint_path.replace('.ckpt', '.pt')
        save_model_state_dict(model, best_checkpoint_path, save_path)

In [25]:
# Initialize wandb
wandb.init(project='value-function-regression')

# # Shut down previous sessions
# wandb.finish()

sweep_id = wandb.sweep(sweep_config, project='value-function-regression')

# Run the sweep
wandb.agent(sweep_id, function=train_model)

[34m[1mwandb[0m: Currently logged in as: [33mitaouil[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Create sweep with ID: cpt262ku
Sweep URL: https://wandb.ai/itaouil/value-function-regression/sweeps/cpt262ku


[34m[1mwandb[0m: Agent Starting Run: c6nhegwn with config:
[34m[1mwandb[0m: 	hidden_dim1: 1024
[34m[1mwandb[0m: 	hidden_dim2: 1024
[34m[1mwandb[0m: 	hidden_dim3: 1024
[34m[1mwandb[0m: 	learning_rate: 0.001


[1;34mwandb[0m: 
[1;34mwandb[0m: 🚀 View run [33mastral-darkness-703[0m at: [34mhttps://wandb.ai/itaouil/value-function-regression/runs/olicobjl[0m
[1;34mwandb[0m: Find logs at: [1;35mwandb/run-20250625_092515-olicobjl/logs[0m


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/ilyass/mambaforge/envs/hydrax/lib/python3.12/site-packages/pytorch_lightning/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type       | Params | Mode 
-------------------------------------------------
0 | model     | Sequential | 2.4 M  | train
1 | criterion | MSELoss    | 0      | train
-------------------------------------------------
2.4 M     Trainable params
0         Non-trainable params
2.4 M     Total params
9.462     Total estimated model params size (MB)
9         Modules in train mode
0         Modules in eval mode
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Test MSE: 0.10102978348731995, MAE: 0.1655392348766327, R2: -0.017823172733187675


Model state_dict saved to /home/ilyass/workspace/Text2Motion/notebooks/value_function_checkpoints/model-epoch=11-val_loss=0.099763.pt


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇██
test_mae,▁
test_mse,▁
test_r2,▁
train_loss,█▅▃▃▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
val_loss,█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,17.0
test_mae,0.16554
test_mse,0.10103
test_r2,-0.01782
train_loss,0.10037
trainer/global_step,119.0
val_loss,0.1002


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Sweep Agent: Exiting.
