In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import gymnasium as gym
from d4rl.infos import REF_MIN_SCORE, REF_MAX_SCORE

from examples.offline.utils import load_buffer_d4rl
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

/data/user/R901105/.conda/envs/dev/lib/python3.11/site-packages/glfw/__init__.py:916: GLFWError: (65544) b'X11: The DISPLAY environment variable is missing'
  File "/data/user/R901105/.conda/envs/dev/lib/python3.11/site-packages/gymnasium/envs/registration.py", line 594, in load_plugin_envs
    fn()
  File "/data/user/R901105/.conda/envs/dev/lib/python3.11/site-packages/shimmy/registration.py", line 262, in register_gymnasium_envs
    _register_dm_control_envs()
  File "/data/user/R901105/.conda/envs/dev/lib/python3.11/site-packages/shimmy/registration.py", line 26, in _register_dm_control_envs
    from shimmy.dm_control_compatibility import DmControlCompatibilityV0
  File "/data/user/R901105/.conda/envs/dev/lib/python3.11/site-packages/shimmy/dm_control_compatibility.py", line 12, in <module>
    import dm_env
ModuleNotFoundError: No module named 'dm_env'
[0m
  logger.warn(f"plugin: {plugin.value} raised {traceback.format_exc()}")


In [2]:
task = "HalfCheetah-v3"
task_data = "halfcheetah-medium-v0"
device = "cuda:1"
learning_rate = 1e-3
batch_size = 1024
hidden_sizes = [512, 512, 512]

In [4]:
REF_MAX_SCORE["halfcheetah-medium-v0"]*66.4/100

8057.640000000001

In [40]:
IDbuffer = load_buffer_d4rl(task_data)

load datafile: 100%|██████████| 5/5 [00:00<00:00,  7.70it/s]


In [41]:
IDdata = np.concatenate((IDbuffer.obs, IDbuffer.act), axis=1)

In [42]:
# load numpy file
# ODdata = np.load("/data/user/R901105/dev/my_fork/tianshou/ODdata_halfcheetah_medium.npy")

In [43]:
import gym
env = gym.make(task)

In [44]:
env.action_space.high, env.action_space.low

(array([1., 1., 1., 1., 1., 1.], dtype=float32),
 array([-1., -1., -1., -1., -1., -1.], dtype=float32))

In [45]:
# generate num_repeat random actions for each observation in IDbuffer
# to use as a negative sample
num_repeat = 10
num_samples = num_repeat*len(IDbuffer)
rand_actions = np.random.uniform(low=env.action_space.low, high=env.action_space.high, size=(num_samples, env.action_space.shape[0])).astype(np.float32)
ODdata = np.concatenate((np.repeat(IDbuffer.obs, num_repeat, axis=0), rand_actions), axis=1)

In [46]:
class MyModel(nn.Module):
    def __init__(self, input_size, hidden_sizes):
        super(MyModel, self).__init__()
        self.input_layer = nn.Linear(input_size, hidden_sizes[0])
        self.input_norm = nn.LayerNorm(hidden_sizes[0])
        self.hidden_layers = nn.ModuleList([
            nn.Linear(hidden_sizes[i], hidden_sizes[i+1])
            for i in range(len(hidden_sizes) - 1)
        ])
        self.hidden_norms = nn.ModuleList([
            nn.LayerNorm(hidden_sizes[i+1])
            for i in range(len(hidden_sizes) - 1)
        ])
        self.output_layer = nn.Linear(hidden_sizes[-1], 1)  # 1 output for energy score
    
    def forward(self, x):
        x = self.input_layer(x)
        x = self.input_norm(x)
        x = torch.relu(x)
        for hidden_layer, hidden_norm in zip(self.hidden_layers, self.hidden_norms):
            x = hidden_layer(x)
            x = hidden_norm(x)
            x = torch.relu(x)
        output = self.output_layer(x)
        return torch.sigmoid(output)

In [47]:
class MyData(Dataset):
    def __init__(self, ID, OD, device=device):
        self.ID = torch.from_numpy(ID)
        self.OD = torch.from_numpy(OD)
        self.X = torch.concatenate((self.ID, self.OD)).to(device)
        self.y = torch.concatenate((torch.ones(len(self.ID)), torch.zeros(len(self.OD)))).unsqueeze(-1).to(device)

    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, index):
        return self.X[index], self.y[index]

In [48]:
train_ID, test_ID = train_test_split(IDdata, test_size=0.2)
train_ID, val_ID = train_test_split(train_ID, test_size=0.2)
train_OD, test_OD = train_test_split(ODdata, test_size=0.2)
train_OD, val_OD = train_test_split(train_OD, test_size=0.2)

In [49]:
train_data = MyData(train_ID, train_OD)
val_data = MyData(val_ID, val_OD)
test_data = MyData(test_ID, test_OD)

In [50]:
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

In [51]:
train_features, train_labels = next(iter(train_dataloader))

In [52]:
input_size = list(train_features[0].shape)[0]

In [53]:
model = MyModel(input_size, hidden_sizes).to(device)
print(f"Model structure: {model}")

Model structure: MyModel(
  (input_layer): Linear(in_features=23, out_features=512, bias=True)
  (input_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (hidden_layers): ModuleList(
    (0-1): 2 x Linear(in_features=512, out_features=512, bias=True)
  )
  (hidden_norms): ModuleList(
    (0-1): 2 x LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (output_layer): Linear(in_features=512, out_features=1, bias=True)
)


In [92]:
loss_fn = nn.BCELoss()

In [93]:
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

In [94]:
from tqdm import tqdm
import time

def train_loop(dataloader, model, loss_fn, optimizer, val_dataloader=None, epochs=1):
    train_losses = []
    val_losses = []
    for epoch in range(epochs):
        model.train()
        total_train_loss = 0
        start_time = time.time()
        for X, y in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            optimizer.zero_grad()  # Zero gradients before forward pass
            pred = model(X)
            loss = loss_fn(pred, y)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(dataloader)
        end_time = time.time()
        elapsed_time = end_time - start_time

        if val_dataloader:
            model.eval()
            total_val_loss = 0
            with torch.no_grad():
                for X_val, y_val in val_dataloader:
                    pred_val = model(X_val)
                    val_loss = loss_fn(pred_val, y_val).item()
                    total_val_loss += val_loss

            avg_val_loss = total_val_loss / len(val_dataloader)
            print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}, Time: {elapsed_time:.2f}s")
    
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)

    return train_losses, val_losses

def plot_losses(train_losses, val_losses):
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Losses')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

In [95]:
train_losses, val_losses = train_loop(train_dataloader, model, loss_fn, optimizer, val_dataloader, epochs=200)

Epoch 1/200: 100%|██████████| 6871/6871 [01:15<00:00, 91.16it/s] 


Epoch 1/200, Train Loss: 0.1871, Validation Loss: 0.1267, Time: 75.38s


Epoch 2/200: 100%|██████████| 6871/6871 [01:14<00:00, 92.84it/s] 


Epoch 2/200, Train Loss: 0.1067, Validation Loss: 0.0911, Time: 74.01s


Epoch 3/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.88it/s] 


Epoch 3/200, Train Loss: 0.0825, Validation Loss: 0.0753, Time: 77.31s


Epoch 4/200: 100%|██████████| 6871/6871 [01:16<00:00, 89.32it/s] 


Epoch 4/200, Train Loss: 0.0704, Validation Loss: 0.0662, Time: 76.93s


Epoch 5/200: 100%|██████████| 6871/6871 [01:20<00:00, 84.88it/s] 


Epoch 5/200, Train Loss: 0.0629, Validation Loss: 0.0605, Time: 80.96s


Epoch 6/200: 100%|██████████| 6871/6871 [01:13<00:00, 93.54it/s] 


Epoch 6/200, Train Loss: 0.0577, Validation Loss: 0.0556, Time: 73.46s


Epoch 7/200: 100%|██████████| 6871/6871 [01:12<00:00, 94.74it/s] 


Epoch 7/200, Train Loss: 0.0537, Validation Loss: 0.0542, Time: 72.53s


Epoch 8/200: 100%|██████████| 6871/6871 [01:15<00:00, 91.17it/s] 


Epoch 8/200, Train Loss: 0.0506, Validation Loss: 0.0498, Time: 75.36s


Epoch 9/200: 100%|██████████| 6871/6871 [01:12<00:00, 95.37it/s] 


Epoch 9/200, Train Loss: 0.0480, Validation Loss: 0.0471, Time: 72.05s


Epoch 10/200: 100%|██████████| 6871/6871 [01:11<00:00, 96.50it/s] 


Epoch 10/200, Train Loss: 0.0459, Validation Loss: 0.0454, Time: 71.21s


Epoch 11/200: 100%|██████████| 6871/6871 [01:12<00:00, 94.75it/s] 


Epoch 11/200, Train Loss: 0.0441, Validation Loss: 0.0435, Time: 72.52s


Epoch 12/200: 100%|██████████| 6871/6871 [01:13<00:00, 93.28it/s] 


Epoch 12/200, Train Loss: 0.0425, Validation Loss: 0.0423, Time: 73.67s


Epoch 13/200: 100%|██████████| 6871/6871 [01:16<00:00, 89.62it/s] 


Epoch 13/200, Train Loss: 0.0411, Validation Loss: 0.0417, Time: 76.67s


Epoch 14/200: 100%|██████████| 6871/6871 [01:13<00:00, 93.15it/s] 


Epoch 14/200, Train Loss: 0.0399, Validation Loss: 0.0398, Time: 73.76s


Epoch 15/200: 100%|██████████| 6871/6871 [01:13<00:00, 92.95it/s] 


Epoch 15/200, Train Loss: 0.0388, Validation Loss: 0.0388, Time: 73.93s


Epoch 16/200: 100%|██████████| 6871/6871 [01:15<00:00, 90.83it/s] 


Epoch 16/200, Train Loss: 0.0378, Validation Loss: 0.0378, Time: 75.65s


Epoch 17/200: 100%|██████████| 6871/6871 [01:14<00:00, 91.67it/s] 


Epoch 17/200, Train Loss: 0.0369, Validation Loss: 0.0370, Time: 74.96s


Epoch 18/200: 100%|██████████| 6871/6871 [01:22<00:00, 83.68it/s] 


Epoch 18/200, Train Loss: 0.0361, Validation Loss: 0.0365, Time: 82.11s


Epoch 19/200: 100%|██████████| 6871/6871 [01:16<00:00, 89.47it/s] 


Epoch 19/200, Train Loss: 0.0353, Validation Loss: 0.0357, Time: 76.80s


Epoch 20/200: 100%|██████████| 6871/6871 [01:14<00:00, 91.67it/s] 


Epoch 20/200, Train Loss: 0.0347, Validation Loss: 0.0352, Time: 74.95s


Epoch 21/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.24it/s] 


Epoch 21/200, Train Loss: 0.0340, Validation Loss: 0.0344, Time: 79.68s


Epoch 22/200: 100%|██████████| 6871/6871 [01:15<00:00, 91.27it/s] 


Epoch 22/200, Train Loss: 0.0334, Validation Loss: 0.0344, Time: 75.28s


Epoch 23/200: 100%|██████████| 6871/6871 [01:20<00:00, 85.26it/s] 


Epoch 23/200, Train Loss: 0.0329, Validation Loss: 0.0340, Time: 80.59s


Epoch 24/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.38it/s] 


Epoch 24/200, Train Loss: 0.0323, Validation Loss: 0.0334, Time: 79.55s


Epoch 25/200: 100%|██████████| 6871/6871 [01:16<00:00, 89.81it/s] 


Epoch 25/200, Train Loss: 0.0318, Validation Loss: 0.0332, Time: 76.51s


Epoch 26/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.57it/s] 


Epoch 26/200, Train Loss: 0.0314, Validation Loss: 0.0325, Time: 77.58s


Epoch 27/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.76it/s] 


Epoch 27/200, Train Loss: 0.0309, Validation Loss: 0.0324, Time: 78.29s


Epoch 28/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.49it/s] 


Epoch 28/200, Train Loss: 0.0305, Validation Loss: 0.0312, Time: 79.45s


Epoch 29/200: 100%|██████████| 6871/6871 [01:19<00:00, 85.99it/s] 


Epoch 29/200, Train Loss: 0.0301, Validation Loss: 0.0308, Time: 79.91s


Epoch 30/200: 100%|██████████| 6871/6871 [01:14<00:00, 91.86it/s] 


Epoch 30/200, Train Loss: 0.0298, Validation Loss: 0.0305, Time: 74.80s


Epoch 31/200: 100%|██████████| 6871/6871 [01:14<00:00, 92.80it/s] 


Epoch 31/200, Train Loss: 0.0294, Validation Loss: 0.0302, Time: 74.05s


Epoch 32/200: 100%|██████████| 6871/6871 [01:20<00:00, 85.71it/s] 


Epoch 32/200, Train Loss: 0.0291, Validation Loss: 0.0309, Time: 80.17s


Epoch 33/200: 100%|██████████| 6871/6871 [01:15<00:00, 90.68it/s] 


Epoch 33/200, Train Loss: 0.0288, Validation Loss: 0.0298, Time: 75.77s


Epoch 34/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.44it/s] 


Epoch 34/200, Train Loss: 0.0285, Validation Loss: 0.0303, Time: 78.58s


Epoch 35/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.30it/s] 


Epoch 35/200, Train Loss: 0.0282, Validation Loss: 0.0302, Time: 77.82s


Epoch 36/200: 100%|██████████| 6871/6871 [01:20<00:00, 85.66it/s] 


Epoch 36/200, Train Loss: 0.0279, Validation Loss: 0.0287, Time: 80.21s


Epoch 37/200: 100%|██████████| 6871/6871 [01:20<00:00, 85.31it/s] 


Epoch 37/200, Train Loss: 0.0276, Validation Loss: 0.0290, Time: 80.55s


Epoch 38/200: 100%|██████████| 6871/6871 [01:16<00:00, 89.72it/s] 


Epoch 38/200, Train Loss: 0.0274, Validation Loss: 0.0286, Time: 76.58s


Epoch 39/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.61it/s] 


Epoch 39/200, Train Loss: 0.0271, Validation Loss: 0.0286, Time: 79.34s


Epoch 40/200: 100%|██████████| 6871/6871 [01:18<00:00, 88.06it/s] 


Epoch 40/200, Train Loss: 0.0269, Validation Loss: 0.0279, Time: 78.03s


Epoch 41/200: 100%|██████████| 6871/6871 [01:18<00:00, 88.04it/s] 


Epoch 41/200, Train Loss: 0.0266, Validation Loss: 0.0280, Time: 78.05s


Epoch 42/200: 100%|██████████| 6871/6871 [01:19<00:00, 85.96it/s] 


Epoch 42/200, Train Loss: 0.0264, Validation Loss: 0.0274, Time: 79.93s


Epoch 43/200: 100%|██████████| 6871/6871 [01:15<00:00, 90.87it/s] 


Epoch 43/200, Train Loss: 0.0262, Validation Loss: 0.0277, Time: 75.61s


Epoch 44/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.92it/s] 


Epoch 44/200, Train Loss: 0.0260, Validation Loss: 0.0272, Time: 77.27s


Epoch 45/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.55it/s] 


Epoch 45/200, Train Loss: 0.0258, Validation Loss: 0.0271, Time: 78.48s


Epoch 46/200: 100%|██████████| 6871/6871 [01:14<00:00, 91.77it/s] 


Epoch 46/200, Train Loss: 0.0256, Validation Loss: 0.0269, Time: 74.88s


Epoch 47/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.94it/s] 


Epoch 47/200, Train Loss: 0.0254, Validation Loss: 0.0267, Time: 79.04s


Epoch 48/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.93it/s] 


Epoch 48/200, Train Loss: 0.0252, Validation Loss: 0.0269, Time: 78.14s


Epoch 49/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.38it/s] 


Epoch 49/200, Train Loss: 0.0251, Validation Loss: 0.0261, Time: 78.64s


Epoch 50/200: 100%|██████████| 6871/6871 [01:23<00:00, 82.63it/s] 


Epoch 50/200, Train Loss: 0.0249, Validation Loss: 0.0267, Time: 83.16s


Epoch 51/200: 100%|██████████| 6871/6871 [01:25<00:00, 80.56it/s] 


Epoch 51/200, Train Loss: 0.0247, Validation Loss: 0.0259, Time: 85.29s


Epoch 52/200: 100%|██████████| 6871/6871 [01:20<00:00, 85.22it/s] 


Epoch 52/200, Train Loss: 0.0246, Validation Loss: 0.0268, Time: 80.63s


Epoch 53/200: 100%|██████████| 6871/6871 [01:20<00:00, 85.35it/s] 


Epoch 53/200, Train Loss: 0.0244, Validation Loss: 0.0258, Time: 80.50s


Epoch 54/200: 100%|██████████| 6871/6871 [01:20<00:00, 85.13it/s] 


Epoch 54/200, Train Loss: 0.0243, Validation Loss: 0.0254, Time: 80.72s


Epoch 55/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.36it/s] 


Epoch 55/200, Train Loss: 0.0241, Validation Loss: 0.0261, Time: 79.56s


Epoch 56/200: 100%|██████████| 6871/6871 [01:20<00:00, 85.60it/s] 


Epoch 56/200, Train Loss: 0.0240, Validation Loss: 0.0253, Time: 80.27s


Epoch 57/200: 100%|██████████| 6871/6871 [01:15<00:00, 90.43it/s] 


Epoch 57/200, Train Loss: 0.0238, Validation Loss: 0.0251, Time: 75.98s


Epoch 58/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.20it/s] 


Epoch 58/200, Train Loss: 0.0237, Validation Loss: 0.0249, Time: 78.80s


Epoch 59/200: 100%|██████████| 6871/6871 [01:16<00:00, 89.32it/s] 


Epoch 59/200, Train Loss: 0.0235, Validation Loss: 0.0248, Time: 76.93s


Epoch 60/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.22it/s] 


Epoch 60/200, Train Loss: 0.0234, Validation Loss: 0.0247, Time: 77.89s


Epoch 61/200: 100%|██████████| 6871/6871 [01:14<00:00, 92.11it/s] 


Epoch 61/200, Train Loss: 0.0233, Validation Loss: 0.0249, Time: 74.59s


Epoch 62/200: 100%|██████████| 6871/6871 [01:23<00:00, 82.12it/s] 


Epoch 62/200, Train Loss: 0.0232, Validation Loss: 0.0247, Time: 83.67s


Epoch 63/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.73it/s] 


Epoch 63/200, Train Loss: 0.0230, Validation Loss: 0.0245, Time: 77.44s


Epoch 64/200: 100%|██████████| 6871/6871 [01:19<00:00, 85.96it/s] 


Epoch 64/200, Train Loss: 0.0229, Validation Loss: 0.0264, Time: 79.93s


Epoch 65/200: 100%|██████████| 6871/6871 [01:21<00:00, 84.52it/s] 


Epoch 65/200, Train Loss: 0.0228, Validation Loss: 0.0242, Time: 81.30s


Epoch 66/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.35it/s] 


Epoch 66/200, Train Loss: 0.0227, Validation Loss: 0.0244, Time: 79.57s


Epoch 67/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.86it/s] 


Epoch 67/200, Train Loss: 0.0226, Validation Loss: 0.0253, Time: 78.21s


Epoch 68/200: 100%|██████████| 6871/6871 [01:21<00:00, 83.88it/s] 


Epoch 68/200, Train Loss: 0.0225, Validation Loss: 0.0240, Time: 81.92s


Epoch 69/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.29it/s] 


Epoch 69/200, Train Loss: 0.0224, Validation Loss: 0.0241, Time: 78.72s


Epoch 70/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.53it/s] 


Epoch 70/200, Train Loss: 0.0223, Validation Loss: 0.0239, Time: 77.61s


Epoch 71/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.68it/s] 


Epoch 71/200, Train Loss: 0.0222, Validation Loss: 0.0237, Time: 77.48s


Epoch 72/200: 100%|██████████| 6871/6871 [01:20<00:00, 85.14it/s] 


Epoch 72/200, Train Loss: 0.0220, Validation Loss: 0.0251, Time: 80.70s


Epoch 73/200: 100%|██████████| 6871/6871 [01:17<00:00, 89.13it/s] 


Epoch 73/200, Train Loss: 0.0220, Validation Loss: 0.0240, Time: 77.09s


Epoch 74/200: 100%|██████████| 6871/6871 [01:21<00:00, 84.47it/s] 


Epoch 74/200, Train Loss: 0.0219, Validation Loss: 0.0251, Time: 81.34s


Epoch 75/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.20it/s] 


Epoch 75/200, Train Loss: 0.0218, Validation Loss: 0.0247, Time: 79.71s


Epoch 76/200: 100%|██████████| 6871/6871 [01:18<00:00, 88.06it/s] 


Epoch 76/200, Train Loss: 0.0217, Validation Loss: 0.0237, Time: 78.02s


Epoch 77/200: 100%|██████████| 6871/6871 [01:20<00:00, 85.27it/s] 


Epoch 77/200, Train Loss: 0.0216, Validation Loss: 0.0235, Time: 80.58s


Epoch 78/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.21it/s] 


Epoch 78/200, Train Loss: 0.0215, Validation Loss: 0.0233, Time: 79.71s


Epoch 79/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.69it/s] 


Epoch 79/200, Train Loss: 0.0214, Validation Loss: 0.0231, Time: 77.48s


Epoch 80/200: 100%|██████████| 6871/6871 [01:20<00:00, 85.62it/s] 


Epoch 80/200, Train Loss: 0.0213, Validation Loss: 0.0234, Time: 80.25s


Epoch 81/200: 100%|██████████| 6871/6871 [01:23<00:00, 81.96it/s] 


Epoch 81/200, Train Loss: 0.0212, Validation Loss: 0.0249, Time: 83.84s


Epoch 82/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.35it/s] 


Epoch 82/200, Train Loss: 0.0212, Validation Loss: 0.0230, Time: 77.77s


Epoch 83/200: 100%|██████████| 6871/6871 [01:20<00:00, 85.86it/s] 


Epoch 83/200, Train Loss: 0.0211, Validation Loss: 0.0230, Time: 80.02s


Epoch 84/200: 100%|██████████| 6871/6871 [01:23<00:00, 81.84it/s] 


Epoch 84/200, Train Loss: 0.0210, Validation Loss: 0.0231, Time: 83.96s


Epoch 85/200: 100%|██████████| 6871/6871 [01:16<00:00, 89.94it/s] 


Epoch 85/200, Train Loss: 0.0209, Validation Loss: 0.0232, Time: 76.42s


Epoch 86/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.99it/s] 


Epoch 86/200, Train Loss: 0.0208, Validation Loss: 0.0247, Time: 78.09s


Epoch 87/200: 100%|██████████| 6871/6871 [01:15<00:00, 90.71it/s] 


Epoch 87/200, Train Loss: 0.0207, Validation Loss: 0.0227, Time: 75.75s


Epoch 88/200: 100%|██████████| 6871/6871 [01:22<00:00, 83.07it/s] 


Epoch 88/200, Train Loss: 0.0207, Validation Loss: 0.0236, Time: 82.72s


Epoch 89/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.80it/s] 


Epoch 89/200, Train Loss: 0.0206, Validation Loss: 0.0227, Time: 78.26s


Epoch 90/200: 100%|██████████| 6871/6871 [01:21<00:00, 84.42it/s] 


Epoch 90/200, Train Loss: 0.0205, Validation Loss: 0.0236, Time: 81.39s


Epoch 91/200: 100%|██████████| 6871/6871 [01:20<00:00, 85.22it/s] 


Epoch 91/200, Train Loss: 0.0204, Validation Loss: 0.0229, Time: 80.63s


Epoch 92/200: 100%|██████████| 6871/6871 [01:15<00:00, 90.76it/s] 


Epoch 92/200, Train Loss: 0.0204, Validation Loss: 0.0223, Time: 75.71s


Epoch 93/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.78it/s] 


Epoch 93/200, Train Loss: 0.0203, Validation Loss: 0.0240, Time: 79.18s


Epoch 94/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.69it/s] 


Epoch 94/200, Train Loss: 0.0202, Validation Loss: 0.0236, Time: 77.47s


Epoch 95/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.17it/s] 


Epoch 95/200, Train Loss: 0.0202, Validation Loss: 0.0228, Time: 78.83s


Epoch 96/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.86it/s] 


Epoch 96/200, Train Loss: 0.0201, Validation Loss: 0.0222, Time: 79.11s


Epoch 97/200: 100%|██████████| 6871/6871 [01:21<00:00, 84.79it/s] 


Epoch 97/200, Train Loss: 0.0200, Validation Loss: 0.0225, Time: 81.04s


Epoch 98/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.24it/s] 


Epoch 98/200, Train Loss: 0.0200, Validation Loss: 0.0233, Time: 78.77s


Epoch 99/200: 100%|██████████| 6871/6871 [01:20<00:00, 85.45it/s] 


Epoch 99/200, Train Loss: 0.0199, Validation Loss: 0.0246, Time: 80.41s


Epoch 100/200: 100%|██████████| 6871/6871 [01:20<00:00, 85.75it/s] 


Epoch 100/200, Train Loss: 0.0198, Validation Loss: 0.0224, Time: 80.13s


Epoch 101/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.27it/s] 


Epoch 101/200, Train Loss: 0.0198, Validation Loss: 0.0224, Time: 79.64s


Epoch 102/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.61it/s] 


Epoch 102/200, Train Loss: 0.0197, Validation Loss: 0.0222, Time: 78.43s


Epoch 103/200: 100%|██████████| 6871/6871 [01:15<00:00, 91.09it/s] 


Epoch 103/200, Train Loss: 0.0196, Validation Loss: 0.0221, Time: 75.43s


Epoch 104/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.65it/s] 


Epoch 104/200, Train Loss: 0.0196, Validation Loss: 0.0217, Time: 78.40s


Epoch 105/200: 100%|██████████| 6871/6871 [01:20<00:00, 85.85it/s] 


Epoch 105/200, Train Loss: 0.0195, Validation Loss: 0.0226, Time: 80.04s


Epoch 106/200: 100%|██████████| 6871/6871 [01:15<00:00, 91.06it/s] 


Epoch 106/200, Train Loss: 0.0195, Validation Loss: 0.0220, Time: 75.46s


Epoch 107/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.08it/s] 


Epoch 107/200, Train Loss: 0.0194, Validation Loss: 0.0229, Time: 78.91s


Epoch 108/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.60it/s] 


Epoch 108/200, Train Loss: 0.0193, Validation Loss: 0.0216, Time: 78.44s


Epoch 109/200: 100%|██████████| 6871/6871 [01:23<00:00, 82.41it/s] 


Epoch 109/200, Train Loss: 0.0193, Validation Loss: 0.0215, Time: 83.37s


Epoch 110/200: 100%|██████████| 6871/6871 [01:21<00:00, 84.20it/s] 


Epoch 110/200, Train Loss: 0.0192, Validation Loss: 0.0220, Time: 81.60s


Epoch 111/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.67it/s] 


Epoch 111/200, Train Loss: 0.0192, Validation Loss: 0.0218, Time: 78.38s


Epoch 112/200: 100%|██████████| 6871/6871 [01:21<00:00, 84.34it/s] 


Epoch 112/200, Train Loss: 0.0191, Validation Loss: 0.0215, Time: 81.47s


Epoch 113/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.37it/s] 


Epoch 113/200, Train Loss: 0.0191, Validation Loss: 0.0216, Time: 79.55s


Epoch 114/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.63it/s] 


Epoch 114/200, Train Loss: 0.0190, Validation Loss: 0.0221, Time: 79.32s


Epoch 115/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.35it/s] 


Epoch 115/200, Train Loss: 0.0190, Validation Loss: 0.0224, Time: 78.67s


Epoch 116/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.50it/s] 


Epoch 116/200, Train Loss: 0.0189, Validation Loss: 0.0218, Time: 77.64s


Epoch 117/200: 100%|██████████| 6871/6871 [01:21<00:00, 84.65it/s] 


Epoch 117/200, Train Loss: 0.0188, Validation Loss: 0.0211, Time: 81.17s


Epoch 118/200: 100%|██████████| 6871/6871 [01:16<00:00, 89.31it/s] 


Epoch 118/200, Train Loss: 0.0188, Validation Loss: 0.0212, Time: 76.93s


Epoch 119/200: 100%|██████████| 6871/6871 [01:17<00:00, 89.01it/s] 


Epoch 119/200, Train Loss: 0.0187, Validation Loss: 0.0212, Time: 77.20s


Epoch 120/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.66it/s] 


Epoch 120/200, Train Loss: 0.0187, Validation Loss: 0.0211, Time: 78.38s


Epoch 121/200: 100%|██████████| 6871/6871 [01:21<00:00, 84.77it/s] 


Epoch 121/200, Train Loss: 0.0186, Validation Loss: 0.0212, Time: 81.06s


Epoch 122/200: 100%|██████████| 6871/6871 [01:20<00:00, 85.77it/s] 


Epoch 122/200, Train Loss: 0.0186, Validation Loss: 0.0214, Time: 80.11s


Epoch 123/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.34it/s] 


Epoch 123/200, Train Loss: 0.0185, Validation Loss: 0.0219, Time: 79.58s


Epoch 124/200: 100%|██████████| 6871/6871 [01:20<00:00, 85.36it/s] 


Epoch 124/200, Train Loss: 0.0185, Validation Loss: 0.0210, Time: 80.50s


Epoch 125/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.38it/s] 


Epoch 125/200, Train Loss: 0.0184, Validation Loss: 0.0209, Time: 77.74s


Epoch 126/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.40it/s] 


Epoch 126/200, Train Loss: 0.0184, Validation Loss: 0.0211, Time: 78.62s


Epoch 127/200: 100%|██████████| 6871/6871 [01:16<00:00, 89.26it/s] 


Epoch 127/200, Train Loss: 0.0184, Validation Loss: 0.0208, Time: 76.98s


Epoch 128/200: 100%|██████████| 6871/6871 [01:16<00:00, 89.90it/s] 


Epoch 128/200, Train Loss: 0.0183, Validation Loss: 0.0219, Time: 76.43s


Epoch 129/200: 100%|██████████| 6871/6871 [01:16<00:00, 89.77it/s] 


Epoch 129/200, Train Loss: 0.0183, Validation Loss: 0.0242, Time: 76.55s


Epoch 130/200: 100%|██████████| 6871/6871 [01:16<00:00, 89.35it/s] 


Epoch 130/200, Train Loss: 0.0182, Validation Loss: 0.0208, Time: 76.90s


Epoch 131/200: 100%|██████████| 6871/6871 [01:16<00:00, 89.98it/s] 


Epoch 131/200, Train Loss: 0.0182, Validation Loss: 0.0207, Time: 76.37s


Epoch 132/200: 100%|██████████| 6871/6871 [01:13<00:00, 92.95it/s] 


Epoch 132/200, Train Loss: 0.0181, Validation Loss: 0.0211, Time: 73.93s


Epoch 133/200: 100%|██████████| 6871/6871 [01:17<00:00, 89.05it/s] 


Epoch 133/200, Train Loss: 0.0181, Validation Loss: 0.0236, Time: 77.16s


Epoch 134/200: 100%|██████████| 6871/6871 [01:16<00:00, 89.58it/s] 


Epoch 134/200, Train Loss: 0.0181, Validation Loss: 0.0217, Time: 76.70s


Epoch 135/200: 100%|██████████| 6871/6871 [01:14<00:00, 92.71it/s] 


Epoch 135/200, Train Loss: 0.0180, Validation Loss: 0.0215, Time: 74.12s


Epoch 136/200: 100%|██████████| 6871/6871 [01:14<00:00, 91.77it/s] 


Epoch 136/200, Train Loss: 0.0179, Validation Loss: 0.0217, Time: 74.88s


Epoch 137/200: 100%|██████████| 6871/6871 [01:14<00:00, 92.33it/s] 


Epoch 137/200, Train Loss: 0.0179, Validation Loss: 0.0208, Time: 74.42s


Epoch 138/200: 100%|██████████| 6871/6871 [01:17<00:00, 89.02it/s] 


Epoch 138/200, Train Loss: 0.0179, Validation Loss: 0.0205, Time: 77.19s


Epoch 139/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.74it/s] 


Epoch 139/200, Train Loss: 0.0178, Validation Loss: 0.0213, Time: 78.31s


Epoch 140/200: 100%|██████████| 6871/6871 [01:19<00:00, 85.95it/s] 


Epoch 140/200, Train Loss: 0.0178, Validation Loss: 0.0203, Time: 79.95s


Epoch 141/200: 100%|██████████| 6871/6871 [01:17<00:00, 89.16it/s] 


Epoch 141/200, Train Loss: 0.0177, Validation Loss: 0.0204, Time: 77.07s


Epoch 142/200: 100%|██████████| 6871/6871 [01:14<00:00, 92.73it/s] 


Epoch 142/200, Train Loss: 0.0177, Validation Loss: 0.0214, Time: 74.10s


Epoch 143/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.96it/s] 


Epoch 143/200, Train Loss: 0.0176, Validation Loss: 0.0206, Time: 77.24s


Epoch 144/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.87it/s] 


Epoch 144/200, Train Loss: 0.0176, Validation Loss: 0.0203, Time: 79.10s


Epoch 145/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.40it/s] 


Epoch 145/200, Train Loss: 0.0176, Validation Loss: 0.0202, Time: 77.73s


Epoch 146/200: 100%|██████████| 6871/6871 [01:16<00:00, 89.70it/s] 


Epoch 146/200, Train Loss: 0.0175, Validation Loss: 0.0204, Time: 76.60s


Epoch 147/200: 100%|██████████| 6871/6871 [01:14<00:00, 92.07it/s] 


Epoch 147/200, Train Loss: 0.0175, Validation Loss: 0.0202, Time: 74.63s


Epoch 148/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.16it/s] 


Epoch 148/200, Train Loss: 0.0175, Validation Loss: 0.0203, Time: 78.84s


Epoch 149/200: 100%|██████████| 6871/6871 [01:14<00:00, 92.63it/s] 


Epoch 149/200, Train Loss: 0.0174, Validation Loss: 0.0206, Time: 74.18s


Epoch 150/200: 100%|██████████| 6871/6871 [01:14<00:00, 92.39it/s] 


Epoch 150/200, Train Loss: 0.0174, Validation Loss: 0.0215, Time: 74.38s


Epoch 151/200: 100%|██████████| 6871/6871 [01:14<00:00, 92.05it/s] 


Epoch 151/200, Train Loss: 0.0173, Validation Loss: 0.0202, Time: 74.65s


Epoch 152/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.07it/s] 


Epoch 152/200, Train Loss: 0.0173, Validation Loss: 0.0214, Time: 79.84s


Epoch 153/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.17it/s] 


Epoch 153/200, Train Loss: 0.0173, Validation Loss: 0.0219, Time: 78.83s


Epoch 154/200: 100%|██████████| 6871/6871 [01:13<00:00, 92.96it/s] 


Epoch 154/200, Train Loss: 0.0172, Validation Loss: 0.0199, Time: 73.91s


Epoch 155/200: 100%|██████████| 6871/6871 [01:14<00:00, 92.50it/s] 


Epoch 155/200, Train Loss: 0.0172, Validation Loss: 0.0202, Time: 74.28s


Epoch 156/200: 100%|██████████| 6871/6871 [01:13<00:00, 93.35it/s] 


Epoch 156/200, Train Loss: 0.0171, Validation Loss: 0.0202, Time: 73.60s


Epoch 157/200: 100%|██████████| 6871/6871 [01:13<00:00, 93.86it/s] 


Epoch 157/200, Train Loss: 0.0171, Validation Loss: 0.0203, Time: 73.20s


Epoch 158/200: 100%|██████████| 6871/6871 [01:12<00:00, 94.28it/s] 


Epoch 158/200, Train Loss: 0.0171, Validation Loss: 0.0209, Time: 72.88s


Epoch 159/200: 100%|██████████| 6871/6871 [01:13<00:00, 92.93it/s] 


Epoch 159/200, Train Loss: 0.0170, Validation Loss: 0.0201, Time: 73.94s


Epoch 160/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.92it/s] 


Epoch 160/200, Train Loss: 0.0170, Validation Loss: 0.0201, Time: 77.29s


Epoch 161/200: 100%|██████████| 6871/6871 [01:15<00:00, 91.59it/s] 


Epoch 161/200, Train Loss: 0.0170, Validation Loss: 0.0200, Time: 75.03s


Epoch 162/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.94it/s] 


Epoch 162/200, Train Loss: 0.0169, Validation Loss: 0.0206, Time: 79.03s


Epoch 163/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.71it/s] 


Epoch 163/200, Train Loss: 0.0169, Validation Loss: 0.0215, Time: 79.24s


Epoch 164/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.69it/s] 


Epoch 164/200, Train Loss: 0.0169, Validation Loss: 0.0214, Time: 77.47s


Epoch 165/200: 100%|██████████| 6871/6871 [01:15<00:00, 90.55it/s] 


Epoch 165/200, Train Loss: 0.0168, Validation Loss: 0.0199, Time: 75.89s


Epoch 166/200: 100%|██████████| 6871/6871 [01:14<00:00, 92.76it/s] 


Epoch 166/200, Train Loss: 0.0168, Validation Loss: 0.0199, Time: 74.07s


Epoch 167/200: 100%|██████████| 6871/6871 [01:15<00:00, 91.58it/s] 


Epoch 167/200, Train Loss: 0.0168, Validation Loss: 0.0206, Time: 75.03s


Epoch 168/200: 100%|██████████| 6871/6871 [01:14<00:00, 92.61it/s] 


Epoch 168/200, Train Loss: 0.0167, Validation Loss: 0.0198, Time: 74.20s


Epoch 169/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.21it/s] 


Epoch 169/200, Train Loss: 0.0167, Validation Loss: 0.0201, Time: 77.90s


Epoch 170/200: 100%|██████████| 6871/6871 [01:12<00:00, 94.53it/s] 


Epoch 170/200, Train Loss: 0.0167, Validation Loss: 0.0200, Time: 72.69s


Epoch 171/200: 100%|██████████| 6871/6871 [01:09<00:00, 98.28it/s] 


Epoch 171/200, Train Loss: 0.0166, Validation Loss: 0.0224, Time: 69.92s


Epoch 172/200: 100%|██████████| 6871/6871 [01:11<00:00, 96.76it/s] 


Epoch 172/200, Train Loss: 0.0166, Validation Loss: 0.0199, Time: 71.01s


Epoch 173/200: 100%|██████████| 6871/6871 [01:11<00:00, 95.51it/s] 


Epoch 173/200, Train Loss: 0.0166, Validation Loss: 0.0198, Time: 71.94s


Epoch 174/200: 100%|██████████| 6871/6871 [01:15<00:00, 91.25it/s] 


Epoch 174/200, Train Loss: 0.0165, Validation Loss: 0.0204, Time: 75.30s


Epoch 175/200: 100%|██████████| 6871/6871 [01:14<00:00, 91.74it/s] 


Epoch 175/200, Train Loss: 0.0165, Validation Loss: 0.0216, Time: 74.90s


Epoch 176/200: 100%|██████████| 6871/6871 [01:15<00:00, 91.48it/s] 


Epoch 176/200, Train Loss: 0.0165, Validation Loss: 0.0198, Time: 75.11s


Epoch 177/200: 100%|██████████| 6871/6871 [01:15<00:00, 91.40it/s] 


Epoch 177/200, Train Loss: 0.0164, Validation Loss: 0.0200, Time: 75.18s


Epoch 178/200: 100%|██████████| 6871/6871 [01:15<00:00, 90.48it/s] 


Epoch 178/200, Train Loss: 0.0164, Validation Loss: 0.0198, Time: 75.94s


Epoch 179/200: 100%|██████████| 6871/6871 [01:14<00:00, 91.73it/s] 


Epoch 179/200, Train Loss: 0.0164, Validation Loss: 0.0197, Time: 74.91s


Epoch 180/200: 100%|██████████| 6871/6871 [01:15<00:00, 90.89it/s] 


Epoch 180/200, Train Loss: 0.0163, Validation Loss: 0.0197, Time: 75.60s


Epoch 181/200: 100%|██████████| 6871/6871 [01:16<00:00, 89.37it/s] 


Epoch 181/200, Train Loss: 0.0163, Validation Loss: 0.0216, Time: 76.89s


Epoch 182/200: 100%|██████████| 6871/6871 [01:16<00:00, 90.30it/s] 


Epoch 182/200, Train Loss: 0.0163, Validation Loss: 0.0195, Time: 76.09s


Epoch 183/200: 100%|██████████| 6871/6871 [01:16<00:00, 89.77it/s] 


Epoch 183/200, Train Loss: 0.0163, Validation Loss: 0.0203, Time: 76.54s


Epoch 184/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.19it/s] 


Epoch 184/200, Train Loss: 0.0162, Validation Loss: 0.0196, Time: 77.91s


Epoch 185/200: 100%|██████████| 6871/6871 [01:15<00:00, 90.80it/s] 


Epoch 185/200, Train Loss: 0.0162, Validation Loss: 0.0195, Time: 75.68s


Epoch 186/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.93it/s] 


Epoch 186/200, Train Loss: 0.0161, Validation Loss: 0.0195, Time: 77.27s


Epoch 187/200: 100%|██████████| 6871/6871 [01:18<00:00, 88.05it/s] 


Epoch 187/200, Train Loss: 0.0161, Validation Loss: 0.0197, Time: 78.03s


Epoch 188/200: 100%|██████████| 6871/6871 [01:20<00:00, 84.91it/s] 


Epoch 188/200, Train Loss: 0.0161, Validation Loss: 0.0205, Time: 80.93s


Epoch 189/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.56it/s] 


Epoch 189/200, Train Loss: 0.0161, Validation Loss: 0.0195, Time: 77.59s


Epoch 190/200: 100%|██████████| 6871/6871 [01:18<00:00, 86.99it/s] 


Epoch 190/200, Train Loss: 0.0160, Validation Loss: 0.0194, Time: 78.99s


Epoch 191/200: 100%|██████████| 6871/6871 [01:17<00:00, 89.19it/s] 


Epoch 191/200, Train Loss: 0.0160, Validation Loss: 0.0242, Time: 77.04s


Epoch 192/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.92it/s] 


Epoch 192/200, Train Loss: 0.0160, Validation Loss: 0.0197, Time: 79.05s


Epoch 193/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.89it/s] 


Epoch 193/200, Train Loss: 0.0160, Validation Loss: 0.0199, Time: 79.08s


Epoch 194/200: 100%|██████████| 6871/6871 [01:15<00:00, 90.58it/s] 


Epoch 194/200, Train Loss: 0.0159, Validation Loss: 0.0204, Time: 75.86s


Epoch 195/200: 100%|██████████| 6871/6871 [01:20<00:00, 85.21it/s] 


Epoch 195/200, Train Loss: 0.0159, Validation Loss: 0.0201, Time: 80.64s


Epoch 196/200: 100%|██████████| 6871/6871 [01:20<00:00, 85.75it/s] 


Epoch 196/200, Train Loss: 0.0159, Validation Loss: 0.0214, Time: 80.13s


Epoch 197/200: 100%|██████████| 6871/6871 [01:19<00:00, 86.14it/s] 


Epoch 197/200, Train Loss: 0.0158, Validation Loss: 0.0201, Time: 79.77s


Epoch 198/200: 100%|██████████| 6871/6871 [01:18<00:00, 87.51it/s] 


Epoch 198/200, Train Loss: 0.0158, Validation Loss: 0.0194, Time: 78.52s


Epoch 199/200: 100%|██████████| 6871/6871 [01:17<00:00, 88.75it/s] 


Epoch 199/200, Train Loss: 0.0158, Validation Loss: 0.0198, Time: 77.42s


Epoch 200/200: 100%|██████████| 6871/6871 [01:16<00:00, 89.30it/s] 


Epoch 200/200, Train Loss: 0.0158, Validation Loss: 0.0204, Time: 76.94s


In [19]:
# plot_losses(train_losses, val_losses)

In [54]:
# load model
model.load_state_dict(torch.load("/data/user/R901105/dev/log/HalfCheetah-v3/action_classification/0/240504-201628/model.pt"))

<All keys matched successfully>

In [55]:
with torch.no_grad():
    test_pred = model(test_dataloader.dataset.X).round().cpu()
    test_labels = test_dataloader.dataset.y.cpu()

cm = confusion_matrix(test_labels, test_pred)

In [56]:
cm

array([[1958117,   39881],
       [   1467,  198333]])

In [57]:
cm / cm.sum(axis=1)[:, np.newaxis]

array([[0.98003952, 0.01996048],
       [0.00734234, 0.99265766]])

In [63]:
np.infty

inf

In [49]:
# save the model
# torch.save(model.state_dict(), "hopper_medium_sa.pt")