! There is a memory leak somewhere in here or dqn_qwop_agent

In [1]:
import torch
from torch import nn
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
import copy
from collections import deque
import random

class QWOP_Agent:
    def __init__(self, q_net, lr, sync_freq, exp_replay_size):
        self.q_net = q_net
        self.target_net = copy.deepcopy(self.q_net)

        self.loss_fn = torch.nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=lr)

        self.network_sync_freq = sync_freq
        self.network_sync_counter = 0
        self.gamma = 0.9
        self.experience_replay = deque(maxlen=exp_replay_size)
        self.max_exp_replay_size = exp_replay_size

        self.device = torch.device("cpu")
    
    def to(self, device):
        self.device = device
        self.q_net.to(device)
        self.target_net.to(device)
        return self
    
    def load(self, model_path="models/dqn.pth"):
        self.q_net.load_state_dict(torch.load(model_path))
        return self

    def save(self, model_path="models/dqn.pth"):
        torch.save(self.q_net.state_dict(), model_path)
        return self
    
    def get_q(self, state):
        with torch.no_grad():
            qp = self.target_net(state)
            
            return torch.max(qp, axis=1)[0]
    
    def get_action(self, state, temperature=0, epsilon=0):
        with torch.no_grad():
            Qp = self.q_net(state)
        
            if temperature > 0:
                # use boltzman exploration
                A = torch.multinomial(torch.softmax(Qp / temperature, 0), num_samples=1)
            elif epsilon > 0 and torch.rand(1, ).item() < epsilon:
                # epsilon greedy
                A = torch.randint(0, len(Qp[0]), (1,))
            else:
                # best move
                A = torch.max(Qp, axis=1)[1]
            
            return A
    
    def collect_experience(self, experience):
        self.experience_replay.append(experience)

    def sample_experience(self, sample_size):
        if len(self.experience_replay) < sample_size:
            sample_size = len(self.experience_replay)
        sample = random.sample(self.experience_replay, sample_size)
        s = torch.stack([exp[0] for exp in sample]).float()
        a = torch.tensor([exp[1] for exp in sample]).long()
        rn = torch.tensor([exp[2] for exp in sample]).float()
        sn = torch.stack([exp[3] for exp in sample]).float()
        return s, a, rn, sn

    def train(self, batch_size):
        s, a, rn, sn = self.sample_experience(batch_size)
        if self.network_sync_counter == self.network_sync_freq:
            self.optimizer.zero_grad()
            self.target_net.load_state_dict(self.q_net.state_dict())
            self.network_sync_counter = 0

        # predict expected return of current state using main network
        qp = self.q_net(s.to(self.device))
        pred_return = qp[range(batch_size), a]

        # get target return using target network
        q_next = self.get_q(sn.to(self.device))
        target_return = rn.to(self.device) + q_next * self.gamma

        # print(qp, pred_return, pred_return.shape, target_return.shape)

        loss = self.loss_fn(pred_return, target_return)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.network_sync_counter += 1
        return loss.item()

In [3]:
q_net = nn.Sequential(
    nn.Conv2d(4, 16, kernel_size=4, stride=2, padding=1),
    nn.LeakyReLU(),
    nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1),
    nn.LeakyReLU(),
    nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
    nn.LeakyReLU(),
    nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
    nn.LeakyReLU(),
    nn.Flatten(start_dim=1),
    nn.Linear(7680, 3840),
    nn.LeakyReLU(),
    nn.Linear(3840, 1920),
    nn.LeakyReLU(),
    nn.Linear(1920, 960),
    nn.LeakyReLU(),
    nn.Linear(960, 9)
)
agent = QWOP_Agent(q_net, 1e-4, 100, 20000).load().to(device)
state = torch.randn((1, 4, 160, 100)).to(device)
agent.get_action(state, epsilon=0.0)

tensor([7], device='cuda:0')

In [4]:
from qwop_env import QWOP_Env

env = QWOP_Env(headless=True).to(device)

In [5]:
from tqdm import tqdm
from collections import deque

def train(env, agent, episodes=20000, epsilons=(0.2, 0.05, 1e-4)):
    epsilon = epsilons[0]
    
    obslist = deque(maxlen=4)

    index = 0
    pbar = tqdm(range(episodes))
    for i in pbar:
        obs, done, losses, ep_len, rew = env.reset(), False, 0, 0, 0
        for _ in range(4): 
            obslist.append(obs)
    
        for _ in range(1000):
            ep_len += 1
            obs = torch.tensor(np.array(obslist)).to(device)
            A = agent.get_action(obs.unsqueeze(0), epsilon=epsilon)
            obs_next, reward, done, _ = env.step(A.item())

            obslist.popleft()
            obslist.append(obs_next)
            obs_next = torch.tensor(np.array(obslist))
            agent.collect_experience([obs, A.item(), reward, obs_next])

            rew += reward
            index += 1

            if index > 128:
                index = 0
                loss = agent.train(64)
                losses += loss
            
            if done:
                break

            pbar.set_postfix({"rew": rew, "cur_rew": reward})
        
        if epsilon > epsilons[1]:
            epsilon -= epsilons[2]

        agent.save()
        agent.save("models/dqn_backup.pth")

        pbar.set_postfix({"rew": rew})

train(env, agent)

 42%|████▏     | 8368/20000 [14:27:14<20:05:30,  6.22s/it, rew=3, cur_rew=3]        


WebDriverException: Message: disconnected: not connected to DevTools
  (failed to check if window was closed: disconnected: not connected to DevTools)
  (Session info: headless chrome=116.0.5845.141)
Stacktrace:
	GetHandleVerifier [0x00007FF758D952A2+57122]
	(No symbol) [0x00007FF758D0EA92]
	(No symbol) [0x00007FF758BDE3AB]
	(No symbol) [0x00007FF758BCBA47]
	(No symbol) [0x00007FF758BCB6C0]
	(No symbol) [0x00007FF758BDFA71]
	(No symbol) [0x00007FF758C4E27F]
	(No symbol) [0x00007FF758C36DB3]
	(No symbol) [0x00007FF758C0D2B1]
	(No symbol) [0x00007FF758C0E494]
	GetHandleVerifier [0x00007FF75903EF82+2849794]
	GetHandleVerifier [0x00007FF759091D24+3189156]
	GetHandleVerifier [0x00007FF75908ACAF+3160367]
	GetHandleVerifier [0x00007FF758E26D06+653702]
	(No symbol) [0x00007FF758D1A208]
	(No symbol) [0x00007FF758D162C4]
	(No symbol) [0x00007FF758D163F6]
	(No symbol) [0x00007FF758D067A3]
	BaseThreadInitThunk [0x00007FFD5FDC7614+20]
	RtlUserThreadStart [0x00007FFD618C26B1+33]
