# Imports

In [1]:
# General + Pytorch training
import os
import torch
import wandb
import pickle
import numpy as np
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
import torch.nn.functional as F
from pytorch_lightning.loggers import WandbLogger
from scipy.spatial.transform import Rotation as R
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

In [2]:
seed = 42
pl.seed_everything(seed)

Seed set to 42


42

# Settings

In [3]:
# Set the float32 matmul precision for Tensor Cores
torch.set_float32_matmul_precision('high')

# Load Data

In [4]:
from pathlib import Path
import pickle

ep_path = Path("../../episode_data")
episodes = []

for file in ep_path.glob("*.pkl"):  # Use "*.pickle" if your files use that extension
    with open(file, "rb") as f:
        episode = pickle.load(f)
        episodes.append(episode)

In [5]:
len(episodes)

3

In [6]:
import numpy as np

X = []
Y = []

for episode in episodes:  # assuming episodes is a list of loaded pickles
    knots_list = episode["knots"]
    
    for kinfo in knots_list:
        # Build input (concatenate position and velocity)
        x = np.concatenate([kinfo["qpos"], kinfo["qvel"]], axis=0)

        # Build output (flattened knot vector)
        y = kinfo["knots"].reshape(-1)

        X.append(x)
        Y.append(y)

X = np.array(X)
Y = np.array(Y)

In [7]:
X.shape

(18465, 95)

In [8]:
Y.shape

(18465, 164)

In [9]:
# Convert to torch
X_t = torch.from_numpy(X).float()
Y_t = torch.from_numpy(Y).float()

# MLP Regressor

In [10]:
class MLPRegressor(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, hidden_dim3, output_dim=1, learning_rate=1e-3):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            nn.BatchNorm1d(hidden_dim1),
            nn.ReLU(),
            
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.BatchNorm1d(hidden_dim2),
            nn.ReLU(),
            
            nn.Linear(hidden_dim2, hidden_dim3),
            nn.BatchNorm1d(hidden_dim3),
            nn.ReLU(),
            
            nn.Linear(hidden_dim3, output_dim)
        )
        self.test_outputs = []
        self.criterion = nn.MSELoss()
        self.learning_rate = learning_rate

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y = y.squeeze(1)
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y = y.squeeze(1)
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log('val_loss', loss)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y = y.squeeze(1)
        y_hat = self(x)
        self.test_outputs.append({'y_hat': y_hat, 'y': y})
        return {'y_hat': y_hat, 'y': y}

    def on_test_epoch_end(self):
        y_hat = torch.cat([o['y_hat'] for o in self.test_outputs], dim=0).cpu().numpy()
        y = torch.cat([o['y'] for o in self.test_outputs], dim=0).cpu().numpy()

        # Flatten if needed: remove extra dimensions
        y_hat = y_hat.squeeze()
        y = y.squeeze()

        # If they're still 3D, reshape explicitly
        if y_hat.ndim > 2:
            y_hat = y_hat.reshape(y_hat.shape[0], -1)
        if y.ndim > 2:
            y = y.reshape(y.shape[0], -1)
        
        mse = mean_squared_error(y, y_hat)
        mae = mean_absolute_error(y, y_hat)
        r2 = r2_score(y, y_hat)
        self.log('test_mse', mse)
        self.log('test_mae', mae)
        self.log('test_r2', r2)
        print(f'Test MSE: {mse}, MAE: {mae}, R2: {r2}')
        self.test_outputs.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.75, patience=5, mode = 'min', cooldown = 2, min_lr = 10e-7)
        optimizer_dict = {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
        return optimizer_dict

# Training

In [11]:
def load_data(inputs, outputs, batch_size=32, num_workers=10):
    X_train, X_temp, y_train, y_temp = train_test_split(inputs, outputs, test_size=0.1, random_state=seed)
    X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=seed)

    train_dataset = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train).unsqueeze(1))
    val_dataset = TensorDataset(torch.Tensor(X_val), torch.Tensor(y_val).unsqueeze(1))
    test_dataset = TensorDataset(torch.Tensor(X_test), torch.Tensor(y_test).unsqueeze(1))
    
    # train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
    # val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True)
    # test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)
    
    return train_loader, val_loader, test_loader

In [12]:
# Define sweep configuration
sweep_config = {
    'method': 'grid',
    'metric': {
        'name': 'val_loss',
        'goal': 'minimize'
    },
    'parameters': {
        'hidden_dim1': {'values': [512]},
        'hidden_dim2': {'values': [512]},
        'hidden_dim3': {'values': [512]},
        'learning_rate': {'values': [1e-3]}
    }
}

In [13]:
# def save_model_as_pt(model, checkpoint_path, save_path):
#     # Load the best model checkpoint
#     checkpoint = torch.load(checkpoint_path)
#     model.load_state_dict(checkpoint['state_dict'])
    
#     # Save the model with torch script
#     model.eval()
#     model_compiled = torch.jit.script(model)
#     torch.jit.save(model_compiled, save_path)
#     print(f'Model saved to {save_path}')

def save_model_state_dict(model, checkpoint_path, save_path):
    # Load the best model checkpoint
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['state_dict'])
    
    # Save only the state_dict (weights)
    torch.save(model.state_dict(), save_path)
    print(f'Model state_dict saved to {save_path}')

In [14]:
def train_model(config=None):
    with wandb.init(config=config):
        config = wandb.config

        model = MLPRegressor(
            input_dim=X_t.shape[1],
            hidden_dim1=config.hidden_dim1,
            hidden_dim2=config.hidden_dim2,
            hidden_dim3=config.hidden_dim3,
            output_dim=Y_t.shape[1],
            learning_rate=config.learning_rate
        )
        
        train_loader, val_loader, test_loader = load_data(X_t, Y_t, batch_size=512, num_workers=10)

        # Define ModelCheckpoint callback to save the best model
        checkpoint_callback = ModelCheckpoint(
            monitor='val_loss',
            dirpath='./knots_prediction_checkpoints',
            filename='model-{epoch:02d}-{val_loss:.6f}',
            save_top_k=1,
            mode='min'
        )

        trainer = pl.Trainer(
            logger=WandbLogger(),
            max_epochs=1000,
            accelerator='gpu',
            devices=1,
            enable_progress_bar=False,  # Suppress the progress bar
            log_every_n_steps=5,  # Adjust logging frequency
            callbacks=[checkpoint_callback, EarlyStopping(monitor='val_loss', patience=5)]
        )
        
        trainer.fit(model, train_loader, val_loader)
        trainer.test(model, dataloaders=test_loader)

        # Save the best model as .pt file
        best_checkpoint_path = checkpoint_callback.best_model_path
        save_path = best_checkpoint_path.replace('.ckpt', '.pt')
        save_model_state_dict(model, best_checkpoint_path, save_path)

In [15]:
# Initialize wandb
wandb.init(project='knots_prediction')

# # Shut down previous sessions
# wandb.finish()

sweep_id = wandb.sweep(sweep_config, project='knots_prediction')

# Run the sweep
wandb.agent(sweep_id, function=train_model)

[34m[1mwandb[0m: Currently logged in as: [33mitaouil[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Create sweep with ID: i802zqwg
Sweep URL: https://wandb.ai/itaouil/knots_prediction/sweeps/i802zqwg


[34m[1mwandb[0m: Agent Starting Run: 61uepuif with config:
[34m[1mwandb[0m: 	hidden_dim1: 512
[34m[1mwandb[0m: 	hidden_dim2: 512
[34m[1mwandb[0m: 	hidden_dim3: 512
[34m[1mwandb[0m: 	learning_rate: 0.001


[1;34mwandb[0m: 
[1;34mwandb[0m: 🚀 View run [33mclear-aardvark-13[0m at: [34mhttps://wandb.ai/itaouil/knots_prediction/runs/pi40me45[0m
[1;34mwandb[0m: Find logs at: [1;35mwandb/run-20250703_133144-pi40me45/logs[0m


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/ilyass/mambaforge/envs/hydrax/lib/python3.12/site-packages/pytorch_lightning/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type       | Params | Mode 
-------------------------------------------------
0 | model     | Sequential | 661 K  | train
1 | criterion | MSELoss    | 0      | train
-------------------------------------------------
661 K     Trainable params
0         Non-trainable params
661 K     Total params
2.647     Total estimated model params size (MB)
12        Modules in train mode
0         Modules in eval mode
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Test MSE: 0.003671003971248865, MAE: 0.044652439653873444, R2: 0.7124336361885071


Model state_dict saved to /home/ilyass/workspace/Text2Motion/learning/notebooks/knots_prediction_checkpoints/model-epoch=72-val_loss=0.003503.pt


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▇▇█████
test_mae,▁
test_mse,▁
test_r2,▁
train_loss,█▇▅▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇███
val_loss,█▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,78.0
test_mae,0.04465
test_mse,0.00367
test_r2,0.71243
train_loss,0.00285
trainer/global_step,2574.0
val_loss,0.00369


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Sweep Agent: Exiting.
