In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal
from torch.distributions import Categorical

import gymnasium as gym
import pygame
import torch


import numpy as np
import collections, random

# Device selection
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [28]:
device

device(type='mps')

In [29]:
class ICM(nn.Module):
    def __init__(self, state_size, action_size, icm_parameters):
        super(ICM, self).__init__()

        feature_hidden_sizes, feature_size, inverse_hidden_sizes, forward_hidden_sizes, self.β, icm_lr = icm_parameters
        
        self.state_size = state_size
        self.action_size = action_size
        
        # Feature NN
        self.feature_net = nn.Sequential(
            nn.Linear(state_size, feature_hidden_sizes[0]),
            nn.ReLU(),
            nn.Linear(feature_hidden_sizes[0], feature_hidden_sizes[1]),
            nn.ReLU(),
            nn.Linear(feature_hidden_sizes[1], feature_size)
        ).to(device)
        
        # Inverse NN
        self.inverse_net = nn.Sequential(
            nn.Linear(feature_size * 2, inverse_hidden_sizes[0]),
            nn.ReLU(),
            nn.Linear(inverse_hidden_sizes[0], inverse_hidden_sizes[1]),
            nn.ReLU(),
            nn.Linear(inverse_hidden_sizes[1], action_size)
        ).to(device)
        
        # Forward NN
        self.forward_net = nn.Sequential(
            nn.Linear(feature_size + action_size, forward_hidden_sizes[0]),
            nn.ReLU(),
            nn.Linear(forward_hidden_sizes[0], forward_hidden_sizes[1]),
            nn.ReLU(),
            nn.Linear(forward_hidden_sizes[1], feature_size)
        ).to(device)

        self.L_I = nn.MSELoss()
        self.L_F = nn.MSELoss()
        
        self.optimizer = optim.Adam(self.parameters(), lr=icm_lr)
        
    def predict(self, s, a, s_prime):
        '''Predicts next action and state.'''
        s_reshaped = s.view(-1,self.state_size)
        a_reshaped = a.view(-1,self.action_size)
        s_prime_reshaped = s_prime.view(-1,self.state_size)
        cat1 = torch.cat([self.feature_net(s_reshaped), self.feature_net(s_prime_reshaped)], 1)
        cat2 = torch.cat([self.feature_net(s_reshaped).clone().detach(), a_reshaped], 1)
        a_hat = self.inverse_net(cat1)
        s_hat = self.forward_net(cat2)
        return a_hat, s_hat

    def predict_once(self, s, a, s_prime):
        '''Predicts next action and state for a single observation.'''
        s = torch.from_numpy(s).float()
        s_prime = torch.from_numpy(s_prime).float()
        cat1 = torch.cat([self.feature_net(s), self.feature_net(s_prime)])
        cat2 = torch.cat([self.feature_net(s).clone().detach(), a])
        a_hat = self.inverse_net(cat1)
        s_hat = self.forward_net(cat2)
        return a_hat, s_hat

    def loss(self, s, a, a_hat, s_prime, s_hat):
        '''Calculates loss.'''
        L_I = self.L_I(a_hat, a.view(-1,self.action_size))
        L_F = self.L_F(s_hat, self.feature_net(s_prime.view(-1,self.state_size)).clone().detach())
        return (1 - self.β) * L_I + self.β * L_F

    def loss_once(self, s, a, a_hat, s_prime, s_hat):
        '''Calculates loss for a single prediction.'''
        s_prime = torch.from_numpy(s_prime).float()
        L_I = self.L_I(a_hat, a)
        L_F = self.L_F(s_hat, self.feature_net(s_prime).clone().detach())
        return (1 - self.β) * L_I + self.β * L_F

    def compute_intrinsic_reward(self, s, a, s_prime):
        """
        Computes intrinsic reward based on ICM forward model error.

        """
        # Forward model prediction
        _, s_hat = self.predict(s, a, s_prime)
        
        # Intrinsic reward is the forward model error
        intrinsic_reward = F.mse_loss(s_hat, self.feature_net(s_prime).clone().detach(), reduction='none')
        #print(f'intrinsic_reward.shape: {intrinsic_reward.shape}') #intrinsic_reward.shape : (32,16->latent_dim)
        return intrinsic_reward.sum(dim=1)
            

## Soft Actor-Critic

In [30]:
#Hyperparameters
lr_pi           = 0.0005
lr_q            = 0.001
init_alpha      = 0.01
gamma           = 0.98
batch_size      = 128
buffer_limit    = 50000
tau             = 0.01 # for target network soft update
target_entropy  = -1.0 # for automated alpha update
lr_alpha        = 0.001  # for automated alpha update

In [31]:
class ReplayBuffer():
    def __init__(self):
        self.buffer = collections.deque(maxlen=buffer_limit)

    def put(self, transition):
        self.buffer.append(transition)
    
    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []

        for transition in mini_batch:
            s, a, r, s_prime, done = transition
            s_lst.append(s)
            a_lst.append(a)
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            done_mask = 0.0 if done else 1.0 
            done_mask_lst.append([done_mask])
        
        #print(type(s_lst))
        #print(torch.tensor(a_lst).shape)
        return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst, dtype=torch.float), \
                torch.tensor(r_lst, dtype=torch.float), torch.tensor(s_prime_lst, dtype=torch.float), \
                torch.tensor(done_mask_lst, dtype=torch.float)
    
    def size(self):
        return len(self.buffer)

In [32]:
class PolicyNet(nn.Module):
    def __init__(self, learning_rate, icm_params, state_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128).to(device)
        self.fc_mu = nn.Linear(128,action_dim).to(device)
        self.fc_std  = nn.Linear(128,action_dim).to(device)
        
        
        #self.icm = icm = ICM(state_size=state_dim, action_size=action_dim, icm_parameters=icm_params)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

        self.log_alpha = torch.tensor(np.log(init_alpha))
        self.log_alpha.requires_grad = True
        self.log_alpha_optimizer = optim.Adam([self.log_alpha], lr=lr_alpha)

    def forward(self, x):
        x = x.to(device)
        x = F.relu(self.fc1(x))
        mu = self.fc_mu(x)
        std = F.softplus(self.fc_std(x))
        dist = Normal(mu, std)
        action = dist.rsample()
        log_prob = dist.log_prob(action)
        real_action = torch.tanh(action)
        real_log_prob = log_prob - torch.log(1-torch.tanh(action).pow(2) + 1e-7)
        return real_action, real_log_prob

    def train_net(self, q1, q2, mini_batch, intrinsic_reward):
        s, _, _, _, _ = mini_batch
        a, log_prob = self.forward(s)
        entropy = -self.log_alpha.exp() * log_prob #Entropy loss term, weighted by alpha

        q1_val, q2_val = q1(s,a), q2(s,a)
        q1_q2 = torch.cat([q1_val, q2_val], dim=1)
        #print(f'q1_q2 cat shape: {q1_q2.shape}')
        min_q = torch.min(q1_q2, 1, keepdim=True)[0]

        '''
        Combined loss used before:
        # Update ICM module
        s_batch, a_batch, r_batch, s_prime_batch, _ = mini_batch
        a_hat_batch, s_hat_batch = self.icm.predict(s_batch.to(device), a_batch.to(device), s_prime_batch.to(device))
        intrinsic_loss = self.icm.loss(s_batch, a_batch, a_hat_batch, s_prime_batch, s_hat_batch)

        extrinsic_loss = -min_q - entropy # for gradient ascent
        
        
        print(f"Extrinsic Loss:{extrinsic_loss.mean()}, Intrinsic Loss:{intrinsic_loss.mean()}")
        
        #Combinng the losses:
        loss = extrinsic_loss + lambda_intrinsic*intrinsic_loss
        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step()
        '''

        intrinsic_reward = intrinsic_reward.clone().detach()
        #print(f'min_q shape: {min_q.shape}, entropy.shape: {entropy.shape}')  #min_q shape : (32,1)    entropy.shape : (32,6 -> action_dim), 
        loss = -min_q - entropy +  lambda_intrinsic*intrinsic_reward.mean()# for gradient ascent (CHECK IF ADDING REWARD LIKE THIS IS OKAY OR NOT)
        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step()

        self.log_alpha_optimizer.zero_grad()
        alpha_loss = -(self.log_alpha.exp() * (log_prob + target_entropy).detach()).mean()
        alpha_loss.backward()
        self.log_alpha_optimizer.step()

        return -min_q - entropy
        


In [33]:
class QNet(nn.Module):
    def __init__(self, learning_rate, state_dim, action_dim):
        super(QNet, self).__init__()
        self.fc_s = nn.Linear(state_dim, 64).to(device)
        self.fc_a = nn.Linear(action_dim,64).to(device)
        self.fc_cat = nn.Linear(128,32).to(device)
        self.fc_out = nn.Linear(32,action_dim).to(device)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

    def forward(self, x, a):
        x = x.to(device)
        a = a.to(device)
        h1 = F.relu(self.fc_s(x))
        h2 = F.relu(self.fc_a(a))
        cat = torch.cat([h1,h2], dim=1)
        q = F.relu(self.fc_cat(cat))
        q = self.fc_out(q)
        return q

    def train_net(self, target, mini_batch):
        s, a, r, s_prime, done = mini_batch
        loss = F.smooth_l1_loss(self.forward(s, a) , target)
        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step()

    def soft_update(self, net_target):
        for param_target, param in zip(net_target.parameters(), self.parameters()):
            param_target.data.copy_(param_target.data * (1.0 - tau) + param.data * tau)

In [34]:
def calc_target(pi, q1, q2, mini_batch):
    s, a, r, s_prime, done = mini_batch

    with torch.no_grad():
        a_prime, log_prob= pi(s_prime)
        entropy = -pi.log_alpha.exp() * log_prob
        q1_val, q2_val = q1(s_prime,a_prime), q2(s_prime,a_prime)
        q1_q2 = torch.cat([q1_val, q2_val], dim=1)
        min_q = torch.min(q1_q2, 1, keepdim=True)[0]
        #print(r.device, done.device, min_q.device, entropy.device)
        target = r + gamma * done * (min_q + entropy)

    return target

## ICM

In [35]:
#SAC Hyperparameters
lr_pi           = 0.0005
lr_q            = 0.001
init_alpha      = 0.01
gamma           = 0.98
batch_size      = 256
buffer_limit    = 50000
tau             = 0.01 # for target network soft update
target_entropy  = -6.0 # for automated alpha update
lr_alpha        = 0.001  # for automated alpha update


# ICM parameters
feature_hidden_sizes = [128,128,]
feature_size = 16
inverse_hidden_sizes = [128,128,]
forward_hidden_sizes = [128,128,]
β = 0.5
icm_lr = 0.001
lambda_intrinsic = 0.01 #controls trade-off between extrinsic loss and intrinsic loss2.0, 2.0, 10.0, 10.0, 4,0 ,4,0]

In [36]:
def main():
    env = gym.make('HalfCheetah-v4', render_mode="rgb_array")
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
        
    
    memory = ReplayBuffer()
    q1, q2, q1_target, q2_target = QNet(lr_q, state_dim, action_dim).to(device), QNet(lr_q, state_dim, action_dim).to(device), QNet(lr_q, state_dim, action_dim).to(device), QNet(lr_q, state_dim, action_dim).to(device)
    icm_params = (feature_hidden_sizes, feature_size, inverse_hidden_sizes, forward_hidden_sizes, β, icm_lr)
    pi = PolicyNet(lr_pi, icm_params, state_dim, action_dim).to(device)
    icm = ICM(state_size=state_dim, action_size=action_dim, icm_parameters=icm_params).to(device)
    
    
    '''
    icm_params = (feature_hidden_sizes, feature_size, inverse_hidden_sizes, forward_hidden_sizes, β, icm_lr)
    icm = ICM(state_size=state_dim, action_size=action_dim, icm_parameters=icm_params)'''

    q1_target.load_state_dict(q1.state_dict())
    q2_target.load_state_dict(q2.state_dict())

    score = 0.0
    print_interval = 20

    for n_epi in range(10000):
        s, _ = env.reset()
        done = False
        count = 0
        
        # Only render every 100 episodes
        render_env = (n_epi % 100 == 0)

        while count < 1000 and not done:
            
            #if render_env:
                #env.render()  # Render the environment
            
            
            a, log_prob= pi(torch.from_numpy(s).float())
            a = a.cpu().detach().numpy()
            s_prime, r, done, truncated, info = env.step(a)

            memory.put((s, a, r, s_prime, done))  #Changed to r from r/10.0
            score +=r
            s = s_prime
            count += 1

        icm_losses = []
        sac_losses = []
        if memory.size()>10000:
            for i in range(100):
                mini_batch = memory.sample(batch_size)
                mini_batch = tuple(t.to(device) for t in mini_batch)
                
                td_target = calc_target(pi, q1_target, q2_target, mini_batch)
                
                
                q1.train_net(td_target, mini_batch)
                q2.train_net(td_target, mini_batch)

                # Update ICM module
                s_batch, a_batch, r_batch, s_prime_batch, _ = mini_batch
                intrinsic_reward = icm.compute_intrinsic_reward(s_batch, a_batch, s_prime_batch)
                a_hat_batch, s_hat_batch = icm.predict(s_batch, a_batch, s_prime_batch)
                intrinsic_loss = icm.loss(s_batch, a_batch, a_hat_batch, s_prime_batch, s_hat_batch) #intrinsic_loss.shape : [] -> rank 0 tensor
                icm_losses.append(intrinsic_loss.item())
                
                icm.optimizer.zero_grad()
                intrinsic_loss.backward()
                icm.optimizer.step()
                
                #Update policy net
                sac_loss = pi.train_net(q1, q2, mini_batch, intrinsic_reward).mean().item()  #sac_loss.shape : [32,6] 
                #print(sac_loss.shape)
                sac_losses.append(sac_loss)
                
                
                q1.soft_update(q1_target)
                q2.soft_update(q2_target)
                
                
                
                
                
        if n_epi%print_interval==0 and n_epi!=0:
            print("# of episode :{}, avg score : {:.1f} alpha:{:.4f}".format(n_epi, score/print_interval, pi.log_alpha.exp()))
            print(f'Avg ICM loss: {np.mean(icm_losses)}')
            print(f'Avg SAC loss: {np.mean(sac_losses)}')
            score = 0.0

    env.close()

if __name__ == '__main__':
    main()

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.