In [26]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import gym

class CriticNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.fc1 = nn.Linear(4, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 1)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

critic_net = CriticNet()
writer = SummaryWriter()
env = gym.make('CartPole-v1')

def t(x):
    return torch.tensor(x, dtype=torch.float)

for episode in range(4000):
    previous_state = env.reset()
    done = False
    critic_net.zero_grad()
    
    while not done:
        action = 0
        state, reward, done, _ = env.step(action)
        
        previous_state_value = critic_net(t(previous_state))
        state_value = 0 if done else critic_net(t(state))
        td_error = reward + state_value - previous_state_value
        previous_state = state
        
        if done:
            writer.add_scalar('critic/td_error', td_error, episode)
            
        for p in critic_net.parameters():
            if p.grad is not None:
                p.grad.data *= 0.8
            
        previous_state_value.backward()
            
        for p in critic_net.parameters():
            p.data += 5e-4 * td_error * p.grad.data

writer.close()