### B A T C H - C L O N I N G

In [1]:
import h5py
import numpy as np

import torch
torch.set_float32_matmul_precision('high')
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt

from torch.utils.data import DataLoader, Dataset

from torch_setup import compile_model


### D E V I C E 

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')


Device: cuda


### E N V - D A T A S E T

In [None]:
file = 'C:\OFFLINE RL\hopper_medium-v2.hdf5'

with h5py.File(file, 'r') as f:
    
    observations = np.array(f['observations'])
    actions = np.array(f['actions'])
    rewards = np.array(f['rewards'])
    
print("Observations shape:", observations.shape) 
print("Actions shape:", actions.shape)            
print("Rewards shape:", rewards.shape)

state_dim = observations.shape[1]
action_dim = actions.shape[1]
max_action = np.abs(actions).max()

print(f"State dim: {state_dim}, Action dim: {action_dim}, Max action: {max_action}")


Observations shape: (1000000, 11)
Actions shape: (1000000, 3)
Rewards shape: (1000000,)
State dim: 11, Action dim: 3, Max action: 0.9999945163726807


### P R E P A R E - D A T A 

In [4]:
obs_tensor = torch.from_numpy(observations).float().to(device)
act_tensor = torch.from_numpy(actions).float().to(device)

class HOPPER_DATASET(Dataset):
    
    def __init__(self, observations, actions):

        self.observations = observations
        self.actions = actions
        
    def __len__(self):
        
        return len(self.observations)
        
    def __getitem__(self, index):
        
        return self.observations[index], self.actions[index]
    
# create dataset

dataset = HOPPER_DATASET(obs_tensor, act_tensor)
Data_loader = DataLoader(dataset, batch_size = 256, shuffle = True, drop_last = True)
        

### A S S E M B L Y


In [None]:
head_1 = 128
head_2 = 256
head_3 = 256
head_4 = 256

hidden_size = 128
hidden_size_2 = 256


### F E A T U R E 

In [6]:
class Feature_Extractor(nn.Module):
    
    def __init__(self, input_dim, output_dim, hidden_size = hidden_size, hidden_size_2 = hidden_size_2):
        super(Feature_Extractor, self).__init__()
        
        
        self.cal = nn.Sequential(
            
            nn.Linear(input_dim, hidden_size),
            nn.SiLU(),
            
            nn.Linear(hidden_size, hidden_size),
            nn.SiLU(),
            
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size_2),
            nn.SiLU(),
            
            
            nn.LayerNorm(hidden_size_2),
            nn.Linear(hidden_size_2, hidden_size),
            nn.SiLU(),
            
            
            nn.Linear(hidden_size, output_dim),
            nn.SiLU()
        )
        
    def forward(self, x):
        
        return self.cal(x)


### P O L I C Y 

In [7]:
class policy_net(nn.Module):
    
    def __init__(self, state_dim = state_dim, action_dim = action_dim, head_1 = head_1, head_2 = head_2, head_3 = head_3, head_4 = head_4, max_action = max_action):
        super(policy_net, self).__init__()
        
        # max action
        
        self.max_action = max_action
        
        # feature
        
        self.feature = Feature_Extractor(state_dim, head_1)
        
        # norm
        
        self.norm = nn.LayerNorm(head_1)
        
        # pos feature
        
        self.pos_feature = nn.Sequential(
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            
            nn.Linear(head_3, head_4),
            nn.SiLU()
        )
        
        # mu and log std 
        
        self.mu = nn.Linear(head_4, action_dim)
        
        
    def forward(self, state):
        
        # feature
        
        feature = self.feature(state)
        
        # norm
        
        norm = self.norm(feature)
        
        # pos feature
        
        pos = self.pos_feature(norm)
        
        # mu and log std
        
        mu = self.mu(pos)
        action = torch.tanh(mu) * self.max_action
        
        return action


### S E T U P 

In [None]:
# actor net

ACTOR_NETWORK = policy_net().to(device)
print(ACTOR_NETWORK)


policy_net(
  (feature): Feature_Extractor(
    (cal): Sequential(
      (0): Linear(in_features=11, out_features=128, bias=True)
      (1): SiLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): SiLU()
      (4): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (5): Linear(in_features=128, out_features=256, bias=True)
      (6): SiLU()
      (7): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (8): Linear(in_features=256, out_features=128, bias=True)
      (9): SiLU()
      (10): Linear(in_features=128, out_features=128, bias=True)
      (11): SiLU()
    )
  )
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (pos_feature): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): SiLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): SiLU()
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): SiLU()
  )
  (mu): Linear(in_features=256, out_features=3, bia

### O P T I M I Z E R - S C H E D U L E R

In [None]:
lr = 3e-4
T_max = 30

OPTIMIZER = optim.AdamW(ACTOR_NETWORK.parameters(), lr, weight_decay = 0.001)
SCHEDULER = optim.lr_scheduler.CosineAnnealingLR(OPTIMIZER, T_max)


### T R A I N

In [10]:
loss_func = nn.MSELoss()
epochs = 30

for epoch in range(epochs):
    
    running_loss = 0.0
    
    for states, actions in Data_loader:
        
        pred_actions = ACTOR_NETWORK(states)
        
        loss = loss_func(pred_actions, actions)
        
        OPTIMIZER.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(ACTOR_NETWORK.parameters(), max_norm = 0.5)
        OPTIMIZER.step()
        
        running_loss += loss.item()
        
    SCHEDULER.step()
        
    avg_loss = running_loss / len(Data_loader)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}")
    

Epoch 1/30, Loss: 0.091180
Epoch 2/30, Loss: 0.067434
Epoch 3/30, Loss: 0.062793
Epoch 4/30, Loss: 0.060402
Epoch 5/30, Loss: 0.058840
Epoch 6/30, Loss: 0.057648
Epoch 7/30, Loss: 0.056782
Epoch 8/30, Loss: 0.056107
Epoch 9/30, Loss: 0.055370
Epoch 10/30, Loss: 0.054927
Epoch 11/30, Loss: 0.054458
Epoch 12/30, Loss: 0.054038
Epoch 13/30, Loss: 0.053556
Epoch 14/30, Loss: 0.053216
Epoch 15/30, Loss: 0.052839
Epoch 16/30, Loss: 0.052515
Epoch 17/30, Loss: 0.052191
Epoch 18/30, Loss: 0.051891
Epoch 19/30, Loss: 0.051595
Epoch 20/30, Loss: 0.051309
Epoch 21/30, Loss: 0.051051
Epoch 22/30, Loss: 0.050820
Epoch 23/30, Loss: 0.050586
Epoch 24/30, Loss: 0.050384
Epoch 25/30, Loss: 0.050204
Epoch 26/30, Loss: 0.050050
Epoch 27/30, Loss: 0.049911
Epoch 28/30, Loss: 0.049804
Epoch 29/30, Loss: 0.049733
Epoch 30/30, Loss: 0.049682
