In [1]:
import torch
from torch import nn
from torch.distributions import Categorical

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
class Policy(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(4, 16, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.Flatten(start_dim=1),
            nn.Linear(7680, 3840),
            nn.LeakyReLU(),
            nn.Linear(3840, 1920),
            nn.LeakyReLU(),
            nn.Linear(1920, 960),
            nn.LeakyReLU(),
            nn.Linear(960, 9),
            nn.Softmax()
        )

    def forward(self, state):
        probabilities = self.model(state)
        distribution = Categorical(probabilities)
        action = distribution.sample()
        log_prob = distribution.log_prob(action)
        return action.item(), log_prob
    
    def load(self, model_path="reinforce_qwop.pth"):
        self.model.load_state_dict(torch.load(model_path))
        return self

    def save(self, model_path="reinforce_qwop.pth"):
        torch.save(self.model.state_dict(), model_path)
        return self


In [3]:
from tqdm import tqdm

def train(env, policy, episodes, max_t, gamma=0.9):
    optimizer = torch.optim.Adam(policy.parameters(), lr=1e-5)

    pbar = tqdm(range(episodes))
    for i in pbar:
        
        log_probs = []
        rewards = []

        state = env.reset()
        last_states = [torch.tensor(state) for _ in range(4)]

        for t in range(max_t):
            observation = torch.stack(last_states).to(device).unsqueeze(0)
            action, log_prob = policy(observation)
            state, reward, done, _ = env.step(action)

            last_states.pop(0)
            last_states.append(torch.tensor(state))

            log_probs.append(log_prob)
            rewards.append(reward)

            if done:
                break
        episode_len = len(rewards)
        
        cum_return = torch.zeros(episode_len, device=device)
        cum_return[-1] = rewards[-1]
        for i in range(episode_len - 2, -1, -1):
            cum_return[i] = gamma * cum_return[i + 1] + rewards[i]
        
        loss = - torch.sum(torch.concat(log_probs) * cum_return)
        # print(loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix({"loss": loss.item(), "reward": cum_return[0].item()})

        if i % 10 == 0:
            policy.save("reinforce_qwop.pth")
            policy.save("reinforce_qwop_backup.pth")


In [4]:
from qwop_env import QWOP_Env

env = QWOP_Env(headless=False).to(device)
policy = Policy().load("reinforce_qwop.pth").to(device)
train(env, policy, 1000, 1000)

  input = module(input)
  1%|          | 10/1000 [01:02<1:42:38,  6.22s/it, loss=0.41, reward=2.31]


KeyboardInterrupt: 