# I M P L I C I T - Q - L E A R N I N G - IQN

In [1]:
import torch
torch.set_float32_matmul_precision('highest')
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset

from torch.utils.tensorboard import SummaryWriter
import numpy as np

import h5py
import tqdm


### D E V I C E

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')


Device: cuda


### D A T A 

In [3]:
file = 'C:\OFFLINE RL\hopper_medium-v2.hdf5'

with h5py.File(file, mode = 'r') as f:
    
    observations = np.array(f['observations'])
    actions = np.array(f['actions'])
    rewards = np.array(f['rewards'])
    next_obs = np.array(f['next_observations'])
    
state_dim = observations.shape[1]
action_dim = actions.shape[1]
max_action = abs(actions).max()

print(f'len of data: {len(observations)}')
print(f'state dim: {state_dim} | action dim: {action_dim} | max action: {max_action}')


len of data: 1000000
state dim: 11 | action dim: 3 | max action: 0.9999945163726807


### D A T A - E N G I N E E R I N G

In [4]:
obs_tensor = torch.from_numpy(observations).float().to(device)
act_tensor = torch.from_numpy(actions).float().to(device)
rew_tensor = torch.from_numpy(rewards).float().to(device)
next_obs_tensor = torch.from_numpy(next_obs).float().to(device)

class Hopper_Dataset(Dataset):
    
    def __init__(self, obs, act, rew, next_obs_t):
        
        self.states = obs
        self.actions = act
        self.rewards = rew
        self.next_states = next_obs_t
        
    def __len__(self):
        
        return len(self.states)
    
    def __getitem__(self, index):
        return self.states[index], self.actions[index], self.rewards[index], self.next_states[index]
    
dataset = Hopper_Dataset(obs_tensor, act_tensor, rew_tensor, next_obs_tensor)

train_data, test_data = train_test_split(dataset, test_size = 0.1, shuffle = False)

train_loader, test_loader = DataLoader(train_data, batch_size = 256, shuffle = True, drop_last = True), DataLoader(test_data, batch_size = 256, shuffle = True, drop_last = True)


### L O G G I N G

In [5]:
writer = SummaryWriter(log_dir = './runs/IQL')


### A S S E M B L Y

In [6]:
head_1 = 128
head_2 = 256
head_3 = 256
head_4 = 128

hidden_size = 128
hidden_size_2 = 256


### F E A T U R E 

In [7]:
class Feature_Extractor(nn.Module):
    
    def __init__(self, input_dim, output_dim, hidden_size = hidden_size, hidden_size_2 = hidden_size_2):
        super(Feature_Extractor, self).__init__()
        
        self.cal = nn.Sequential(
            
            nn.Linear(input_dim, hidden_size),
            nn.SiLU(),
            
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size_2),
            nn.SiLU(),
            
            nn.Linear(hidden_size_2, hidden_size),
            nn.SiLU(),
            
            nn.Linear(hidden_size, output_dim),
            nn.SiLU()
        )
        
    def forward(self, x):
    
        return self.cal(x)


### P O L I C Y

In [8]:
class policy_net(nn.Module):
    
    def __init__(self, head_1 = head_1, head_2 = head_2, head_3 = head_3, head_4 = head_4, max_action = max_action, state_dim = state_dim, action_dim = action_dim):
        super(policy_net, self).__init__()
        
        # feature 
        
        self.feature = Feature_Extractor(state_dim, head_1)
        
        # norm
        
        self.norm  = nn.LayerNorm(head_1)
        
        # pos process
        
        self.pos_process = nn.Sequential(
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.SiLU()
        ) 
        
        # mu and log std head
        
        self.mu = nn.Linear(head_4, action_dim)
        self.log_std = nn.Linear(head_4, action_dim)
        
        # max action
        
        self.max_action = max_action
        
    def forward(self, state, deterministic = False, return_log_probs = False):
        
        # feature
        
        feature = self.feature(state)
        
        # norm
        
        norm = self.norm(feature)
        
        # pos
        
        pos = self.pos_process(norm)
        
        # mu and log std
        
        mu = self.mu(pos)
        log_std = self.log_std(pos)
        
        log_std = torch.clamp(log_std, -10, 2)
        std = torch.exp(log_std)
        
        dist = torch.distributions.Normal(mu, std)
        z = mu if deterministic == True else dist.rsample()
        tanh_z = torch.tanh(z)
        action = tanh_z * self.max_action
        
        log_prob = None
        
        if return_log_probs:
            
            log_prob = dist.log_prob(z).sum(dim = -1, keepdim = True)
            log_prob -= torch.log(1 - tanh_z.pow(2) + 1e-6).sum(dim = -1, keepdim = True)
        
        return action, mu, log_std, log_prob
        

### V A L U E 

In [9]:
class value_net(nn.Module):
    
    def __init__(self, state_dim = state_dim, head_1 = head_1, head_2 = head_2, head_3 = head_3, head_4 = head_4):
        super(value_net, self).__init__()
        
        # process
        
        self.process = nn.Sequential(
            
            nn.Linear(state_dim, head_1),
            nn.SiLU(),
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            
            nn.LayerNorm(head_2),
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            
            nn.LayerNorm(head_3),
            nn.Linear(head_3, head_4),
            nn.SiLU(),
            
            nn.Linear(head_4, 1)
        )
        
    def forward(self, state):
        
        value = self.process(state)
        
        return value


### Q - F U N C T I O N 

In [10]:
class q_net(nn.Module):
    
    def __init__(self, state_dim = state_dim, action_dim = action_dim, head_1 = head_1, head_2 = head_2, head_3 = head_3, head_4 = head_4):
        super(q_net, self).__init__()
        
        self.q_1 = nn.Sequential(
            
            nn.Linear(state_dim + action_dim, head_1),
            nn.SiLU(),
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            nn.LayerNorm(head_2),
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            nn.LayerNorm(head_3),
            nn.Linear(head_3, head_4),
            nn.SiLU(),
            
            nn.Linear(head_4, 1)
        )
        
        self.q_2 = nn.Sequential(
            
            nn.Linear(state_dim + action_dim, head_1),
            nn.SiLU(),
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            nn.LayerNorm(head_2),
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            nn.LayerNorm(head_3),
            nn.Linear(head_3, head_4),
            nn.SiLU(),
            
            nn.Linear(head_4, 1)
        )
        
    def forward(self, state, action):
        
        # cat
        
        cat = torch.cat([state, action], dim = -1)
        
        # prepare
        
        q_1 = self.q_1(cat)
        q_2 = self.q_2(cat)
        
        return q_1, q_2


### S E T U P 

In [11]:
POLICY_NETWORK = policy_net().to(device)
print(POLICY_NETWORK)
print('-' * 100)

VALUE_NETWORK = value_net().to(device)
print(VALUE_NETWORK)
print('-' * 100)

Q_NETWORK = q_net().to(device)
print(Q_NETWORK)
print('-' * 100)


policy_net(
  (feature): Feature_Extractor(
    (cal): Sequential(
      (0): Linear(in_features=11, out_features=128, bias=True)
      (1): SiLU()
      (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (3): Linear(in_features=128, out_features=256, bias=True)
      (4): SiLU()
      (5): Linear(in_features=256, out_features=128, bias=True)
      (6): SiLU()
      (7): Linear(in_features=128, out_features=128, bias=True)
      (8): SiLU()
    )
  )
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (pos_process): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): SiLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): SiLU()
    (4): Linear(in_features=256, out_features=128, bias=True)
    (5): SiLU()
  )
  (mu): Linear(in_features=128, out_features=3, bias=True)
  (log_std): Linear(in_features=128, out_features=3, bias=True)
)
---------------------------------------------------------------------------

### O P T I M I Z E R - S C H E D U L E R

In [12]:
# lr

T_max = 100
policy_lr = 1e-4
value_lr = 1e-3
q_lr = 1e-3

# O P T I M I Z E R

def get_optimizer(network, lr):
    
    return optim.AdamW(network.parameters(), lr, weight_decay = 0)

policy_optimizer = get_optimizer(POLICY_NETWORK, policy_lr)
value_optimizer = get_optimizer(VALUE_NETWORK, value_lr)
q_net_optimizer = get_optimizer(Q_NETWORK, q_lr)

# S C H E D U L E R

def get_scheduler(optimizer, T_max = T_max):
    
    return optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min = 1e-5)

policy_scheduler = get_scheduler(policy_optimizer)
value_scheduler = get_scheduler(value_optimizer)
q_net_scheduler = get_scheduler(q_net_optimizer)


### L O S S - F U N C

In [13]:
class loss_func:
    
    def __init__(self, gamma, tau, beta, POLICY_NETWORK = POLICY_NETWORK, VALUE_NETWORK = VALUE_NETWORK, Q_NETWORK = Q_NETWORK, policy_optimizer = policy_optimizer, value_optimizer = value_optimizer, q_net_optimizer = q_net_optimizer, policy_scheduler = policy_scheduler, value_scheduler = value_scheduler, q_net_scheduler = q_net_scheduler):
        
        # network
        
        self.policy = POLICY_NETWORK
        self.value = VALUE_NETWORK
        self.q_network = Q_NETWORK
        
        # optimizer
        
        self.policy_optimizer = policy_optimizer
        self.value_optimizer = value_optimizer
        self.q_optimizer = q_net_optimizer
        
        # scheduler
        
        self.policy_scheduler = policy_scheduler
        self.value_scheduler = value_scheduler
        self.q_scheduler = q_net_scheduler
        
        # hyper params
        
        self.gamma = gamma
        self.tau = tau
        self.beta = beta
        
    def value_loss(self, states, rewards, next_states):
        
        # current v
        
        v = self.value(states)
        
        # next v
        
        with torch.no_grad():
            
            next_v = self.value(next_states)
            
        delta = rewards + self.gamma * next_v - v
        weight = torch.where(delta > 0, self.tau, 1 - self.tau)
        
        value_loss = (weight * delta.pow(2)).mean()
        
        return value_loss
    
    def q_loss(self, states, actions, rewards, next_states):
        
        # cal target value
        
        with torch.no_grad():
            
            next_v = self.value(next_states)
            target = rewards + self.gamma * next_v
            
        q1, q2 = self.q_network(states, actions)
        l1 = F.mse_loss(q1, target)
        l2 = F.mse_loss(q2, target)
        
        loss = 0.5 * (l1 + l2)
        
        return loss
    
    def policy_loss(self, states):
        
        # cal log probs and actions
        
        actions, _, _, log_probs = self.policy(states, return_log_probs = True)
        
        # cal q vals
        
        with torch.no_grad():
            
            q1, q2 = self.q_network(states, actions)
            
            q = torch.min(q1, q2)
            
        # cal v
        
            v = self.value(states)
            
            advantages = q - v
        
        # cal weights
        
        weights = torch.exp(self.beta * advantages)
        weights = torch.clamp(weights, max = 100.0)
        
        # cal policy loss
        
        policy_loss = -(weights * log_probs).mean()
        
        
        return policy_loss, advantages
    
    def update(self, data_loader, epoch):
        
        running_policy_loss = 0.0
        running_value_loss = 0.0
        running_q_loss = 0.0
        all_advantages = []
        
        
        for states, actions, rewards, next_states in data_loader:
            
            if rewards.dim() == 1: rewards = rewards.unsqueeze(1)
        
            # cal value loss
            
            value_loss = self.value_loss(states, rewards, next_states)
            running_value_loss += value_loss.item()
            
            # update value network
            
            self.value_optimizer.zero_grad()
            value_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.value.parameters(), max_norm = 0.5)
            self.value_optimizer.step()
            self.value_scheduler.step()
            
            # cal q _ loss
            
            q_loss = self.q_loss(states, actions, rewards, next_states)
            running_q_loss += q_loss.item()
            
            # update q network
            
            self.q_optimizer.zero_grad()
            q_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), max_norm = 0.5)
            self.q_optimizer.step()
            self.q_scheduler.step()
            
            # policy loss
            
            policy_loss, advantages = self.policy_loss(states)
            running_policy_loss += policy_loss.item()
            all_advantages.append(advantages.detach().cpu())
            
            # update
            
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.policy.parameters(), max_norm = 0.5)
            self.policy_optimizer.step()
            self.policy_scheduler.step()
            
        avg_value_loss = running_value_loss / len(data_loader)
        avg_q_loss = running_q_loss / len(data_loader)
        avg_policy_loss = running_policy_loss / len(data_loader)
        
        all_advantages = torch.cat(all_advantages, dim = 0)
        writer.add_histogram('Advantages', all_advantages, epoch)
        
        return avg_value_loss, avg_q_loss, avg_policy_loss


### S E T U P 


In [14]:
# hyper params

gamma = 0.97
tau = 0.7
beta = 2.0

# setup

LOSS_FUNCTION = loss_func(gamma, tau, beta)


### T R A I N I N G

In [15]:
epochs = 100

for epoch in tqdm.tqdm(range(epochs), desc = 'Training IQL'):
    
    avg_value_loss, avg_q_loss, avg_policy_loss = LOSS_FUNCTION.update(train_loader, epoch)
    
    writer.add_scalar('Value loss', avg_value_loss, epoch)
    writer.add_scalar('Q Loss', avg_q_loss, epoch)
    writer.add_scalar('Policy loss', avg_policy_loss, epoch)
    
    writer.flush()
    
    tqdm.tqdm.write(f'Epoch: {epoch} / {epochs} | '
                    f'avg v loss: {avg_value_loss:.4f} | '
                    f'avg q loss: {avg_q_loss:.4f} | '
                    f'avg policy loss: {avg_policy_loss:.4f}')
    

Training IQL:   1%|          | 1/100 [01:28<2:26:50, 89.00s/it]

Epoch: 0 / 100 | avg v loss: 1.3859 | avg q loss: 84.1923 | avg policy loss: -948.2040


Training IQL:   2%|▏         | 2/100 [02:33<2:02:00, 74.70s/it]

Epoch: 1 / 100 | avg v loss: 0.4220 | avg q loss: 26.8879 | avg policy loss: -1035.5075


Training IQL:   3%|▎         | 3/100 [03:35<1:50:57, 68.64s/it]

Epoch: 2 / 100 | avg v loss: 0.3363 | avg q loss: 16.9098 | avg policy loss: -1007.7772


Training IQL:   4%|▍         | 4/100 [04:35<1:44:36, 65.38s/it]

Epoch: 3 / 100 | avg v loss: 0.3244 | avg q loss: 16.3538 | avg policy loss: -983.1977


Training IQL:   5%|▌         | 5/100 [05:54<1:51:10, 70.22s/it]

Epoch: 4 / 100 | avg v loss: 0.2972 | avg q loss: 15.5811 | avg policy loss: -957.3058


Training IQL:   6%|▌         | 6/100 [07:23<1:59:56, 76.55s/it]

Epoch: 5 / 100 | avg v loss: 0.2784 | avg q loss: 13.8399 | avg policy loss: -913.3911


Training IQL:   7%|▋         | 7/100 [08:50<2:04:18, 80.20s/it]

Epoch: 6 / 100 | avg v loss: 0.2698 | avg q loss: 13.1250 | avg policy loss: -878.2010


Training IQL:   8%|▊         | 8/100 [10:16<2:05:32, 81.87s/it]

Epoch: 7 / 100 | avg v loss: 0.2401 | avg q loss: 10.4078 | avg policy loss: -809.1322


Training IQL:   9%|▉         | 9/100 [11:44<2:07:07, 83.82s/it]

Epoch: 8 / 100 | avg v loss: 0.2385 | avg q loss: 12.4915 | avg policy loss: -789.0014


Training IQL:  10%|█         | 10/100 [13:09<2:06:28, 84.32s/it]

Epoch: 9 / 100 | avg v loss: 0.2343 | avg q loss: 11.3595 | avg policy loss: -708.0381


Training IQL:  11%|█         | 11/100 [14:33<2:04:50, 84.16s/it]

Epoch: 10 / 100 | avg v loss: 0.2415 | avg q loss: 12.4804 | avg policy loss: -728.6002


Training IQL:  12%|█▏        | 12/100 [15:54<2:02:01, 83.20s/it]

Epoch: 11 / 100 | avg v loss: 0.2294 | avg q loss: 10.7120 | avg policy loss: -656.9594


Training IQL:  13%|█▎        | 13/100 [17:00<1:52:56, 77.89s/it]

Epoch: 12 / 100 | avg v loss: 0.2333 | avg q loss: 10.2143 | avg policy loss: -616.1512


Training IQL:  14%|█▍        | 14/100 [18:05<1:46:02, 73.98s/it]

Epoch: 13 / 100 | avg v loss: 0.2202 | avg q loss: 10.1577 | avg policy loss: -641.1785


Training IQL:  15%|█▌        | 15/100 [19:12<1:41:58, 71.98s/it]

Epoch: 14 / 100 | avg v loss: 0.2134 | avg q loss: 10.7072 | avg policy loss: -646.0024


Training IQL:  16%|█▌        | 16/100 [20:13<1:36:18, 68.79s/it]

Epoch: 15 / 100 | avg v loss: 0.2113 | avg q loss: 10.4909 | avg policy loss: -615.0130


Training IQL:  17%|█▋        | 17/100 [21:19<1:33:59, 67.94s/it]

Epoch: 16 / 100 | avg v loss: 0.2111 | avg q loss: 9.4724 | avg policy loss: -602.8571


Training IQL:  18%|█▊        | 18/100 [22:33<1:35:09, 69.62s/it]

Epoch: 17 / 100 | avg v loss: 0.2077 | avg q loss: 8.5919 | avg policy loss: -586.5915


Training IQL:  19%|█▉        | 19/100 [23:38<1:32:06, 68.22s/it]

Epoch: 18 / 100 | avg v loss: 0.1979 | avg q loss: 8.5397 | avg policy loss: -590.9516


Training IQL:  20%|██        | 20/100 [24:42<1:29:22, 67.03s/it]

Epoch: 19 / 100 | avg v loss: 0.2024 | avg q loss: 9.3672 | avg policy loss: -586.6571


Training IQL:  21%|██        | 21/100 [25:47<1:27:25, 66.40s/it]

Epoch: 20 / 100 | avg v loss: 0.1929 | avg q loss: 8.1910 | avg policy loss: -603.4874


Training IQL:  22%|██▏       | 22/100 [27:10<1:32:52, 71.44s/it]

Epoch: 21 / 100 | avg v loss: 0.1794 | avg q loss: 7.5313 | avg policy loss: -572.0566


Training IQL:  23%|██▎       | 23/100 [28:20<1:31:03, 70.96s/it]

Epoch: 22 / 100 | avg v loss: 0.1812 | avg q loss: 7.6855 | avg policy loss: -581.7412


Training IQL:  24%|██▍       | 24/100 [29:25<1:27:21, 68.97s/it]

Epoch: 23 / 100 | avg v loss: 0.1802 | avg q loss: 7.5148 | avg policy loss: -574.3077


Training IQL:  25%|██▌       | 25/100 [30:28<1:24:12, 67.37s/it]

Epoch: 24 / 100 | avg v loss: 0.1715 | avg q loss: 7.5830 | avg policy loss: -555.4828


Training IQL:  26%|██▌       | 26/100 [31:36<1:23:20, 67.57s/it]

Epoch: 25 / 100 | avg v loss: 0.1726 | avg q loss: 6.4475 | avg policy loss: -584.0136


Training IQL:  27%|██▋       | 27/100 [32:48<1:23:39, 68.76s/it]

Epoch: 26 / 100 | avg v loss: 0.1651 | avg q loss: 7.4918 | avg policy loss: -590.4853


Training IQL:  28%|██▊       | 28/100 [33:52<1:20:57, 67.47s/it]

Epoch: 27 / 100 | avg v loss: 0.1544 | avg q loss: 6.6567 | avg policy loss: -537.9736


Training IQL:  29%|██▉       | 29/100 [35:05<1:21:50, 69.16s/it]

Epoch: 28 / 100 | avg v loss: 0.1614 | avg q loss: 6.0530 | avg policy loss: -571.1491


Training IQL:  30%|███       | 30/100 [36:07<1:17:57, 66.83s/it]

Epoch: 29 / 100 | avg v loss: 0.1665 | avg q loss: 7.1615 | avg policy loss: -585.1828


Training IQL:  31%|███       | 31/100 [37:07<1:14:44, 64.99s/it]

Epoch: 30 / 100 | avg v loss: 0.1900 | avg q loss: 7.3615 | avg policy loss: -587.9151


Training IQL:  32%|███▏      | 32/100 [38:16<1:14:43, 65.94s/it]

Epoch: 31 / 100 | avg v loss: 0.1700 | avg q loss: 6.7824 | avg policy loss: -558.4485


Training IQL:  33%|███▎      | 33/100 [39:23<1:14:05, 66.35s/it]

Epoch: 32 / 100 | avg v loss: 0.1698 | avg q loss: 6.6882 | avg policy loss: -582.4580


Training IQL:  34%|███▍      | 34/100 [40:30<1:13:22, 66.70s/it]

Epoch: 33 / 100 | avg v loss: 0.1647 | avg q loss: 6.5308 | avg policy loss: -530.4409


Training IQL:  35%|███▌      | 35/100 [41:38<1:12:37, 67.03s/it]

Epoch: 34 / 100 | avg v loss: 0.1686 | avg q loss: 5.9968 | avg policy loss: -526.2778


Training IQL:  36%|███▌      | 36/100 [42:40<1:09:52, 65.50s/it]

Epoch: 35 / 100 | avg v loss: 0.1577 | avg q loss: 6.1140 | avg policy loss: -558.8939


Training IQL:  37%|███▋      | 37/100 [43:40<1:07:08, 63.94s/it]

Epoch: 36 / 100 | avg v loss: 0.1507 | avg q loss: 6.0059 | avg policy loss: -567.6307


Training IQL:  38%|███▊      | 38/100 [44:49<1:07:32, 65.36s/it]

Epoch: 37 / 100 | avg v loss: 0.1585 | avg q loss: 5.6388 | avg policy loss: -556.3631


Training IQL:  39%|███▉      | 39/100 [45:53<1:06:06, 65.03s/it]

Epoch: 38 / 100 | avg v loss: 0.1586 | avg q loss: 5.5372 | avg policy loss: -554.9001


Training IQL:  40%|████      | 40/100 [46:57<1:04:37, 64.63s/it]

Epoch: 39 / 100 | avg v loss: 0.1544 | avg q loss: 5.7016 | avg policy loss: -555.8178


Training IQL:  41%|████      | 41/100 [48:03<1:03:49, 64.91s/it]

Epoch: 40 / 100 | avg v loss: 0.1520 | avg q loss: 5.8189 | avg policy loss: -538.4936


Training IQL:  42%|████▏     | 42/100 [49:05<1:02:07, 64.26s/it]

Epoch: 41 / 100 | avg v loss: 0.1631 | avg q loss: 5.7670 | avg policy loss: -556.3157


Training IQL:  43%|████▎     | 43/100 [50:12<1:01:44, 65.00s/it]

Epoch: 42 / 100 | avg v loss: 0.1514 | avg q loss: 5.6906 | avg policy loss: -567.3556


Training IQL:  44%|████▍     | 44/100 [51:17<1:00:31, 64.84s/it]

Epoch: 43 / 100 | avg v loss: 0.1594 | avg q loss: 5.8813 | avg policy loss: -575.6659


Training IQL:  45%|████▌     | 45/100 [52:22<59:40, 65.10s/it]  

Epoch: 44 / 100 | avg v loss: 0.1422 | avg q loss: 5.0768 | avg policy loss: -547.9258


Training IQL:  46%|████▌     | 46/100 [53:26<58:09, 64.61s/it]

Epoch: 45 / 100 | avg v loss: 0.1610 | avg q loss: 6.1169 | avg policy loss: -570.6060


Training IQL:  47%|████▋     | 47/100 [54:25<55:40, 63.02s/it]

Epoch: 46 / 100 | avg v loss: 0.1523 | avg q loss: 5.8694 | avg policy loss: -549.3955


Training IQL:  48%|████▊     | 48/100 [55:27<54:17, 62.65s/it]

Epoch: 47 / 100 | avg v loss: 0.1692 | avg q loss: 5.5244 | avg policy loss: -573.0074


Training IQL:  49%|████▉     | 49/100 [56:38<55:22, 65.14s/it]

Epoch: 48 / 100 | avg v loss: 0.1540 | avg q loss: 5.5848 | avg policy loss: -557.9532


Training IQL:  50%|█████     | 50/100 [57:58<58:00, 69.60s/it]

Epoch: 49 / 100 | avg v loss: 0.1480 | avg q loss: 4.9120 | avg policy loss: -546.7591


Training IQL:  51%|█████     | 51/100 [59:03<55:39, 68.14s/it]

Epoch: 50 / 100 | avg v loss: 0.1536 | avg q loss: 5.2622 | avg policy loss: -542.1519


Training IQL:  52%|█████▏    | 52/100 [1:00:05<53:05, 66.37s/it]

Epoch: 51 / 100 | avg v loss: 0.1477 | avg q loss: 4.8330 | avg policy loss: -561.9044


Training IQL:  53%|█████▎    | 53/100 [1:01:09<51:33, 65.81s/it]

Epoch: 52 / 100 | avg v loss: 0.1474 | avg q loss: 4.9551 | avg policy loss: -564.5768


Training IQL:  54%|█████▍    | 54/100 [1:02:15<50:20, 65.67s/it]

Epoch: 53 / 100 | avg v loss: 0.1594 | avg q loss: 4.6637 | avg policy loss: -546.2826


Training IQL:  55%|█████▌    | 55/100 [1:03:28<51:02, 68.06s/it]

Epoch: 54 / 100 | avg v loss: 0.1375 | avg q loss: 4.3982 | avg policy loss: -545.7865


Training IQL:  56%|█████▌    | 56/100 [1:04:37<50:01, 68.21s/it]

Epoch: 55 / 100 | avg v loss: 0.1474 | avg q loss: 4.7570 | avg policy loss: -544.5940


Training IQL:  57%|█████▋    | 57/100 [1:05:39<47:31, 66.30s/it]

Epoch: 56 / 100 | avg v loss: 0.1399 | avg q loss: 4.7195 | avg policy loss: -552.2599


Training IQL:  58%|█████▊    | 58/100 [1:06:40<45:24, 64.88s/it]

Epoch: 57 / 100 | avg v loss: 0.1374 | avg q loss: 4.5671 | avg policy loss: -537.8728


Training IQL:  59%|█████▉    | 59/100 [1:07:42<43:36, 63.81s/it]

Epoch: 58 / 100 | avg v loss: 0.1511 | avg q loss: 4.3374 | avg policy loss: -506.9408


Training IQL:  60%|██████    | 60/100 [1:08:46<42:39, 63.98s/it]

Epoch: 59 / 100 | avg v loss: 0.1384 | avg q loss: 4.6709 | avg policy loss: -536.3736


Training IQL:  61%|██████    | 61/100 [1:09:47<41:05, 63.21s/it]

Epoch: 60 / 100 | avg v loss: 0.1363 | avg q loss: 4.2096 | avg policy loss: -527.0669


Training IQL:  62%|██████▏   | 62/100 [1:10:47<39:25, 62.25s/it]

Epoch: 61 / 100 | avg v loss: 0.1384 | avg q loss: 4.2015 | avg policy loss: -522.4280


Training IQL:  63%|██████▎   | 63/100 [1:11:48<38:05, 61.77s/it]

Epoch: 62 / 100 | avg v loss: 0.1291 | avg q loss: 4.1989 | avg policy loss: -506.2030


Training IQL:  64%|██████▍   | 64/100 [1:12:49<36:53, 61.48s/it]

Epoch: 63 / 100 | avg v loss: 0.1319 | avg q loss: 3.7546 | avg policy loss: -485.6794


Training IQL:  65%|██████▌   | 65/100 [1:14:04<38:13, 65.53s/it]

Epoch: 64 / 100 | avg v loss: 0.1384 | avg q loss: 4.3040 | avg policy loss: -480.9979


Training IQL:  66%|██████▌   | 66/100 [1:15:05<36:26, 64.30s/it]

Epoch: 65 / 100 | avg v loss: 0.1365 | avg q loss: 4.1112 | avg policy loss: -488.4565


Training IQL:  67%|██████▋   | 67/100 [1:16:05<34:37, 62.95s/it]

Epoch: 66 / 100 | avg v loss: 0.1286 | avg q loss: 3.9397 | avg policy loss: -476.9576


Training IQL:  68%|██████▊   | 68/100 [1:17:06<33:15, 62.35s/it]

Epoch: 67 / 100 | avg v loss: 0.1306 | avg q loss: 4.4507 | avg policy loss: -490.4714


Training IQL:  69%|██████▉   | 69/100 [1:18:06<31:53, 61.72s/it]

Epoch: 68 / 100 | avg v loss: 0.1371 | avg q loss: 3.9386 | avg policy loss: -474.0506


Training IQL:  70%|███████   | 70/100 [1:19:07<30:42, 61.41s/it]

Epoch: 69 / 100 | avg v loss: 0.1301 | avg q loss: 3.9215 | avg policy loss: -491.3894


Training IQL:  71%|███████   | 71/100 [1:20:10<29:59, 62.07s/it]

Epoch: 70 / 100 | avg v loss: 0.1325 | avg q loss: 4.2452 | avg policy loss: -478.6582


Training IQL:  72%|███████▏  | 72/100 [1:21:10<28:38, 61.38s/it]

Epoch: 71 / 100 | avg v loss: 0.1301 | avg q loss: 4.0651 | avg policy loss: -485.4012


Training IQL:  73%|███████▎  | 73/100 [1:22:13<27:50, 61.88s/it]

Epoch: 72 / 100 | avg v loss: 0.1387 | avg q loss: 4.1550 | avg policy loss: -478.0781


Training IQL:  74%|███████▍  | 74/100 [1:23:17<27:04, 62.47s/it]

Epoch: 73 / 100 | avg v loss: 0.1367 | avg q loss: 4.2156 | avg policy loss: -494.1085


Training IQL:  75%|███████▌  | 75/100 [1:24:26<26:47, 64.31s/it]

Epoch: 74 / 100 | avg v loss: 0.1356 | avg q loss: 3.9420 | avg policy loss: -474.7700


Training IQL:  76%|███████▌  | 76/100 [1:25:33<26:06, 65.27s/it]

Epoch: 75 / 100 | avg v loss: 0.1299 | avg q loss: 3.4255 | avg policy loss: -450.1687


Training IQL:  77%|███████▋  | 77/100 [1:26:36<24:45, 64.57s/it]

Epoch: 76 / 100 | avg v loss: 0.1405 | avg q loss: 3.8855 | avg policy loss: -454.4845


Training IQL:  78%|███████▊  | 78/100 [1:27:38<23:24, 63.82s/it]

Epoch: 77 / 100 | avg v loss: 0.1279 | avg q loss: 3.4315 | avg policy loss: -453.5228


Training IQL:  79%|███████▉  | 79/100 [1:28:39<22:02, 62.98s/it]

Epoch: 78 / 100 | avg v loss: 0.1305 | avg q loss: 3.7105 | avg policy loss: -453.9908


Training IQL:  80%|████████  | 80/100 [1:29:42<21:01, 63.05s/it]

Epoch: 79 / 100 | avg v loss: 0.1239 | avg q loss: 3.5652 | avg policy loss: -446.6634


Training IQL:  81%|████████  | 81/100 [1:30:45<19:52, 62.77s/it]

Epoch: 80 / 100 | avg v loss: 0.1260 | avg q loss: 3.5769 | avg policy loss: -445.2010


Training IQL:  82%|████████▏ | 82/100 [1:31:47<18:46, 62.57s/it]

Epoch: 81 / 100 | avg v loss: 0.1210 | avg q loss: 3.1479 | avg policy loss: -427.7220


Training IQL:  83%|████████▎ | 83/100 [1:32:48<17:37, 62.21s/it]

Epoch: 82 / 100 | avg v loss: 0.1125 | avg q loss: 3.2369 | avg policy loss: -466.4597


Training IQL:  84%|████████▍ | 84/100 [1:33:49<16:28, 61.80s/it]

Epoch: 83 / 100 | avg v loss: 0.1204 | avg q loss: 3.4668 | avg policy loss: -441.0920


Training IQL:  85%|████████▌ | 85/100 [1:34:51<15:26, 61.75s/it]

Epoch: 84 / 100 | avg v loss: 0.1250 | avg q loss: 3.5433 | avg policy loss: -437.3539


Training IQL:  86%|████████▌ | 86/100 [1:35:52<14:24, 61.72s/it]

Epoch: 85 / 100 | avg v loss: 0.1215 | avg q loss: 2.9917 | avg policy loss: -426.5437


Training IQL:  87%|████████▋ | 87/100 [1:36:55<13:26, 62.02s/it]

Epoch: 86 / 100 | avg v loss: 0.1156 | avg q loss: 3.8092 | avg policy loss: -450.3662


Training IQL:  88%|████████▊ | 88/100 [1:37:57<12:23, 61.94s/it]

Epoch: 87 / 100 | avg v loss: 0.1198 | avg q loss: 3.4020 | avg policy loss: -437.3512


Training IQL:  89%|████████▉ | 89/100 [1:38:59<11:21, 61.98s/it]

Epoch: 88 / 100 | avg v loss: 0.1191 | avg q loss: 3.2987 | avg policy loss: -459.9882


Training IQL:  90%|█████████ | 90/100 [1:40:01<10:19, 61.94s/it]

Epoch: 89 / 100 | avg v loss: 0.1129 | avg q loss: 2.8400 | avg policy loss: -412.0021


Training IQL:  91%|█████████ | 91/100 [1:41:10<09:36, 64.09s/it]

Epoch: 90 / 100 | avg v loss: 0.1259 | avg q loss: 3.3485 | avg policy loss: -439.5145


Training IQL:  92%|█████████▏| 92/100 [1:42:17<08:40, 65.11s/it]

Epoch: 91 / 100 | avg v loss: 0.1081 | avg q loss: 2.8144 | avg policy loss: -419.3894


Training IQL:  93%|█████████▎| 93/100 [1:43:36<08:05, 69.34s/it]

Epoch: 92 / 100 | avg v loss: 0.1113 | avg q loss: 3.2158 | avg policy loss: -432.3067


Training IQL:  94%|█████████▍| 94/100 [1:44:50<07:03, 70.55s/it]

Epoch: 93 / 100 | avg v loss: 0.1066 | avg q loss: 2.9034 | avg policy loss: -410.2417


Training IQL:  95%|█████████▌| 95/100 [1:46:00<05:52, 70.47s/it]

Epoch: 94 / 100 | avg v loss: 0.1197 | avg q loss: 3.4898 | avg policy loss: -417.3898


Training IQL:  96%|█████████▌| 96/100 [1:47:12<04:44, 71.01s/it]

Epoch: 95 / 100 | avg v loss: 0.1099 | avg q loss: 3.2162 | avg policy loss: -422.7069


Training IQL:  97%|█████████▋| 97/100 [1:48:23<03:32, 70.92s/it]

Epoch: 96 / 100 | avg v loss: 0.1138 | avg q loss: 3.2863 | avg policy loss: -433.9003


Training IQL:  98%|█████████▊| 98/100 [1:49:32<02:20, 70.47s/it]

Epoch: 97 / 100 | avg v loss: 0.1120 | avg q loss: 2.9351 | avg policy loss: -397.9808


Training IQL:  99%|█████████▉| 99/100 [1:50:42<01:10, 70.31s/it]

Epoch: 98 / 100 | avg v loss: 0.1137 | avg q loss: 2.8402 | avg policy loss: -412.6640


Training IQL: 100%|██████████| 100/100 [1:52:11<00:00, 67.32s/it]

Epoch: 99 / 100 | avg v loss: 0.1150 | avg q loss: 2.9650 | avg policy loss: -380.7476



