In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import glob
import os
from tqdm import tqdm

class RolloutDataset(Dataset):
    def __init__(self, data_dir):
        self.samples = []
        filepaths = glob.glob(os.path.join(data_dir, '*.pt'))
        # print(filepaths)
        for filepath in tqdm(filepaths, desc="Loading rollouts", unit="file"):
            try:
                print(filepath)
                rollout = torch.load(filepath)
                if len(rollout) < 100:
                    continue
            except RuntimeError:
                tqdm.write(f"Skipping unreadable file {os.path.basename(filepath)}")
                continue

            print(f"rollout length: {len(rollout)}")
            for step in rollout:
                obs = step['obs'].flatten()
                act = step['act'].flatten()
                
                if torch.isnan(obs).any() or torch.isinf(obs).any() \
                or torch.isnan(act).any() or torch.isinf(act).any():
                    continue
                    
                if act.abs().max().item() > 10:
                    continue

                self.samples.append((obs, act))
                
        self.length = len(self.samples)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return self.samples[idx]
    
def kl_divergence_fixed_std(mean_s, mean_t, std_t):
    var = std_t.pow(2)
    return ((mean_s - mean_t).pow(2) / (2*var)).mean()


In [None]:
def train_and_export(dataset,
                     teacher_std_np,
                     epochs=50,
                     batch_size=128,
                     lr=3e-4,
                     in_dim=None,
                     base_model_name="models/tqw_student_policy_base_jit_275.pt",
                     device=None,
                     output_dir='models'):
    device = device or ('cuda' if torch.cuda.is_available() else 'cpu')

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    policy = torch.jit.load(base_model_name)
    # policy = policy.to(dtype=torch.double)
        
    policy.eval() 

    optimizer = optim.Adam(policy.parameters(), lr=lr)
    # teacher_std = torch.tensor(teacher_std_np, device=device).unsqueeze(0).double()
    teacher_std = torch.tensor(teacher_std_np, device=device).unsqueeze(0)

    for ep in range(1, epochs+1):
        policy.train()
        total_loss = 0.0
        for obs_batch, act_batch in loader:
            obs_batch = obs_batch.to(device=device)
            act_batch = act_batch.to(device=device)
            # obs_batch = (obs_batch - obs_mean) / obs_std
            optimizer.zero_grad()
            mean_s = policy(obs_batch)
            loss = kl_divergence_fixed_std(mean_s, act_batch, teacher_std)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * obs_batch.size(0)

        print(f"Epoch {ep}/{epochs}, Loss: {total_loss/len(dataset):.6f}")

    os.makedirs(output_dir, exist_ok=True)

    torch.save(policy.state_dict(), os.path.join(output_dir, 'distilled_policy_test.pth'))
    print("Saved distilled_policy.pth")
    
    example_input = torch.randn(1, in_dim).to(device)
    traced = torch.jit.trace(policy, example_input)
    
    traced.save(os.path.join(output_dir, 'distilled_policy.pt'))
    print("Saved JIT policy with normalization.")

In [None]:
wheeled_data_dir='/root/workspace/personal/akamisaka/quadruped_wheelchairs/D2/datasets/wheeled_mode_dynamics'
walking_data_dir='/root/workspace/personal/akamisaka/quadruped_wheelchairs/D2/datasets/walking_mode_dynamics'

walking_dataset = RolloutDataset(walking_data_dir)
wheeled_dataset = RolloutDataset(wheeled_data_dir)

In [None]:
print(f"num of walking_dataset: {len(walking_dataset)}")
print(f"num of wheeled_dataset: {len(wheeled_dataset)}")

In [None]:
dataset = wheeled_dataset

In [None]:
teacher_std_np = np.array([
    1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1
])
train_and_export(
    dataset=dataset,
    teacher_std_np=teacher_std_np,
    epochs=500,
    batch_size=1024,
    lr=3e-4,
    in_dim=275,
    device=None,
    output_dir='models'
)

In [None]:
import torch
import glob
import os

def check_rollout_file(filepath):
    try:
        rollout = torch.load(filepath)
    except Exception as e:
        print(f"[ERROR] 読み込み失敗: {os.path.basename(filepath)} → {e}")
        return

    obs_list = []
    act_list = []
    for step in rollout:
        obs_list.append(step['obs'].flatten())
        act_list.append(step['act'].flatten())
    obs = torch.stack(obs_list, dim=0)
    act = torch.stack(act_list, dim=0)

    nan_obs  = torch.isnan(obs).any().item()
    inf_obs  = torch.isinf(obs).any().item()
    nan_act  = torch.isnan(act).any().item()
    inf_act  = torch.isinf(act).any().item()

    if nan_obs or inf_obs or nan_act or inf_act:
        print(f"[BAD]  {os.path.basename(filepath)}:", end=" ")
        tags = []
        if nan_obs: tags.append("NaN in obs")
        if inf_obs: tags.append("Inf in obs")
        if nan_act: tags.append("NaN in act")
        if inf_act: tags.append("Inf in act")
        print(", ".join(tags))

if __name__ == "__main__":
    for path in glob.glob(os.path.join(wheeled_data_dir, "*.pt")):
        check_rollout_file(path)

In [None]:
class RolloutDataset(Dataset):
    def __init__(self, data_dir, onehot):
        if not torch.is_tensor(onehot):
            self.onehot = torch.tensor(onehot, dtype=torch.float32)
        else:
            self.onehot = onehot.to(dtype=torch.float32)

        self.samples = []
        filepaths = glob.glob(os.path.join(data_dir, '*.pt'))
        for filepath in tqdm(filepaths, desc="Loading rollouts", unit="file"):
            try:
                rollout = torch.load(filepath)
                if len(rollout) < 100:
                    continue
            except RuntimeError:
                tqdm.write(f"Skipping unreadable file {os.path.basename(filepath)}")
                continue

            for step in rollout:
                obs = step['obs'].flatten()
                obs = torch.cat([obs, self.onehot], dim=0)

                act = step['act'].flatten()
                
                if torch.isnan(obs).any() or torch.isinf(obs).any() \
                or torch.isnan(act).any() or torch.isinf(act).any():
                    continue
                    
                if act.abs().max().item() > 10:
                    continue
                
                self.samples.append((obs, act))

        self.length = len(self.samples)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return self.samples[idx]

In [None]:
wheeled_data_dir='/root/workspace/personal/akamisaka/quadruped_wheelchairs/D2/datasets/wheeled_mode_dynamics'
walking_data_dir='/root/workspace/personal/akamisaka/quadruped_wheelchairs/D2/datasets/walking_mode_dynamics'

walking_dataset = RolloutDataset(walking_data_dir, [0, 1])
wheeled_dataset = RolloutDataset(wheeled_data_dir, [1, 0])

In [None]:
print(f"num of walking_dataset: {len(walking_dataset)}")
print(f"num of wheeled_dataset: {len(wheeled_dataset)}")

In [None]:
dataset = walking_dataset[:1400000] + wheeled_dataset[:1400000]

In [None]:
teacher_std_np = np.array([
    1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1
])
train_and_export(
    dataset=dataset,
    teacher_std_np=teacher_std_np,
    epochs=5000,
    batch_size=10000,
    lr=3e-4,
    in_dim=277,
    base_model_name="models/tqw_student_policy_base_jit_277.pt",
    device=None,
    output_dir='models'
)