In [1]:
from player import NetworkPlayer
from environments.Connect4.Network import CNN, ViT, CNN_old
from environments.Connect4.env_cython import Env
from policy_value_net import PolicyValueNet
import torch
import torch.nn as nn
from torch.optim import NAdam
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
from tqdm.auto import tqdm
from copy import deepcopy

In [2]:
class ReplayBuffer:
    def __init__(self, state_dim, capacity, action_dim, row, col, device='cpu'):
        self.state = torch.full(
            (capacity, state_dim, row, col), torch.nan, dtype=torch.float32, device=device)
        self.mask = torch.full((capacity, action_dim), torch.nan, dtype=torch.bool, device=device)
        self.count = 0
        self.device = device

    def __len__(self):
        return min(self.count, len(self.state))

    def is_full(self):
        return self.__len__() >= len(self.state)

    def reset(self):
        self.state = torch.full_like(
            self.state, torch.nan, dtype=torch.float32)
        self.mask = torch.full_like(self.mask, torch.nan, dtype=torch.bool)
        self.count = 0

    def to(self, device='cpu'):
        self.state = self.state.to(device)
        self.mask = self.mask.to(device)
        self.device = device

    def store(self, state, mask):
        idx = self.count % len(self.state)
        self.count += 1
        if isinstance(state, np.ndarray):
            state = torch.from_numpy(state).type(
                torch.FloatTensor).to(self.device)
        self.state[idx] = state
        if isinstance(mask, list):
            mask = torch.tensor(mask, dtype=torch.bool, device=self.device)
        self.mask[idx] = mask
        return idx
    
    def sample(self, batch_size):
        idx = torch.from_numpy(np.random.randint(
            0, self.__len__(), batch_size, dtype=np.int64))
        return self.state[idx], self.mask[idx]



class Game:
    def __init__(self, env):
        self.env = env

    def start_self_play(self, player, first_n_steps=5):
        self.env.reset()
        states,  masks = [], []
        steps = 0
        while True:
            if steps < first_n_steps:
                action, _ = player.get_action(self.env)
            else:
                action, _ = player.get_action(self.env)
            steps += 1
            states.append(self.env.current_state())
            masks.append(self.env.valid_mask())
            self.env.step(action)
            if self.env.done():
                return states, masks


def instant_augment(state_original, mask_original=None):
    state, mask = deepcopy(state_original), deepcopy(mask_original)
    for idx, i in enumerate(state):
        for idx_j, j in enumerate(i):
            state[idx, idx_j] = torch.fliplr(j)
        if mask_original is not None:
            mask[[idx]] = torch.fliplr(mask[[idx]])
    state = torch.concat([state, state_original])
    if mask_original is not None:
        mask = torch.concat([mask, mask_original])
    return state, mask


def quantile_huber_loss(pred, target, tau, kappa=1.0):
    assert pred.shape[1] == tau.shape[0], "pred and tau must have compatible shapes"
    target = target.expand_as(pred)
    diff = target - pred
    huber = torch.where(diff.abs() <= kappa, 0.5 * diff.pow(2), kappa * (diff.abs() - 0.5 * kappa))
    tau = tau.view(1, -1)
    loss = torch.abs(tau - (diff.detach() < 0).float()) * huber
    return loss.mean()

In [None]:
device = 'cuda'
game = Game(Env())
buffer = ReplayBuffer(3, 100000, 7, 6, 7, device=device)

teacher_net = CNN(0, device=device)
teacher_policy = PolicyValueNet(teacher_net, 0.99, './params/AZ2_Connect4_CNN_best.pt')
teacher_player = NetworkPlayer(teacher_policy, False)
teacher_player.eval()

Failed to load parameters.
Error(s) in loading state_dict for CNN:
	Missing key(s) in state_dict: "hidden.5.weight", "hidden.5.bias", "hidden.5.running_mean", "hidden.5.running_var", "hidden.8.weight", "hidden.8.bias", "hidden.9.weight", "hidden.9.bias", "hidden.9.running_mean", "hidden.9.running_var", "policy_head.1.weight", "policy_head.1.bias", "policy_head.4.weight", "policy_head.4.bias", "value_head.4.weight", "value_head.4.bias". 
	Unexpected key(s) in state_dict: "hidden.3.weight", "hidden.3.bias", "hidden.4.running_mean", "hidden.4.running_var", "hidden.4.num_batches_tracked", "hidden.6.weight", "hidden.6.bias", "hidden.7.weight", "hidden.7.bias", "hidden.7.running_mean", "hidden.7.running_var", "hidden.7.num_batches_tracked", "value_head.1.running_mean", "value_head.1.running_var", "value_head.1.num_batches_tracked", "value_head.3.weight", "value_head.3.bias". 
	size mismatch for hidden.4.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in curren

In [4]:
plt.ion()
epoch = 0
while epoch < 2000:
    epoch += 1
    states, masks = game.start_self_play(teacher_player, first_n_steps=0)
    for i in range(len(states)):
        buffer.store(states[i], masks[i])
    print(f'\repoch: {epoch}, num_example: {buffer.__len__()}', end='')

epoch: 2000, num_example: 42894

In [None]:
class Pretrain(nn.Module):
    def __init__(self):
        super().__init__()
        self.student = CNN(0, device=device)
        self.encoder = self.student.hidden
        self.decoder = nn.Sequential(nn.Linear(64 * 4, 64 * 4),
                                     nn.BatchNorm1d(64 * 4),
                                     nn.SiLU(True),
                                     nn.Linear(64 * 4, 42 * 2),
                                     nn.Sigmoid())
        self.opt = NAdam(self.parameters(), lr=1e-3, weight_decay=0.01, decoupled_weight_decay=True)
        self.device = device
        self.to(self.device)
    
    def save(self):
        torch.save(self.encoder.state_dict(), 'CNN_pretrained.pt')
    
    def forward(self, state):
        laten = self.encoder(state)
        dec = self.decoder(laten)
        return dec.view(-1, 2, 6, 7)


In [9]:
pt = Pretrain()

In [None]:
losses = []
for j in tqdm(range(100000)):
    mask = None
    state, mask = buffer.sample(256)
    state, mask = instant_augment(state, mask)
    pt.opt.zero_grad()
    state_pred = pt(state)
    loss = F.binary_cross_entropy(state_pred, state[:, :-1])
    loss.backward()
    pt.opt.step()
    losses.append(loss.item())
    if j % 20 == 0 and j != 0:
        plt.clf()
        plt.plot(losses, label='P Loss', alpha=0.5)
        plt.legend()
        plt.title(f'P Loss: {losses[-1]: .5f}')
        plt.tight_layout()
        plt.pause(0.1)
        pt.save()