# Soft Actor Critic Implementation

In [25]:
import torch
import torch.nn as nn
import torch.optim as optim

import gymnasium as gym
import numpy as np
from tqdm import tqdm

from policy import CategoricalPolicy, GaussianPolicy
from network_utils import build_mlp

In [46]:
class SoftActorCritic(nn.Module):
    def __init__(self, obs_dim, act_dim):
        nn.Module.__init__(self)

        self.obs_dim = obs_dim
        self.act_dim = act_dim

        self.value = build_mlp(obs_dim, 1, 2, 64)
        self.value_optimizer = optim.AdamW(self.value.parameters())

        self.q = build_mlp(obs_dim + act_dim, 1, 2, 64)
        self.q_optimizer = optim.AdamW(self.q.parameters())

        self.network = build_mlp(obs_dim, act_dim, 2, 64)
        self.policy = GaussianPolicy(self.network, self.act_dim)
        self.policy_optimizer = optim.AdamW(self.policy.parameters())

        # In the form of (s_t, a_t, r(s_t, a-t), s_{t+1})
        self.replay_buffer = []

        self.action_sample_n = 100

    # Accpting input as torch
    def forward(self, obs):
        dist = self.policy.action_distribution(obs)
        return torch.tanh(dist.sample())

    # Accept numpy
    # def act(self, obs):
        

    def add_to_replay_buffer(self, obs, action, reward, next_obs, done):
        self.replay_buffer.append((obs, action, reward, next_obs, done))

    def update_value_approx(self):
        self.value_optimizer.zero_grad()

        # TODO: we can probably have this be better by doing something better
        # for s_t ~ D
        loss = 0
        for s in self.replay_buffer:
            dist = self.policy.action_distribution(self, torch.array([[s]]))
            t1 = self.value(s[0])
            t2 = 0
            for i in range(self.action_sample_n):
                a = dist.sample()
                t2 += self.q(torch.cat((s[0], a), 1)) - dist.log_prob(a)
            t2 /= self.action_sample_n

            loss += torch.square(t1 + t2)
        loss /= 2 * len(self.replay_buffer)

        loss.backward()
        self.value_optimizer.step()

    def update_value(self):
        self.value_optimizer.zero_grad()

        loss = 0
        for r in self.replay_buffer:
            s = r[0]
            dist = self.policy.action_distribution(self, torch.array([[s]]))
            a = dist.sample() 
            loss += self.value(s) * (self.value(s) - self.q(torch.cat((s, a), 1)) + dist.log_sample(a))

        loss /= len(self.replay_buffer)

        loss.backward()
        self.value_optimizer.step()

    def update_q(self):
        self.q_optimizer.zero_grad()

        loss = 0
        for r in self.replay_buffer:
            s = r[0]
            dist = self.policy.action_distribution(self, torch.array([[s]]))
            a = dist.sample() 
            loss += self.value(s) * (self.value(s) - self.q(torch.cat((s, a), 1)) + dist.log_sample(a))

        loss /= len(self.replay_buffer)

        loss.backward() 
        self.q_optimizer.step()

    # def update_q(self):


In [45]:
env = gym.make('Ant-v4')
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]

# Define your Soft Actor-Critic agent
agent = SoftActorCritic(obs_dim, act_dim)

gradient_steps = 100

# Training loop
def train(num_episodes, max_steps):

    for episode in tqdm(range(num_episodes)):
        obs = env.reset()[0]
        done = False
        episode_reward = 0

        # Collect trajectory data into replay buffer
        # TODO: should replay buffer be a set or a list?
        for step in range(max_steps):
            obs_ = torch.tensor(obs, dtype=torch.float32)
            action = agent(obs_).detach().numpy()
            next_obs, reward, done, info, _ = env.step(action)
            
            agent.add_to_replay_buffer(obs, action, reward, next_obs, done)

            obs = next_obs
            episode_reward += reward

            if done:
                break

        # print(f"Episode {episode+1}: Reward = {episode_reward}")

        # Now do a bunch of gradient steps
        for step in range(gradient_steps):
            

    

# Run the training loop
train(num_episodes=100, max_steps=1000)

Collect:   0%|          | 0/100 [00:00<?, ?it/s]


TypeError: SoftActorCritic.add_to_replay_buffer() takes 2 positional arguments but 6 were given