# CP Predictor

In [1]:
from typing import Callable, Iterable
import functools
import numpy as np
import os

import gym

import torch
# torch.multiprocessing.set_start_method('spawn')
import torch.nn as nn
from torch.utils.data import IterableDataset, DataLoader
from torch import optim

from a2c_ppo_acktr import algo, utils
from a2c_ppo_acktr.envs import make_vec_envs

## Predictors

In [20]:
class PredictorMLP(nn.Module):
    def __init__(self, in_size=128, hidden_size=256, out_size=128):
        super().__init__()
        self.fc_grp1 = nn.Sequential(
            nn.Linear(in_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
        )

        self.action_fc = nn.Sequential(nn.Linear(1, hidden_size), nn.ReLU())

        self.fc_grp2 = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, out_size),
        )

    def forward(self, obs, action):
        action = action.squeeze(-1)
        obs_encoding = self.fc_grp1(obs)
        action_encoding = self.action_fc(action)
        prod = obs_encoding * action_encoding
        out = self.fc_grp2(prod)
        return out

In [23]:
class PredictorMLPTwo(nn.Module):
    def __init__(self, in_size=128, hidden_size=256, out_size=128):
        super().__init__()
        self.fc_grp1 = nn.Sequential(
            nn.Linear(in_size, hidden_size),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_size),
            nn.Dropout(p=0.4),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_size),
            nn.Dropout(p=0.4),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
        )

        self.action_fc = nn.Sequential(nn.Linear(1, hidden_size), nn.ReLU())

        self.fc_grp2 = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_size),
            nn.Dropout(p=0.4),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_size),
            nn.Dropout(p=0.4),
            nn.Linear(hidden_size, out_size),
        )

    def forward(self, obs, action):
        action = action.squeeze(-1)
        obs_encoding = self.fc_grp1(obs)
        action_encoding = self.action_fc(action)
        prod = obs_encoding * action_encoding
        out = self.fc_grp2(prod)
        return out

## Data

In [3]:
class GeneratorDataset(IterableDataset):
    """Uses a generator function to generate batches of data."""

    def __init__(self, generator_fn: Callable):
        self.generator_fn = generator_fn

    def __iter__(self) -> Iterable:
        return self.generator_fn()

In [4]:
def experience_gen(env, agent):
    vis_obs = env.reset()
    prev_obs = env.get_attr("unwrapped")[0]._get_ram()
    
    if np.random.rand() < 0.1:
        action = torch.randint(0, 6, (1,)).unsqueeze(-1)
    else:
        _, action, _, _ = agent.act(vis_obs, None, None)
        action = action.cpu()

    while True:
        if np.random.rand() < 0.1:
            action = torch.randint(0, 6, (1,)).unsqueeze(-1)
        else:
            _, action, _, _ = agent.act(vis_obs, None, None)
            action = action.cpu()
        
        vis_obs, reward, done, _ = env.step(action)
        cur_obs = env.get_attr("unwrapped")[0]._get_ram()
        if done:
            yield prev_obs, action, cur_obs
            vis_obs = env.reset()
            prev_obs = env.get_attr("unwrapped")[0]._get_ram()
        else:
            yield prev_obs, action, cur_obs
            prev_obs = cur_obs

## Training

In [6]:
env_name = 'PongNoFrameskip-v4'
seed = 1
num_procs = 1
gamma = 0.99
log_dir = '/tmp/gym'
device = 'cuda:0'
envs = make_vec_envs(env_name, seed, num_procs,
                     gamma, log_dir, device, False)

In [8]:
trained_agent, _ = torch.load(
            os.path.join("trained_models/a2c", env_name + ".pt"), map_location=device
        )

In [9]:
trained_agent = trained_agent.eval()

In [10]:
envs.observation_space.shape

(4, 84, 84)

In [11]:
batch_size = 64

In [12]:
dataset = GeneratorDataset(functools.partial(experience_gen, env=envs, agent=trained_agent))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

In [13]:
print(best_loss)

NameError: name 'best_loss' is not defined

In [42]:
lr = 1e-3

In [43]:
mask = torch.zeros(batch_size, 128)
mask[:, 13] = 1
mask[:, 14] = 1
mask[:, 21] = 1
mask[:, 51] = 2
mask[:, 49] = 5
mask[:, 54] = 5
mask = mask.to(device)

In [44]:
def RAMLoss(pred, target):
    loss = nn.MSELoss()
    mse = loss(pred, target)
    
    ram_loss = (pred - target) * mask
    ram_loss = ram_loss.pow(2).mean()
#     print(ram_loss.shape)
    
#     cpu_score_loss = (pred[:, 13] - target[:, 13]).pow(2).mean()
#     player_score_loss = (pred[:, 14] - target[:, 14]).pow(2).mean()
#     cpu_paddle_loss = (pred[:, 21] - target[:, 21]).pow(2).mean()
#     player_paddle_loss = (pred[:, 51] - target[:, 51]).pow(2).mean()
#     ball_pos_loss = (pred[:, 49] - target[:, 49]).pow(2).mean() + (pred[:, 54] - target[:, 54]).pow(2).mean()
    
    return mse + ram_loss

In [45]:
RAMLoss(pred, cur)

tensor(1086.2400, device='cuda:0', grad_fn=<AddBackward0>)

In [51]:
try:
    net, best_loss = torch.load('trained_models/cp/predictor.pt', map_location=device)
except (FileNotFoundError, TypeError):
    net, best_loss = PredictorMLPTwo().to(device), 1000000

In [52]:
print(best_loss)

322.7232971191406


In [48]:
optimizer = optim.Adam(net.parameters(), lr=lr)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

In [49]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [50]:
for i, (prev, action, cur) in zip(range(20000), dataloader):
    prev = prev.float().to(device)
    action = action.float().to(device)
    cur = cur.float().to(device)
    pred = net(prev, action)
    
    optimizer.zero_grad()
    out = RAMLoss(pred, cur)
    out.backward()
    optimizer.step()
#     scheduler.step(out)
    
    if i % 100 == 0:
        print(f'{i}, {out.item()}')
    if out.item() < best_loss:
        best_loss = out.item()
        torch.save([net, best_loss], 'trained_models/cp/predictor.pt')

0, 910.559814453125
100, 1328.6968994140625
200, 918.6124877929688
300, 1159.4410400390625
400, 891.9005126953125
500, 1536.2138671875
600, 974.6658935546875
700, 700.4344482421875
800, 1389.30908203125
900, 1273.97705078125
1000, 738.9677124023438
1100, 927.365966796875
1200, 815.876708984375
1300, 1796.7689208984375
1400, 809.6547241210938
1500, 752.522216796875
1600, 811.7977294921875
1700, 1351.9534912109375
1800, 667.17822265625
1900, 850.1029052734375
2000, 998.177490234375
2100, 640.401123046875
2200, 2020.424072265625
2300, 758.96533203125
2400, 1732.5469970703125
2500, 822.7368774414062
2600, 635.6709594726562
2700, 790.884033203125
2800, 1797.23828125
2900, 1147.185791015625
3000, 1233.4232177734375
3100, 911.3038330078125
3200, 926.6458740234375
3300, 627.89990234375
3400, 559.9838256835938
3500, 1088.983154296875
3600, 1068.4588623046875
3700, 1086.847900390625
3800, 533.4752807617188
3900, 564.2099609375
4000, 822.1962890625
4100, 903.390625
4200, 726.7750244140625
4300, 1

KeyboardInterrupt: 

In [134]:
for i, (prev, action, cur) in zip(range(20000), dataloader):
    prev = prev.float().to(device)
    action = action.float().to(device)
    cur = cur.float().to(device)
    pred = net(prev, action)
    
    optimizer.zero_grad()
    out = RAMLoss(pred, cur)
    out.backward()
    optimizer.step()
#     scheduler.step(out)
    
    if i % 100 == 0:
        print(f'{i}, {out.item()}')
    if out.item() < best_loss:
        best_loss = out.item()
        torch.save([net, best_loss], 'trained_models/cp/predictor.pt')

0, 84.97222900390625
100, 244.71353149414062
200, 122.55136108398438
300, 77.8275146484375
400, 61.221317291259766
500, 82.46733093261719
600, 113.71481323242188
700, 103.4535903930664
800, 92.49735260009766
900, 100.73908996582031
1000, 114.33695220947266
1100, 98.61327362060547
1200, 98.59130096435547
1300, 89.2513427734375
1400, 104.36087036132812
1500, 89.0654296875
1600, 100.53943634033203
1700, 133.02536010742188
1800, 121.65409088134766
1900, 90.86433410644531
2000, 98.9177474975586
2100, 135.02392578125
2200, 150.87948608398438
2300, 100.12068176269531
2400, 109.56149291992188
2500, 108.75056457519531
2600, 94.17854309082031
2700, 110.979248046875
2800, 127.96937561035156
2900, 106.25205993652344
3000, 106.06056213378906
3100, 104.81507873535156
3200, 103.19715881347656
3300, 95.39431762695312
3400, 205.64013671875
3500, 91.30162048339844
3600, 97.04886627197266
3700, 109.35940551757812
3800, 101.76792907714844
3900, 262.92572021484375
4000, 98.69830322265625
4100, 95.427993774

KeyboardInterrupt: 

In [69]:
pred[:, 49]

tensor([186.3287, 186.1859, 185.9084, 185.6812, 185.3882, 185.1511, 184.8581,
        184.6210, 184.3173, 184.0700, 183.7642, 183.5169, 183.2110, 174.4593,
        174.8849, 171.5093, 171.7436, 175.7138, 175.9613, 172.4207, 172.6434,
        176.7037, 176.9512, 173.2953, 173.5764, 177.6667, 174.1609, 174.3833,
        178.4554, 178.7112, 175.0382, 175.2431, 182.2193, 186.8036, 186.6889,
        186.6277, 186.4290, 186.2093, 185.9328, 185.7132, 185.4308, 185.2050,
        184.9176, 184.6914, 184.3986, 184.1566, 183.8544, 183.6124, 174.6291,
        175.0713, 171.7400, 172.0045, 175.9480, 176.2103, 172.7980, 173.0378,
        176.9944, 177.2484, 173.7389, 174.0315, 178.0106, 174.3979, 174.5976,
        178.7258], device='cuda:0', grad_fn=<SelectBackward>)

In [37]:
def get_ball_pos(ram):
    ball_x = ram[49]  # X coordinate of ball
    ball_y = ram[54]
    return ball_x, ball_y

In [56]:
print(get_ball_pos(pred[13].int()))
print(get_ball_pos(cur[13]))

(tensor(82, device='cuda:0', dtype=torch.int32), tensor(44, device='cuda:0', dtype=torch.int32))
(tensor(52., device='cuda:0'), tensor(0., device='cuda:0'))


In [None]:
# ram = env.unwrapped._get_ram()  # get emulator RAM state
# cpu_score = ram[13]  # computer/ai opponent score 
# player_score = ram[14]  # your score
# cpu_paddle_y = ram[21]  # Y coordinate of computer paddle
# player_paddle_y = ram[51]  # Y coordinate of your paddle
# ball_x = ram[49]  # X coordinate of ball
# ball_y = ram[54]  