# CartPole-v1 Simple Solution

## Value Function Approximation (Semi-gradient TD(λ))

In [211]:
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch
import torch.optim as optim

value_limit = 300

def softclamp(x):
    exp = torch.exp(x)
    return (exp - 1) / (exp + 1)

class ValueNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.fc1 = nn.Linear(4, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 1)
        
        for p in self.parameters():
            if isinstance(p, nn.Linear):
                init.savier_uniform_(p.weight)
                init.zeros_(p.bias)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class Value:
    def __init__(self, learning_rate, trace_decay_rate):
        self.value_net = ValueNet()
        self.learning_rate = learning_rate
        self.trace_decay_rate = trace_decay_rate
        self.optim = optim.SGD(self.value_net.parameters(), lr=learning_rate)
    
    def reset(self):
        self.value_net.zero_grad()
    
    def begin_log(self, logger_function):
        self.logger_function = logger_function
    
    def end_log(self):
        self.logger_function = None
        
    def step(self, previous_state, state, reward):
        #for p in self.value_net.parameters():
        #    if p.grad is not None:
        #        p.grad.data *= self.trace_decay_rate
        
        previous_state_value = self.value_net(previous_state)
        #previous_state_value.backward()
        
        state_value = self.value_net(state).detach()
        td_error = reward + state_value - previous_state_value
        loss = td_error.pow(2).mean()
        
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
        
        for p in self.value_net.parameters():
            assert not torch.isnan(p.grad.data).any()
            #p.data += torch.clamp(self.learning_rate * td_error * p.grad.data, -10, 10)
        
        if self.logger_function is not None:
            self.logger_function(loss, self.value_net)
    
    def batch(self, trajectory_states, trajectory_rewards):
        self.value_net.zero_grad()
        state_values = self.value_net(trajectory_states)
        T = len(trajectory_states) - 1
        total_td_error = 0
        
        for t in range(T):
            for p in self.value_net.parameters():
                if p.grad is not None:
                    p.grad.data *= self.trace_decay_rate
            
            previous_state_value = state_values[t]
            previous_state_value.backward(retain_graph=True)
            state_value = state_values[t + 1]
            reward = trajectory_rewards[t]
            td_error = reward + state_value - previous_state_value
            
            total_td_error += td_error
            
            assert td_error == td_error
            
            for p in self.value_net.parameters():
                assert not torch.isnan(p.grad.data).any()
                p.data += self.learning_rate * td_error * p.grad.data
        
        if self.logger_function is not None:
            self.logger_function(total_td_error / T, self.value_net)

In [212]:
import math

class Policy:
    def __init__(self, env):
        self.env = env
        
    def sample_action(self):
        return 0

def prepare(state):
    return torch.tensor(state, dtype=torch.float)

def train_episode(env, policy, value):
    #value.reset()
    previous_state = prepare(env.reset())
    
    while True:
        action = policy.sample_action()
        state, reward, done, _ = env.step(action)
        state = prepare(state)
        value.step(previous_state, state, reward)
        previous_state = state
        
        if done:
            break

def train(env, policy, value, episodes=100, batch=False):
    for episode in range(episodes):
        if batch:
            train_episode_batch(env, policy, value)
        else:
            train_episode(env, policy, value)

def train_episode_batch(env, policy, value):
    trajectory_states = [prepare(env.reset())]
    trajectory_rewards = []
    
    while True:
        action = policy.sample_action()
        state, reward, done, _ = env.step(action)
        state = prepare(state)
        trajectory_states.append(state)
        trajectory_rewards.append(reward)
        
        if done:
            break
    
    trajectory_states = torch.stack(trajectory_states)
    trajectory_rewards = torch.tensor(trajectory_rewards, dtype=torch.float)
    td_error = value.batch(trajectory_states, trajectory_rewards)

In [213]:
import gym
from torch.utils.tensorboard import SummaryWriter

env = gym.make('CartPole-v1')
value = Value(1e-3, 0)
policy = Policy(env)
writer = SummaryWriter()

def make_logger():
    log_step = 0
    
    def log_td_error(average_td_error, value_net):
        nonlocal log_step
        writer.add_scalar('critic/td_error', average_td_error, log_step)
        grads = torch.cat([torch.flatten(p.grad.data) for p in value_net.parameters()])
        writer.add_scalar('critic/mean_grad', torch.mean(grads), log_step)
        writer.add_scalar('critic/mean_value', torch.mean(value_net(
            torch.randn(1000, 4)
        )), log_step)
        log_step += 1
    
    return log_td_error

value.begin_log(make_logger())

In [214]:
td_errors = train(env, policy, value, episodes=4000, batch=False)

AssertionError: 