In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import tqdm
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset


In [None]:
num_seeds = 30
seed = 0
# Load fullstate
data_fullstate = np.empty(num_seeds, dtype=object)
data_no_joint_pos = np.empty(num_seeds, dtype=object)
data_no_joint_vel = np.empty(num_seeds, dtype=object)
data_no_action = np.empty(num_seeds, dtype=object)
data_no_imu = np.empty(num_seeds, dtype=object)
data_no_fc = np.empty(num_seeds, dtype=object)
for i in range(num_seeds):
    data_fullstate[i] = np.load(f"data/HEBB-FULL-STATE_seed-{seed}-fullstate-rand-{i}.npz")    
    data_no_joint_pos[i] = np.load(f"data/HEBB-FULL-STATE_seed-{seed}-no_joint_pos-rand-{i}.npz")
    data_no_joint_vel[i] = np.load(f"data/HEBB-FULL-STATE_seed-{seed}-no_joint_vel-rand-{i}.npz")
    data_no_action[i] = np.load(f"data/HEBB-FULL-STATE_seed-{seed}-no_action-rand-{i}.npz")
    data_no_imu[i] = np.load(f"data/HEBB-FULL-STATE_seed-{seed}-no_imu-rand-{i}.npz")
    data_no_fc[i] = np.load(f"data/HEBB-FULL-STATE_seed-{seed}-no_fc-rand-{i}.npz")

## **Neural network model**

In [148]:
class Discriminator(nn.Module):
    def __init__(self, state_dim=64, action_dim=19, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        self.mu_head = nn.Linear(hidden_dim, action_dim)
        self.var_head = nn.Sequential(
        nn.Linear(hidden_dim, action_dim),
        nn.ReLU(),
        )
    def forward(self, state):
        h = self.net(state)
        
        mu = self.mu_head(h)
        var = self.var_head(h)
        return mu, var


In [166]:
states = torch.tensor(data_fullstate[0]["state"].reshape(-1, 64))
actions = torch.tensor(data_fullstate[0]["action_lowpass"].reshape(-1, 19))

In [165]:
dataset = TensorDataset(states, actions)
loader = DataLoader(dataset, batch_size=64, shuffle=True)
discriminator = Discriminator(state_dim=64, action_dim=19)
optimizer = optim.Adam(discriminator.parameters(), lr=1e-5)

In [159]:
def discriminator_loss(discriminator, states, actions):
    mu, var = discriminator(states)
    var = torch.clamp(var, min=1e-6)  # ป้องกัน var เป็น 0
    dist = torch.distributions.Normal(mu, var.sqrt())
    log_prob = dist.log_prob(actions).sum(dim=-1)
    return -log_prob.mean()

In [162]:
def discriminator_loss(discriminator, states, actions):
    mu, var = discriminator(states)
    var = torch.clamp(var, min=1e-6)
    sigma = var.sqrt()
    dist  = torch.distributions.Normal(mu, sigma)
    log_prob = dist.log_prob(actions).sum(dim=-1)
    return -log_prob.mean()

In [164]:
criterion = nn.MSELoss()          # creates an MSE loss module
def discriminator_loss(discriminator, states, actions):
    mu, var = discriminator(states)
    loss  = criterion(mu, actions)
    return loss

In [167]:
# Train discriminator
# for epoch in tqdm.tqdm(range(100)):
for epoch in range(100):
    for s, a in loader:
        # s, a = s.cuda(), a.cuda()
        loss = discriminator_loss(discriminator, s, a)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}: Discriminator Loss = {loss.item():.4f}")
# กำหนด path ที่จะบันทึก เช่น
MODEL_PATH = "discriminator.pth"
# เซฟ state_dict ของโมเดล
torch.save(discriminator.state_dict(), MODEL_PATH)
print(f"Saved model to {MODEL_PATH}")

Epoch 1: Discriminator Loss = 0.3663
Epoch 2: Discriminator Loss = 0.3493
Epoch 3: Discriminator Loss = 0.3654
Epoch 4: Discriminator Loss = 0.3271
Epoch 5: Discriminator Loss = 0.3590
Epoch 6: Discriminator Loss = 0.3480
Epoch 7: Discriminator Loss = 0.3165
Epoch 8: Discriminator Loss = 0.3413
Epoch 9: Discriminator Loss = 0.3019
Epoch 10: Discriminator Loss = 0.3149
Epoch 11: Discriminator Loss = 0.2948
Epoch 12: Discriminator Loss = 0.2847
Epoch 13: Discriminator Loss = 0.2878
Epoch 14: Discriminator Loss = 0.2695
Epoch 15: Discriminator Loss = 0.2399
Epoch 16: Discriminator Loss = 0.2167
Epoch 17: Discriminator Loss = 0.2318
Epoch 18: Discriminator Loss = 0.2408
Epoch 19: Discriminator Loss = 0.1940
Epoch 20: Discriminator Loss = 0.2000
Epoch 21: Discriminator Loss = 0.1748
Epoch 22: Discriminator Loss = 0.1561
Epoch 23: Discriminator Loss = 0.1624
Epoch 24: Discriminator Loss = 0.1421
Epoch 25: Discriminator Loss = 0.1208
Epoch 26: Discriminator Loss = 0.1210
Epoch 27: Discriminat

In [168]:
discriminator = Discriminator(state_dim=64, action_dim=19)
discriminator.load_state_dict(torch.load("discriminator.pth"))
discriminator.eval()   # หรือ .train() ตามกรณี

Discriminator(
  (net): Sequential(
    (0): Linear(in_features=64, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU()
  )
  (mu_head): Linear(in_features=256, out_features=19, bias=True)
  (var_head): Sequential(
    (0): Linear(in_features=256, out_features=19, bias=True)
    (1): ReLU()
  )
)

In [170]:
print(discriminator(torch.tensor(data_fullstate[0]["state"].reshape(-1, 64))[0]))

(tensor([ 0.0857,  0.0869,  0.0627, -0.0790, -0.0326,  0.0158,  0.0961,  0.0087,
         0.0049, -0.0790,  0.0517, -0.0123, -0.0076,  0.0139,  0.0646, -0.1493,
         0.0455,  0.1211, -0.0767], grad_fn=<ViewBackward0>), tensor([0.1988, 0.0971, 0.1328, 0.0000, 0.0217, 0.0204, 0.0371, 0.0000, 0.0000,
        0.0193, 0.0215, 0.0000, 0.0523, 0.0000, 0.0000, 0.0207, 0.0766, 0.0990,
        0.0967], grad_fn=<ReluBackward0>))


In [171]:
data_fullstate[0]["action_lowpass"].reshape(-1, 19)[0]

array([-1.7569959e-05, -2.5523271e-04,  5.2417530e-04,  5.4440787e-04,
        1.3592407e-04,  2.6334359e-04, -1.3004772e-04, -6.8443049e-05,
       -1.4505014e-04,  1.3148678e-04, -6.6024315e-04, -5.8590475e-04,
        1.9936003e-04, -3.0740726e-04, -3.6030388e-04, -4.2632362e-04,
        1.9244241e-04,  4.3898248e-04,  7.6421195e-05], dtype=float32)