In [46]:
import gym
import d4rl # Import required to register environments
# import time 
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

In [48]:
env = gym.make("hopper-expert-v0")
dataset = env.get_dataset()

load datafile: 100%|██████████████████████████████| 5/5 [00:00<00:00, 19.75it/s]


In [49]:
len(dataset['observations'])

1000000

In [50]:
data_array = np.array(dataset["rewards"][:10000])
neg_array = data_array[data_array<0]
neg_array.size

0

In [51]:
observations = torch.tensor(dataset['observations'])[:10000]
mean = observations.mean(dim=0)
std = observations.std(dim=0)
observations = (observations - mean)/std
actions = torch.tensor(dataset['actions'])[:10000]
# data = TensorPairDataset(observations,actions)

In [52]:
observations

tensor([[-0.7744,  0.7826,  2.3278,  ...,  0.0355,  0.0036, -0.0113],
        [-0.7768,  0.8052,  2.3472,  ...,  0.2680, -0.2564,  0.1914],
        [-0.7831,  0.7728,  2.3604,  ...,  0.1356, -0.6247,  0.2166],
        ...,
        [ 1.3501, -0.4600,  0.7185,  ..., -0.8782,  0.6239,  1.1158],
        [ 1.3850, -0.5587,  0.6150,  ..., -1.5694,  0.8177,  1.1408],
        [ 1.4168, -0.6333,  0.5020,  ..., -1.1095,  1.2184,  1.2135]])

In [53]:
env.reset()

array([ 1.24900130e+00, -1.92332705e-03, -5.94129118e-04,  4.82180722e-03,
       -1.59802033e-03, -3.18701423e-03, -3.22828108e-03, -2.66117682e-04,
       -4.43989664e-03,  4.80747718e-03,  2.83557979e-03])

In [54]:
dataset_N = {"observations":dataset["observations"][:10000],"actions":dataset["actions"][:10000]}
new_dataset = [dict(zip(dataset_N.keys(), values)) for values in zip(*dataset_N.values())]

# Print result
print(new_dataset)


[{'observations': array([ 1.2508533e+00,  1.6305586e-03, -7.9514465e-04,  3.3486076e-04,
       -4.9994583e-03,  2.8168827e-03, -2.1439972e-03, -1.2370852e-03,
        1.4964999e-03,  1.6224863e-03, -4.4636223e-03], dtype=float32), 'actions': array([ 0.8259195 , -0.43341178,  0.8120992 ], dtype=float32)}, {'observations': array([ 1.2504431e+00,  3.1245863e-03,  2.3545863e-03, -2.4580609e-03,
       -2.3094496e-04, -1.3907926e-02, -9.9610195e-02,  5.3787638e-02,
        3.9008829e-01, -6.6784775e-01,  1.1991799e+00], dtype=float32), 'actions': array([ 0.8244549 , -0.7164865 ,  0.07432321], dtype=float32)}, {'observations': array([ 1.2493614e+00,  9.8507630e-04,  4.5010685e-03, -1.1589271e-02,
        9.9630365e-03, -1.1045936e-01, -1.7149104e-01, -5.7005167e-01,
        1.6879366e-01, -1.6157442e+00,  1.3486822e+00], dtype=float32), 'actions': array([ 0.8499072 , -0.20028582,  0.13544567], dtype=float32)}, {'observations': array([ 1.247657  , -0.00447613,  0.00532527, -0.02554708,  0.02

In [92]:
class PM_Network(nn.Module):
    def __init__(self,input_size,output_size):
        super().__init__()
        self.network_m = nn.Sequential(nn.Linear(input_size,4),
                                     nn.Tanh(),
                                     nn.Linear(4,6),
                                     nn.Tanh(),
                                     nn.Linear(6,3))                                     
        self.network_s = nn.Sequential(nn.Linear(input_size,4),
                                     nn.Tanh(),
                                     nn.Linear(4,6),
                                     nn.Tanh(),
                                     nn.Linear(6,3))      
    def forward(self,state):
        self.mean = self.network_m(state)
        self.std = F.softplus(self.network_s(state)) + 1e-6
        print(self.mean,self.std)
        
        normal_dist = torch.distributions.Normal(self.mean,self.std)
        raw_action = normal_dist.rsample()  # Use reparameterization trick
        
        # Map action to [-1, 1] using tanh
        action = torch.clamp(raw_action,min=-1,max=1)
        # print(self.mean,self.std)
        # Compute log probability of the sampled action
        log_prob = normal_dist.log_prob(raw_action) 
        # print(log_prob)
        log_prob = log_prob.sum(dim=-1, keepdim=True)  # Sum over action dimensions
        # print(action)
        
        return action, log_prob

        
                    
           

In [241]:
class D_Network(nn.Module):
    def __init__(self,input_size):
        super(D_Network,self).__init__()
        self.network = nn.Sequential(nn.Linear(input_size,8),
                                     nn.Tanh(),
                                     nn.Linear(8,4),
                                     nn.Tanh(),
                                     nn.Linear(4,1),
                                     nn.Sigmoid())
    def forward(self,state,action):
        print(state,action)
        print("State shape:", state.shape, "Action shape:", action.shape)
        # x = torch.cat((state, action.detach()), dim=-1)
        # print("combined input",x)
        self.output = self.network(state.float())

        return self.output
    
        



In [242]:
class P_Network(nn.Module):
    def __init__(self,input_size,output_size):
        super().__init__()
        self.network = nn.Sequential(nn.Linear(input_size,64),
                                     nn.Tanh(),
                                     nn.Linear(64,16),
                                     nn.Tanh(),
                                     nn.Linear(16,3),
                                     nn.Sigmoid)                                     
        
    def forward(self,state):
        # self.mean = self.network_m(state)
        # self.std = torch.exp(self.network_s(state))
        # normal_dist = torch.distributions.Normal(self.mean,self.std)
        # raw_action = normal_dist.rsample()  # Use reparameterization trick
        
        # Map action to [-1, 1] using tanh
        # action = torch.clamp(raw_action,min=-1,max=1)
        # print(self.mean,self.std)
        # Compute log probability of the sampled action
        # log_prob = normal_dist.log_prob(raw_action) 
        # # print(log_prob)
        # log_prob = log_prob.sum(dim=-1, keepdim=True)  # Sum over action dimensions
        # print(action)

        action = self.network(state.float())
        
        log_prob = torch.log(action + 1e-10)
        
        return action, log_prob

In [243]:
Policy = PM_Network(11,3)

In [244]:
Discriminator = D_Network(11)

In [245]:
state = torch.from_numpy(env.reset())

In [246]:
Policy(state.float())

tensor([-0.0751,  0.2758, -0.1405], grad_fn=<ViewBackward0>) tensor([0.8406, 0.5698, 0.4598], grad_fn=<AddBackward0>)


(tensor([0.0669, 0.0078, 0.0012], grad_fn=<ClampBackward1>),
 tensor([-1.4162], grad_fn=<SumBackward1>))

In [247]:
import random

In [248]:

def train(dataset,epochs,traj_no,max_steps,P_network,D_network,lr_D,lr_P,lamda,opt_func=torch.optim.Adam):
    state = torch.from_numpy(env.reset())
    # print(state.dtype)
    expert_trajectories =[]
    for t in range(traj_no):
        idx = random.randint(0,len(dataset)-max_steps)
        expert_trajectories.append(dataset[idx:idx+max_steps])

    for epoch in range(epochs):

        Ex_aD_loss = 0
        Ex_eD_loss = 0
        Ex_P_loss =0 
        Ex_H_loss=0
        # Trajectory Sampling
        sample_trajectories =[]
        for path in range(traj_no):
            aD_loss=0
            eD_loss=0
            current_trajectory=[]
            for step in range(max_steps):
                print(state)
                action, log_prob = P_network(state.float())
                current_trajectory.append({"state":state,"actions":action,"log_prob":log_prob})
                temp_action = action
                next_state,reward,done,info = env.step(np.array(temp_action.detach()))
                # print(expert_trajectories[path])
                exp_state = torch.Tensor(expert_trajectories[path][step]["observations"])
                exp_action = torch.Tensor(expert_trajectories[path][step]["actions"])
                
                # print(action,state)
                aD_loss += (D_network(state.detach(),action.detach())+1e-8)
                eD_loss += (1- D_network(exp_state,exp_action)+1e-8)

                # P_loss+=log_prob*(D_network(state,action).item())
                # H_loss+=log_prob
                
                if done or step==max_steps-1:
                    eD_loss /= step
                    aD_loss /= step
                    # P_loss /= step
                    break
                state = torch.from_numpy(next_state)
                
            sample_trajectories.append(current_trajectory)
                    

            Ex_aD_loss+=aD_loss
            Ex_eD_loss+=eD_loss

            # Ex_P_loss +=P_loss
            # Ex_H_loss +=H_loss
        
        ### Updating Discriminator 
        print("============UPDATE DISCRI===============")
        Ex_aD_loss/=traj_no
        Ex_eD_loss/=traj_no
        total_D_loss = Ex_aD_loss + Ex_eD_loss
        
        
        opt_D = opt_func(D_network.parameters(),lr_D)
        print(Ex_aD_loss,Ex_eD_loss)
        print("total D loss",total_D_loss)
        print(list(D_network.parameters()))
        total_D_loss.backward()
        opt_D.step()
        opt_D.zero_grad()
        print(list(D_network.parameters()))
        print(D_network(state.detach(),action.detach()))
        
        ### Updating the Policy
        # print("================UPDATE POLICY======================")
        # Ex_P_loss=0
        # for traj in sample_trajectories:
        #     traj_P_loss=0
        #     for idx,step in enumerate(traj):
        #         Q=0
        #         for jdx in range(idx+1,len(traj)):
        #             print(traj[jdx]['state'],traj[jdx]['actions'])
        #             temp = D_network(traj[jdx]['state'].detach(),traj[jdx]['actions'].detach())
                    
        #             print("temp:",temp)
        #             Q+=torch.log(temp).item()
                    
        #         # print("Q:",Q)
        #         traj_P_loss+=step['log_prob']*Q
            
        #     Ex_P_loss+=traj_P_loss
        
                
        # Ex_P_loss/=traj_no
        # # Ex_H_loss/=traj_no
        # total_P_loss = (Ex_P_loss)
        # print("PLOss:",total_P_loss)
        # opt_P = opt_func(P_network.parameters(),lr_P)
        # total_P_loss.backward()
        # opt_P.step()
        # opt_P.zero_grad()
        # print("P descent happened!!")
        
    
    

In [249]:
train(new_dataset,5,2,20,Policy,Discriminator,1e-5,1e-2,0.01)

tensor([ 1.2515e+00,  2.7997e-03, -2.1417e-03, -3.5030e-03,  4.4028e-03,
        -3.5820e-03,  2.8444e-03, -4.7819e-03, -1.2851e-04,  1.4149e-03,
        -4.9405e-03], dtype=torch.float64)
tensor([-0.0751,  0.2746, -0.1408], grad_fn=<ViewBackward0>) tensor([0.8404, 0.5697, 0.4597], grad_fn=<AddBackward0>)
tensor([ 1.2515e+00,  2.7997e-03, -2.1417e-03, -3.5030e-03,  4.4028e-03,
        -3.5820e-03,  2.8444e-03, -4.7819e-03, -1.2851e-04,  1.4149e-03,
        -4.9405e-03], dtype=torch.float64) tensor([-0.3172, -0.8940, -0.6683])
State shape: torch.Size([11]) Action shape: torch.Size([3])
tensor([ 1.5171,  0.0101, -0.2016, -0.4037, -0.6284,  2.3313,  0.7056, -0.2988,
         0.7314, -0.9135,  5.2155]) tensor([-0.4769,  0.7096,  0.2156])
State shape: torch.Size([11]) Action shape: torch.Size([3])
tensor([ 1.2513e+00, -6.5371e-04, -3.5572e-03, -8.0201e-03,  7.3932e-04,
        -1.4502e-01, -4.3472e-02, -8.5790e-01, -3.5342e-01, -1.1295e+00,
        -9.1001e-01], dtype=torch.float64)
tensor(

In [229]:
env.action_space

Box(-1.0, 1.0, (3,), float32)