In [3]:
import gym
import d4rl # Import required to register environments
# import time 
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

In [4]:
env = gym.make("hopper-expert-v0")
dataset = env.get_dataset()

load datafile: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 19.72it/s]


In [5]:
len(dataset['observations'])

1000000

In [6]:
data_array = np.array(dataset["rewards"][:10000])
neg_array = data_array[data_array<0]
neg_array.size

0

In [7]:
observations = torch.tensor(dataset['observations'])[:10000]
mean = observations.mean(dim=0)
std = observations.std(dim=0)
observations = (observations - mean)/std
actions = torch.tensor(dataset['actions'])[:10000]
# data = TensorPairDataset(observations,actions)

In [8]:
observations

tensor([[-0.7744,  0.7826,  2.3278,  ...,  0.0355,  0.0036, -0.0113],
        [-0.7768,  0.8052,  2.3472,  ...,  0.2680, -0.2564,  0.1914],
        [-0.7831,  0.7728,  2.3604,  ...,  0.1356, -0.6247,  0.2166],
        ...,
        [ 1.3501, -0.4600,  0.7185,  ..., -0.8782,  0.6239,  1.1158],
        [ 1.3850, -0.5587,  0.6150,  ..., -1.5694,  0.8177,  1.1408],
        [ 1.4168, -0.6333,  0.5020,  ..., -1.1095,  1.2184,  1.2135]])

In [9]:
env.reset()

array([ 1.25249914e+00,  3.32176429e-03,  4.94544072e-03,  4.75809402e-03,
       -4.01548025e-03,  2.70609150e-03, -8.25083861e-04,  1.01427586e-03,
        4.88109332e-03, -4.73209735e-03,  3.06503894e-03])

In [29]:
dataset_N = {"observations":dataset["observations"][:10000],"actions":dataset["actions"][:10000]}
new_dataset = [dict(zip(dataset_N.keys(), values)) for values in zip(*dataset_N.values())]

# Print result
# print(new_dataset)


In [11]:
class PM_Network(nn.Module):
    def __init__(self,input_size,output_size):
        super().__init__()
        self.network_m = nn.Sequential(nn.Linear(input_size,24),
                                     nn.Tanh(),
                                     nn.Linear(24,8),
                                     nn.Tanh(),
                                     nn.Linear(8,3))                                     
        self.network_s = nn.Sequential(nn.Linear(input_size,24),
                                     nn.Tanh(),
                                     nn.Linear(24,8),
                                     nn.Tanh(),
                                     nn.Linear(8,3))      
    def forward(self,state):
        self.mean = self.network_m(state)
        self.std = F.softplus(self.network_s(state)) + 1e-6
        # print(self.mean,self.std)
        
        normal_dist = torch.distributions.Normal(self.mean,self.std)
        raw_action = normal_dist.rsample()  # Use reparameterization trick
        
        # Map action to [-1, 1] using tanh
        action = torch.clamp(raw_action,min=-1,max=1)
        # print(self.mean,self.std)
        # Compute log probability of the sampled action
        log_prob = normal_dist.log_prob(raw_action) 
        # print(log_prob)
        log_prob = log_prob.sum(dim=-1, keepdim=True)  # Sum over action dimensions
        # print(action)
        
        return action, log_prob

        
                    
           

In [12]:
class D_Network(nn.Module):
    def __init__(self,input_size):
        super(D_Network,self).__init__()
        self.network = nn.Sequential(nn.Linear(input_size,24),
                                     nn.Tanh(),
                                     nn.Linear(24,16),
                                     nn.Tanh(),
                                     nn.Linear(16,4),
                                     nn.Tanh(),
                                     nn.Linear(4,1),
                                     nn.Sigmoid())
    def forward(self,state,action):
        # print(state,action)
        # print("State shape:", state.shape, "Action shape:", action.shape)
        x = torch.cat((state, action.detach()), dim=-1)
        # print("combined input",x)
        self.output = self.network(x.float())

        return self.output
    
        



In [14]:
Policy = PM_Network(11,3)

In [15]:
Discriminator = D_Network(14)

In [18]:
import random

In [19]:

def train(dataset,epochs,traj_no,max_steps,P_network,D_network,lr_D,lr_P,lamda,opt_func=torch.optim.Adam):
    state = torch.from_numpy(env.reset())
    # print(state.dtype)
    expert_trajectories =[]
    for t in range(traj_no):
        idx = random.randint(0,len(dataset)-max_steps)
        expert_trajectories.append(dataset[idx:idx+max_steps])

    for epoch in range(epochs):

        Ex_aD_loss = 0
        Ex_eD_loss = 0
        Ex_P_loss =0 
        Ex_H_loss=0
        # Trajectory Sampling
        sample_trajectories =[]
        for path in range(traj_no):
            state = torch.from_numpy(env.reset())
            aD_loss=0
            eD_loss=0
            H_loss=0
            current_trajectory=[]
            for step in range(1,max_steps):
                
                action, log_prob = P_network(state.float())
                current_trajectory.append({"state":state,"actions":action,"log_prob":log_prob})
                temp_action = action
                next_state,reward,done,info = env.step(np.array(temp_action.detach()))

                exp_state = torch.Tensor(expert_trajectories[path][step]["observations"])
                exp_action = torch.Tensor(expert_trajectories[path][step]["actions"])
                
                aD_loss += torch.log(D_network(state.detach(),action.detach())+1e-8)
                eD_loss += torch.log(1- D_network(exp_state,exp_action)+1e-8)

                H_loss-=log_prob*torch.exp(log_prob).item()
                
                if done or step==max_steps-1:
                    # eD_loss /= step
                    # aD_loss /= step
                    # P_loss /= step
                    break
                state = torch.from_numpy(next_state)
                
            sample_trajectories.append(current_trajectory)
                    

            Ex_aD_loss+=aD_loss
            Ex_eD_loss+=eD_loss
            Ex_H_loss += H_loss
        
        ### Updating Discriminator 
        print("XX============UPDATE DISCRIMINATOR===============XX")
        Ex_aD_loss/=traj_no
        Ex_eD_loss/=traj_no
        total_D_loss = -(Ex_aD_loss + Ex_eD_loss)
        
        
        opt_D = opt_func(D_network.parameters(),lr_D)
        # print(Ex_aD_loss,Ex_eD_loss)
        print("total D loss",total_D_loss)
        # print(list(D_network.parameters()))
        total_D_loss.backward()
        opt_D.step()
        opt_D.zero_grad()

        
        ## Updating the Policy
        print("XX================UPDATE POLICY======================XX")
        Ex_P_loss=0
        
        for traj in sample_trajectories:
            traj_P_loss=0
            for idx,step in enumerate(traj):
                Q=0.1
                for jdx in range(idx+1,len(traj)):
                    temp = D_network(traj[jdx]['state'].detach(),traj[jdx]['actions'].detach())
                    # print("temp",temp)
                    Q+=torch.log(temp).item()
                    
                # print("Q:",Q)
                traj_P_loss+=step['log_prob']*Q
            
            Ex_P_loss+=traj_P_loss
        
                
        Ex_P_loss/=traj_no
        # Ex_H_loss/=traj_no
        # Ex_H_loss/=traj_no
        total_P_loss = (Ex_P_loss) - lamda*Ex_H_loss 
        print("PLOss:",total_P_loss)
        opt_P = opt_func(P_network.parameters(),lr_P)
        total_P_loss.backward()
        opt_P.step()
        opt_P.zero_grad()
        print("P descent happened!!")
        
    
    

In [21]:
train(new_dataset,100,10,300,Policy,Discriminator,1e-2,1e-3,0.01)

total D loss tensor([66.1064], grad_fn=<NegBackward0>)
PLOss: tensor([1381.8131], grad_fn=<SubBackward0>)
P descent happened!!
total D loss tensor([72.4726], grad_fn=<NegBackward0>)
PLOss: tensor([1587.6176], grad_fn=<SubBackward0>)
P descent happened!!
total D loss tensor([60.0060], grad_fn=<NegBackward0>)
PLOss: tensor([1137.4939], grad_fn=<SubBackward0>)
P descent happened!!
total D loss tensor([62.9477], grad_fn=<NegBackward0>)
PLOss: tensor([1319.0419], grad_fn=<SubBackward0>)
P descent happened!!
total D loss tensor([60.3831], grad_fn=<NegBackward0>)
PLOss: tensor([1233.6179], grad_fn=<SubBackward0>)
P descent happened!!
total D loss tensor([65.4535], grad_fn=<NegBackward0>)
PLOss: tensor([1655.8586], grad_fn=<SubBackward0>)
P descent happened!!
total D loss tensor([55.6413], grad_fn=<NegBackward0>)
PLOss: tensor([1120.8329], grad_fn=<SubBackward0>)
P descent happened!!
total D loss tensor([53.0847], grad_fn=<NegBackward0>)
PLOss: tensor([1099.3429], grad_fn=<SubBackward0>)
P des

In [23]:
env.action_space

Box(-1.0, 1.0, (3,), float32)

In [137]:
Policy(state.float())

(tensor([-0.0253, -0.0562, -0.0464], grad_fn=<ClampBackward1>),
 tensor([3.9466], grad_fn=<SumBackward1>))

In [22]:
def test(model,D_network,episodes):
    env = gym.make("hopper-expert-v0")
    # state,info = env.reset()
    state = torch.from_numpy(env.reset())
    total_reward=0
    D_loss =0
    for ep in range(episodes):
       
        action, log_prob = model(state.float())
        temp_action = action
        next_state,reward,done,info = env.step(np.array(temp_action.detach()))
        D_loss += torch.log(D_network(state.detach(),action.detach())+1e-8).item()
        total_reward+=reward
        if done or ep==episodes-1:
            print(total_reward,ep,D_loss)
            
            break
        state = torch.from_numpy(next_state)    
    env.close()    

In [24]:
for i in range(20):
    test(Policy,Discriminator,50)

65.1800582912413 39 -0.8850178867578506
78.16346447767653 49 -1.1062723584473133
73.42365866834382 44 -0.995645122602582
72.93741231893219 49 -1.1062723584473133
69.74051002100086 41 -0.9292687810957432
68.51393076983254 41 -0.9292687810957432
64.8223833661753 38 -0.8628924395889044
69.38194206856501 42 -0.9513942282646894
77.5262960162779 49 -1.1062723584473133
75.22054518555791 45 -1.0177705697715282
60.66736071436634 36 -0.8186415452510118
65.73002497388603 39 -0.8850178867578506
68.45210880422424 40 -0.9071433339267969
73.53251916926321 44 -0.995645122602582
72.1538682413839 43 -0.9735196754336357
73.3731325388231 44 -0.995645122602582
77.63000783666065 49 -1.1062723584473133
63.86802059181661 38 -0.8628924395889044
66.8414912263664 40 -0.9071433339267969
76.83695148143843 46 -1.0398960169404745


In [26]:
Policy_random = PM_Network(11,3)
D_random = D_Network(14)

In [28]:
for i in range(20):
    test(Policy_random,D_random,50)

7.881493748814688 8 -5.584661364555359
5.979921445232856 7 -5.030848741531372
13.379352816140864 15 -10.337769269943237
8.765036066981214 9 -6.272688925266266
9.90976770242666 9 -6.220690667629242
9.160820276921338 10 -7.0389769077301025
6.499856430808439 8 -5.742412090301514
13.452229521565847 14 -9.704015374183655
8.015553724005388 10 -7.068350195884705
5.04691883299593 11 -7.924397826194763
8.484104750842096 10 -7.06064110994339
6.275024420845617 10 -7.1368520855903625
4.320179777452697 6 -4.407621383666992
5.240018461710615 9 -6.480333924293518
8.197756770945944 9 -6.365648984909058
9.3885599189925 10 -6.9461177587509155
7.757151341306446 9 -6.3882155418396
10.730500608002185 11 -7.6920687556266785
7.384956126754444 10 -7.136242747306824
7.5382089593371795 8 -5.655667662620544
