In [1]:
import torch
import torch.optim as optim
from RLAlg.buffer.replay_buffer import ReplayBuffer
from RLAlg.alg.gan import GAN
from model import Discriminator

In [2]:
expert_motion_buffer = ReplayBuffer(
            4000,
            50
        )

expert_motion_buffer.load("expert_motion_buffer.pth")

[ReplayBuffer] Loaded buffer from 'expert_motion_buffer.pth' to device 'cpu'


In [3]:
agent_motion_buffer = ReplayBuffer(
            4096,
            100
        )

agent_motion_buffer.load("agent_motion_buffer.pth")

[ReplayBuffer] Loaded buffer from 'agent_motion_buffer.pth' to device 'cpu'


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
discriminator = Discriminator(162).to(device)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=5e-5, weight_decay=1e-3)

In [5]:
for _ in range(10):
    for _ in range(100):
        expert_batch = expert_motion_buffer.sample_tensor("motion_observations", 512).to(device)
        agent_batch = agent_motion_buffer.sample_tensor("motion_observations", 512).to(device)
        d_loss = GAN.compute_bce_loss(discriminator, expert_batch, agent_batch, r1_gamma=10.0)
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
    print(f"D Loss: {d_loss.item()}")

D Loss: 0.6218315362930298
D Loss: 0.37726879119873047
D Loss: 0.30637335777282715
D Loss: 0.25039976835250854
D Loss: 0.24920417368412018
D Loss: 0.20451880991458893
D Loss: 0.20734509825706482
D Loss: 0.17789261043071747
D Loss: 0.1865866631269455
D Loss: 0.1929100751876831


In [6]:
discriminator(expert_batch).value

tensor([3.3458, 3.2041, 3.2586, 3.4311, 3.3480, 3.0870, 2.9051, 3.2744, 3.4602,
        3.2038, 2.9910, 3.0244, 2.7139, 3.0946, 3.2552, 3.3256, 2.4170, 3.1778,
        2.4522, 3.2474, 2.9685, 3.0264, 3.3553, 2.9239, 2.8648, 3.2907, 2.5992,
        3.2912, 2.7473, 2.7147, 3.2496, 3.2725, 2.9023, 3.2700, 3.3540, 2.8471,
        3.3533, 3.3875, 3.1990, 3.0708, 3.1803, 2.7478, 2.7571, 3.2145, 3.0765,
        3.5315, 3.2304, 3.3504, 3.1620, 3.1696, 2.9260, 2.6476, 3.0744, 3.2035,
        3.2984, 3.2674, 3.4514, 3.1752, 3.1358, 3.5312, 3.0930, 3.3094, 3.4476,
        3.3822, 3.2906, 3.3051, 3.3247, 3.2567, 3.0303, 2.5080, 1.2556, 3.1886,
        3.2615, 2.8301, 3.0085, 3.0241, 3.3184, 3.3778, 2.7850, 3.1001, 3.3020,
        3.5498, 3.3863, 3.1668, 2.7846, 3.0536, 3.4325, 2.9898, 3.4032, 3.2030,
        3.2536, 2.8868, 2.4072, 3.2087, 3.5455, 3.2593, 3.2646, 2.9787, 3.2888,
        3.1218, 3.2812, 3.2989, 3.1794, 2.8076, 3.5271, 3.2395, 3.2735, 3.5454,
        2.6076, 3.3423, 3.3833, 3.3731, 

In [7]:
discriminator(agent_batch).value

tensor([-6.0699, -5.8882, -6.5883, -6.2400, -6.1303, -6.4090, -6.4638, -6.4100,
        -5.3395, -6.6100, -6.8150, -6.7138, -4.6272, -6.8212, -6.3236, -6.4284,
        -6.3121, -7.0367, -6.1543, -5.8994, -6.1916, -6.2250, -5.2479, -5.8638,
        -6.4845, -5.7563,  2.0855,  2.5283, -5.4110, -5.8330, -6.7026, -6.0699,
        -6.5025, -6.6413, -5.2223, -6.4798, -6.4142, -5.3408, -6.4076, -6.1247,
         2.1195, -6.7174, -7.0353, -6.3076, -6.0100, -6.8725, -6.6936, -5.9182,
        -6.3658, -6.1500, -6.3888, -5.7856, -5.5573, -6.0058, -5.9831, -7.0124,
        -6.3184, -6.4986, -6.0438, -6.1974, -5.4265, -6.8575,  1.8093, -6.4482,
        -6.6157, -6.3222, -6.4163, -6.6540, -5.7059, -6.7192, -6.7004, -6.4311,
        -6.3334, -6.4189, -6.6007, -6.6673, -6.7685, -6.1694, -6.2645, -6.7084,
        -6.3977, -6.0023, -6.7638, -6.8113, -6.3739, -6.0576, -6.9409, -6.0534,
        -5.9949, -4.2281, -6.2391, -6.3282, -6.1411, -5.5535, -6.3428, -0.0543,
         2.0727, -6.7326, -6.4241, -6.33

In [9]:
-torch.log(1 - 1 / (1 + torch.exp(-discriminator(expert_batch).value)) + 1e-5)

tensor([3.3802, 3.2437, 3.2961, 3.4626, 3.3822, 3.1314, 2.9582, 3.3112, 3.4909,
        3.2433, 3.0398, 3.0717, 2.7780, 3.1387, 3.2927, 3.3606, 2.5023, 3.2184,
        2.5347, 3.2853, 3.0184, 3.0735, 3.3893, 2.9760, 2.9201, 3.3269, 2.6707,
        3.3275, 2.8093, 2.7787, 3.2873, 3.3094, 2.9555, 3.3071, 3.3881, 2.9033,
        3.3874, 3.4204, 3.2387, 3.1159, 3.2208, 2.8098, 2.8184, 3.2536, 3.1214,
        3.5600, 3.2689, 3.3846, 3.2032, 3.2106, 2.9781, 2.7158, 3.1193, 3.2430,
        3.3344, 3.3045, 3.4823, 3.2159, 3.1781, 3.5597, 3.1371, 3.3450, 3.4786,
        3.4153, 3.3269, 3.3408, 3.3598, 3.2942, 3.0772, 2.5862, 1.5062, 3.2287,
        3.2989, 2.8873, 3.0565, 3.0714, 3.3537, 3.4110, 2.8447, 3.1439, 3.3379,
        3.5778, 3.4192, 3.2078, 2.8444, 3.0994, 3.4640, 3.0387, 3.4356, 3.2425,
        3.2913, 2.9409, 2.4933, 3.2481, 3.5736, 3.2967, 3.3018, 3.0281, 3.3251,
        3.1647, 3.3178, 3.3349, 3.2199, 2.8660, 3.5557, 3.2777, 3.3104, 3.5735,
        2.6786, 3.3767, 3.4164, 3.4065, 