In [1]:
import torch
import torch.nn as nn
import gym
import numpy as np
import random
import torch.nn.functional as F
%matplotlib inline
import matplotlib.pyplot as plt
from torch.distributions import Categorical

envs = ['CartPole-v1','Acrobot-v1','MountainCar-v0','Breakout-v0','BipedalWalker-v2','LunarLander-v2']
env_to_use = 0

discrete_actions = True

env = gym.make(envs[env_to_use]).unwrapped
env.seed(1)
if discrete_actions:
    action_size = env.action_space.n
else:
    action_size = env.action_space.shape[0]
    
state_size = 4
encoded_size = 4
gamma = 0.95
beta = 0.2

In [2]:
class Exploration_Policy(nn.Module):
    def __init__(self,state_size,action_size):
        super(Exploration_Policy, self).__init__()
        self.action_size =action_size
        self.state_size = state_size
        
        self.layer1 = nn.Linear(state_size , 6)
        self.layer2 = nn.Linear(6, 6)
        self.layer3 = nn.Linear(6, action_size)
        
    def forward(self, x):
        out = F.relu(self.layer1(x) )
        out = F.relu(self.layer2(out))
        out = F.sigmoid(self.layer3(out))
        return out
    
class Forward(nn.Module):
    def __init__(self,encoded_size,action_size):
        super(Forward, self).__init__()
        self.action_size =action_size
        self.encoded_size = encoded_size
        
        self.layer1 = nn.Linear(encoded_size + action_size , 6)
        self.layer2 = nn.Linear(6, 6)
        self.layer3 = nn.Linear(6, encoded_size)
        
    def forward(self, x,x1):
        x = torch.cat([x,x1],-1)
        out = F.relu(self.layer1(x) )
        out = F.relu(self.layer2(out))
        out = self.layer3(out)
        return out    

class Inverse(nn.Module):
    def __init__(self,encoded_size,action_size):
        super(Inverse, self).__init__()
        self.action_size =action_size
        self.encoded_size = encoded_size
        
        self.layer1 = nn.Linear(encoded_size*2 , 8)
        self.layer2 = nn.Linear(8, 8)
        self.layer3 = nn.Linear(8, 8)
        self.layer4 = nn.Linear(8, 8)
        self.layer5 = nn.Linear(8, action_size)
        
    def forward(self, x,x1):
        x = torch.cat([x,x1],-1)
        out = F.relu(self.layer1(x) )
        out = F.relu(self.layer2(out))
        out = F.relu(self.layer3(out))
        out = F.relu(self.layer4(out))
        #out = F.sigmoid(self.layer4(out))
        out = (self.layer5(out))
        
        return out 

class Encoder(nn.Module):
    def __init__(self,encoded_size,state_size):
        super(Encoder, self).__init__()
        self.action_size =action_size
        self.encoded_size = encoded_size
        
        self.layer1 = nn.Linear(state_size , 4)
        self.layer2 = nn.Linear(4, 4)
        self.layer3 = nn.Linear(4, 4)
        self.layer4 = nn.Linear(4, encoded_size)
        
    def forward(self, x):
        out = F.elu(self.layer1(x) )
        out = F.elu(self.layer2(out))
        out = F.elu(self.layer3(out))
        out = self.layer4(out)
        return out
    
enc_net = Encoder(encoded_size,state_size)#.to(device)
for_net = Forward(encoded_size,action_size)#.to(device)
inv_net = Inverse(encoded_size,action_size)#.to(device)    
explore_agent = Exploration_Policy(state_size,action_size)    

inv_optimizer = torch.optim.Adam(list(enc_net.parameters()) + 
                                 list(for_net.parameters()) + 
                                  list(inv_net.parameters()),lr = 0.001)

#for_optimizer = torch.optim.Adam(for_net.parameters(),lr = 0.01)


agent_optimizer = torch.optim.Adam(explore_agent.parameters(),lr = 0.01)


In [None]:
batch_size = 5
for e in range(1000):
    state = env.reset()
    state = torch.from_numpy(state).type("torch.FloatTensor")
    all_rewards = []
    all_actions = []
    all_probs = []
    all_states = []
    all_new_states = []
    all_for_loss = []
    total_reward = 0
    steps = 0
    first = True
    while True:
        all_states.append(state)
        steps+=1
        #env.render()
        agent_action = explore_agent(state)
        
        action_distribution = torch.softmax(agent_action,-1)
        
        
        m = Categorical(action_distribution)
        action = m.sample()
        
        all_actions.append(action)
        all_probs.append(action_distribution[action])
        
        new_state, reward, done, info = env.step(action.item())
        total_reward += reward
        all_rewards.append(reward)
        
        
        new_state = torch.from_numpy(new_state).type("torch.FloatTensor")
        all_new_states.append(new_state)
        if ((steps-1) % batch_size == 0 or done) and (steps-1)!=0:
            
        #if done:
            all_states = torch.stack(all_states).type("torch.FloatTensor")
            all_new_states = torch.stack(all_new_states).type("torch.FloatTensor")
            all_actions = torch.stack(all_actions)#.type("torch.LongTensor")
            if True:
                ######ICM#######
                enc_state = enc_net(all_states)
                enc_new_state = enc_net(all_new_states)
                
                #one hot
                action_o = torch.zeros([all_states.size(0),action_distribution.size(0)])
                action_o = action_o.scatter(1,all_actions.unsqueeze(-1),1)
                
                inv_out = inv_net(enc_state,enc_new_state)
                
                inv_out = F.softmax(inv_out)#.type("torch.FloatTensor")
                #print(inv_out[0],all_actions[0])
                
                
                #inv_loss = torch.mean(torch.pow(inv_out-action_o,2))
                

                inv_optimizer.zero_grad()
                inv_loss = torch.mean(nn.CrossEntropyLoss()(inv_out,all_actions)) * (1-beta)
                #inv_loss.backward(retain_graph=True)
                #inv_optimizer.step()
                
                #for_optimizer.zero_grad()
                for_out = for_net(action_o,enc_state)
                for_loss = torch.mean((torch.pow(for_out-enc_new_state,2)),-1)
                for_loss_mean = torch.mean(for_loss) * beta
                
                #for_loss_mean.backward(retain_graph=True)
                
                #for_optimizer.step()
                
                icm_loss =inv_loss + for_loss_mean
                icm_loss.backward(retain_graph=True)
                
                inv_optimizer.step() 
                
                all_for_loss.append(for_loss)
                all_actions = []
                all_states = []
                all_new_states = []
            ################
        first = False    
        if done:
            print(inv_loss.item(),for_loss_mean.item())
            all_probs = torch.stack(all_probs)#.type("torch.FloatTensor")
            all_for_loss = torch.cat(all_for_loss,0)    
            
            
            #Discout and normalize reward
            all_rewards = torch.from_numpy(np.array(all_rewards)).type("torch.FloatTensor")
            all_rewards += all_for_loss
            
            running_add = 0
            for i in reversed(range(all_rewards.size(0))):
                running_add = running_add * gamma + all_rewards[i]
                all_rewards[i] = running_add
               
            
            all_rewards = (all_rewards-torch.mean(all_rewards))/(torch.std(all_rewards)+0.0000001)
            
            
            agent_optimizer.zero_grad()
            
            #probs = torch.gather(all_distributions,-1,all_actions.unsqueeze(-1))
            loss = torch.mean(-torch.log(all_probs) * all_rewards)
            loss.backward()
            agent_optimizer.step() 
            
            all_rewards = []
            all_probs = []
            
            
            
        
            print("Episode : {}  Rewards : {}".format(e+1,np.sum(total_reward)))
            break
            
        state = new_state



0.5558045506477356 0.02187994122505188
Episode : 1  Rewards : 13.0
0.5794923305511475 0.03585883229970932
Episode : 2  Rewards : 10.0
0.5558602213859558 0.033102601766586304
Episode : 3  Rewards : 13.0
0.5778562426567078 0.02815191261470318
Episode : 4  Rewards : 30.0
0.513637900352478 0.0166670810431242
Episode : 5  Rewards : 23.0
0.5473519563674927 0.014659528620541096
Episode : 6  Rewards : 36.0
0.5474921464920044 0.014601901173591614
Episode : 7  Rewards : 11.0
0.5157783627510071 0.010570619255304337
Episode : 8  Rewards : 12.0
0.5753481984138489 0.01092276070266962
Episode : 9  Rewards : 10.0
0.5793377757072449 0.011939008720219135
Episode : 10  Rewards : 16.0
0.5749748945236206 0.00904905702918768
Episode : 11  Rewards : 15.0
0.5172165036201477 0.006013869773596525
Episode : 12  Rewards : 12.0
0.5741155743598938 0.005034265108406544
Episode : 13  Rewards : 40.0
0.5908892750740051 0.004303883295506239
Episode : 14  Rewards : 37.0
0.5345005989074707 0.003523237304762006
Episode : 1

0.5569114685058594 3.4826785849872977e-05
Episode : 128  Rewards : 29.0
0.553381085395813 7.721064321231097e-05
Episode : 129  Rewards : 41.0
0.5595453381538391 5.274338400340639e-05
Episode : 130  Rewards : 8.0
0.5545130372047424 2.787805351545103e-05
Episode : 131  Rewards : 20.0
0.5545070171356201 8.41151486383751e-05
Episode : 132  Rewards : 33.0
0.5547634363174438 2.267158379254397e-05
Episode : 133  Rewards : 36.0
0.5542078018188477 4.3784577428596094e-05
Episode : 134  Rewards : 18.0
0.5541655421257019 4.99920679430943e-05
Episode : 135  Rewards : 31.0
0.5564383268356323 7.820891187293455e-05
Episode : 136  Rewards : 22.0
0.5556120872497559 0.00011305041698506102
Episode : 137  Rewards : 12.0
0.5544055104255676 1.1724964679160621e-05
Episode : 138  Rewards : 26.0
0.5548701882362366 7.941194053273648e-05
Episode : 139  Rewards : 14.0
0.5535337328910828 5.527922257897444e-05
Episode : 140  Rewards : 56.0
0.5545045733451843 9.305930143455043e-05
Episode : 141  Rewards : 30.0
0.5550

Episode : 251  Rewards : 26.0
0.5546272397041321 2.8858243240392767e-05
Episode : 252  Rewards : 30.0
0.5590628981590271 9.80700715444982e-05
Episode : 253  Rewards : 34.0
0.5676791071891785 4.795495988219045e-05
Episode : 254  Rewards : 28.0
0.5665884017944336 5.718269221688388e-06
Episode : 255  Rewards : 22.0
0.5646999478340149 4.853721748077078e-06
Episode : 256  Rewards : 24.0
0.5527054667472839 1.8152590200770646e-05
Episode : 257  Rewards : 16.0
0.5617139339447021 0.00017636321717873216
Episode : 258  Rewards : 18.0
0.5531244874000549 2.4788399969111197e-05
Episode : 259  Rewards : 39.0
0.5505954623222351 5.096230233903043e-05
Episode : 260  Rewards : 24.0
0.549275815486908 7.312455181818223e-06
Episode : 261  Rewards : 27.0
0.5625381469726562 0.00011508647730806842
Episode : 262  Rewards : 27.0
0.563067615032196 2.83708686765749e-05
Episode : 263  Rewards : 28.0
0.5544711351394653 4.528969839157071e-06
Episode : 264  Rewards : 33.0
0.5544608235359192 6.897333605593303e-06
Episo

Episode : 368  Rewards : 83.0
0.2507193088531494 2.8217420549481176e-05
Episode : 369  Rewards : 77.0
0.25100260972976685 6.76790978104691e-06
Episode : 370  Rewards : 27.0
0.26018640398979187 0.00132435979321599
Episode : 371  Rewards : 198.0
0.2506321370601654 6.137391756055877e-05
Episode : 372  Rewards : 78.0
0.25075361132621765 2.103009683196433e-05
Episode : 373  Rewards : 63.0
0.2508048713207245 3.19449209200684e-05
Episode : 374  Rewards : 70.0
0.25346842408180237 0.0007141738897189498
Episode : 375  Rewards : 20.0
0.25168636441230774 0.0005975948297418654
Episode : 376  Rewards : 11.0
0.2541917562484741 0.0008237781003117561
Episode : 377  Rewards : 33.0
0.2509666979312897 0.0001243379811057821
Episode : 378  Rewards : 81.0
0.25095707178115845 3.043472497665789e-05
Episode : 379  Rewards : 23.0
0.25118568539619446 0.0002939617552328855
Episode : 380  Rewards : 26.0
0.25071582198143005 6.266302079893649e-05
Episode : 381  Rewards : 55.0
0.251064658164978 4.728636849904433e-05
E

0.2506181001663208 2.7839007088914514e-05
Episode : 483  Rewards : 114.0
0.25061988830566406 3.3406391594326124e-05
Episode : 484  Rewards : 34.0
0.2506171464920044 2.702035453694407e-05
Episode : 485  Rewards : 41.0
0.25061947107315063 3.112457488896325e-05
Episode : 486  Rewards : 118.0
0.25062140822410583 0.00023164767480921
Episode : 487  Rewards : 87.0
0.2506236732006073 1.0517720511415973e-05
Episode : 488  Rewards : 84.0
0.2507118880748749 6.145265797385946e-05
Episode : 489  Rewards : 31.0
0.25109514594078064 4.819730747840367e-05
Episode : 490  Rewards : 199.0
0.25069138407707214 0.0003151385753881186
Episode : 491  Rewards : 179.0
0.25065183639526367 1.2141740626248065e-05
Episode : 492  Rewards : 92.0
0.25061437487602234 1.2399587831168901e-05
Episode : 493  Rewards : 22.0
0.2506217658519745 1.4577431102225091e-05
Episode : 494  Rewards : 216.0
0.25061342120170593 6.146312807686627e-05
Episode : 495  Rewards : 89.0
0.2506166398525238 9.498923645878676e-06
Episode : 496  Rewa

Episode : 597  Rewards : 158.0
0.2510303556919098 0.00027539697475731373
Episode : 598  Rewards : 172.0
0.2506353557109833 1.8392105630482547e-05
Episode : 599  Rewards : 284.0
0.2506103813648224 9.144015348283574e-05
Episode : 600  Rewards : 239.0
0.2506115436553955 8.48576473799767e-06
Episode : 601  Rewards : 183.0
0.25061148405075073 5.562292062677443e-05
Episode : 602  Rewards : 149.0
0.2506096363067627 4.619318133336492e-05
Episode : 603  Rewards : 45.0
0.2506096065044403 1.6235206203418784e-05
Episode : 604  Rewards : 174.0
0.2506221532821655 2.0584948288160376e-05
Episode : 605  Rewards : 49.0
0.2506099343299866 9.160422450804617e-06
Episode : 606  Rewards : 93.0
0.25060972571372986 6.451393710449338e-05
Episode : 607  Rewards : 36.0
0.25060996413230896 8.914981299312785e-05
Episode : 608  Rewards : 129.0
0.250610888004303 3.200332503183745e-05
Episode : 609  Rewards : 85.0
0.2506096661090851 1.556621464260388e-05
Episode : 610  Rewards : 40.0
0.2506100833415985 6.5936378632613

0.25061139464378357 1.8415814338368364e-05
Episode : 714  Rewards : 72.0
0.250613272190094 1.2366846021905076e-05
Episode : 715  Rewards : 16.0
0.2506107985973358 1.0178084721701453e-06
Episode : 716  Rewards : 26.0
0.2506100833415985 1.3741310795012396e-05
Episode : 717  Rewards : 37.0
0.2506105899810791 2.4978614874271443e-06
Episode : 718  Rewards : 40.0
0.2506108283996582 1.8353908671997488e-05
Episode : 719  Rewards : 66.0
0.25066348910331726 6.286103143793298e-06
Episode : 720  Rewards : 123.0
0.2506141662597656 1.089787019736832e-05
Episode : 721  Rewards : 22.0
0.25061360001564026 1.0146823115064763e-05
Episode : 722  Rewards : 114.0
0.25060996413230896 9.92080367723247e-06
Episode : 723  Rewards : 20.0
0.2506170868873596 1.2636950486921705e-05
Episode : 724  Rewards : 41.0
0.2506100833415985 1.014190183923347e-05
Episode : 725  Rewards : 19.0
0.25062111020088196 1.293087552767247e-05
Episode : 726  Rewards : 19.0
0.25061142444610596 1.3389885680226143e-06
Episode : 727  Reward

0.2506093978881836 2.2122379959910177e-05
Episode : 828  Rewards : 18.0
0.2506106495857239 2.7291303013043944e-06
Episode : 829  Rewards : 27.0
0.2506108283996582 5.939045422564959e-06
Episode : 830  Rewards : 19.0
0.25060948729515076 8.780752068560105e-06
Episode : 831  Rewards : 231.0
0.25060948729515076 1.1843661013699602e-05
Episode : 832  Rewards : 105.0
0.2506101131439209 2.7279356800136156e-06
Episode : 833  Rewards : 550.0
0.25061047077178955 3.4405024962325115e-06
Episode : 834  Rewards : 16.0
0.25060948729515076 1.3117946764396038e-05
Episode : 835  Rewards : 336.0
0.25060969591140747 1.2545072422653902e-05
Episode : 836  Rewards : 26.0
0.25060975551605225 1.089546458388213e-05
Episode : 837  Rewards : 29.0
0.25060954689979553 3.0602070637542056e-06
Episode : 838  Rewards : 61.0
0.25060948729515076 8.170126761797292e-07
Episode : 839  Rewards : 99.0
0.25060969591140747 3.954571639042115e-06
Episode : 840  Rewards : 57.0
0.25060948729515076 6.3805541685724165e-06
Episode : 841