# ReesSaver Discriminator Agent

## Notes:
- Roughly 90% of our boards are unique
- Every time you call generate_data it gets new games/games in a different order?

In [1]:
import chess

import chess.svg
import cv2
from IPython.display import display, SVG

import numpy as np
import random
from tqdm import tqdm
from importlib import reload
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import utils
reload(utils)
import utils


from sklearn.model_selection import train_test_split

  _C._set_default_tensor_type(t)


In [2]:
if torch.cuda.is_available():
    # Set default tensor type to CUDA tensors
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
    
else:
    
    torch.set_default_tensor_type(torch.FloatTensor)
 
print(torch.cuda.is_available())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

made_loader = False

True


In [3]:
def clear_cuda():
    
    input()
    
    for obj in gc.get_objects():
        if torch.is_tensor(obj):
            if obj.is_cuda:
                print(type(obj), obj.size(), obj.device)
                del obj
    torch.cuda.empty_cache()
    gc.collect()

In [4]:
clear_cuda()

<class 'torch.Tensor'> torch.Size([64]) cuda:0




In [4]:
from utils.Datasets import *

import utils.Dataloading
reload(utils.Dataloading)
from utils.Dataloading import *
from utils.Game_playing import *

import utils.Playing_agents
reload(utils.Playing_agents)
from utils.Playing_agents import *

from utils.CSV_data import *
from utils.Puzzles import *

In [5]:
class MLPv2_1(nn.Module):

    def __init__(self):

        super().__init__()
        self.conv1 = nn.Conv2d(14, 64, 3, 1, padding=1, padding_mode = 'zeros')

        self.layers = nn.ModuleList()

        self.depth = 6

        for _ in range(self.depth):
            self.layers.append(nn.Conv2d(64,64, 3, 1, padding=1, padding_mode = 'zeros'))
            self.layers.append(nn.BatchNorm2d(64))
            self.layers.append(nn.Conv2d(64,64, 3, 1, padding=1, padding_mode = 'zeros'))
            self.layers.append(nn.BatchNorm2d(64))

        self.linear = nn.Linear(4096, 128)

    def forward(self, x):

        x = self.conv1(x)

        for i in range(self.depth):
            j = i*4
            ph = x.clone()
            ph = self.layers[j](ph)
            ph = self.layers[j+1](ph)
            ph = F.relu(ph)
            ph = self.layers[j+2](ph)
            ph = self.layers[j+3](ph)

            x = x + ph
            x = F.relu(x)


        x = torch.flatten(x, start_dim=1)

        x = self.linear(x)

        minn, ila = x[:,:64], x[:,64:]

        return minn, ila

In [6]:
RDv2 = torch.load("Models/RDv2.3 CB.pt", map_location= device)

In [18]:
boards, meta, elo, moves, _, _, fens = generate_data("./Data/GAN_human_data.pgn", N = 40_000)
elo = [int(x) for x in elo]

0it [00:00, ?it/s]
100%|██████████| 500/500 [00:20<00:00, 23.94it/s]


In [8]:
class generator_1(nn.Module):

    def __init__(self, conv_depth):

        super().__init__()
        self.conv1 = nn.Conv2d(14, 64, 3, 1, padding=1, padding_mode = 'zeros')
        
        self.conv_layers = nn.ModuleList()
        self.conv_depth = conv_depth
        
        for i in range(self.conv_depth):
            self.conv_layers.append(nn.Conv2d(64,64, 3, 1, padding=1, padding_mode = 'zeros'))
            self.conv_layers.append(nn.BatchNorm2d(64))
            self.conv_layers.append(nn.Conv2d(64,64, 3, 1, padding=1, padding_mode = 'zeros'))
            if i < self.conv_depth - 1:
                self.conv_layers.append(nn.BatchNorm2d(64))

        self.linear = nn.Linear(4096, 128)

    def forward(self, x):

        x = self.conv1(x)
        
        for i in range(self.conv_depth):
            j = i*4
            ph = x.clone()
            ph = self.conv_layers[j](ph)
            ph = self.conv_layers[j+1](ph)
            ph = F.relu(ph)
            ph = self.conv_layers[j+2](ph)
            if i < self.conv_depth - 1:
                ph = self.conv_layers[j+3](ph)
            
            x = x + ph
            x = F.relu(x)   
        
        x = torch.flatten(x, start_dim=1)

        x = self.linear(x)
        minn, ila = x[:,:64], x[:,64:]

        minn = F.softmax(minn, dim=1)
        ila = F.softmax(ila, dim=1)

        return torch.cat([minn, ila], dim=1).view(-1, 2, 8, 8)

In [9]:
class discriminator_1(nn.Module):

    def __init__(self, conv_depth):

        super().__init__()
        self.conv1 = nn.Conv2d(16, 64, 3, 1, padding=1, padding_mode = 'zeros')

        self.conv_layers = nn.ModuleList()
        self.conv_depth = conv_depth

        for _ in range(self.conv_depth):
            self.conv_layers.append(nn.Conv2d(64,64, 3, 1, padding=1, padding_mode = 'zeros'))
            self.conv_layers.append(nn.BatchNorm2d(64))
            self.conv_layers.append(nn.Conv2d(64,64, 3, 1, padding=1, padding_mode = 'zeros'))
            self.conv_layers.append(nn.BatchNorm2d(64))
            
        self.linear = nn.Linear(4096, 1)
    

    def forward(self, board, move):
        
        x = torch.cat((board, move), dim = 1)

        x = self.conv1(x)
        
        for i in range(self.conv_depth):
            j = i*4
            ph = x.clone()
            ph = self.conv_layers[j](ph)
            ph = self.conv_layers[j+1](ph)
            ph = F.leaky_relu(ph)
            ph = self.conv_layers[j+2](ph)
            ph = self.conv_layers[j+3](ph)
            
            x = x + ph
            x = F.leaky_relu(x)
                  
                  
        x = torch.flatten(x, start_dim=1)
        
        x = self.linear(x)
        x = F.sigmoid(x)
        
        return x

In [21]:
class GAN_1(nn.Module):
    #AI: 0, Human: 1
    def __init__(self, g_conv_depth, d_conv_depth, lr):
        
        super().__init__()
        
        print(device)
        
        self.generator = generator_1(g_conv_depth)
        self.discriminator = discriminator_1(d_conv_depth)
        
        self.logs = {"g_acc": [0], "d_acc_r": [0], "d_acc_f": [0], 
                     "g_loss": [0], "d_loss": [0],"d_dist_f": [0], "d_dist_r": [0], 
                     "cur_g_loss": 0, "cur_d_loss": 0}
        
        self.made_loader = False
        
        self.configure_optimizers(lr)
        
    def forward(self, x):
        return self.generator(x)
    
    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)
    
    def train_step(self, train_boards, real_moves, train_generator):
        
        #train generator
        if train_generator:

            self.opt_g.zero_grad()
            
            fake_moves = self(train_boards)
                
            y_hat = self.discriminator(train_boards, fake_moves)
            
            y = torch.ones(real_moves.size(0), 1).to(device)
            
            g_loss = self.adversarial_loss(y_hat, y)
            self.logs["cur_g_loss"] += g_loss.item()
            
            g_loss.backward()

            self.opt_g.step()
              
            
        else:

            self.opt_d.zero_grad()
            
            y_hat_real = self.discriminator(train_boards, real_moves)
            y_real = torch.ones(real_moves.size(0), 1).to(device)
            
            d_real_loss = self.adversarial_loss(y_hat_real, y_real)
            
            y_hat_fake = self.discriminator(train_boards, self(train_boards).detach())
            y_fake = torch.zeros(real_moves.size(0), 1).to(device)
            
            d_fake_loss = self.adversarial_loss(y_hat_fake, y_fake)
            
            d_loss = d_real_loss + d_fake_loss
            self.logs["cur_d_loss"] += d_loss.item()
            
            d_loss.backward()

            self.opt_d.step()
            
    
    def configure_optimizers(self, lr):
        self.lr = lr
        self.opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas = (0.5, 0.999),  weight_decay=0.0)
        self.opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr,  betas = (0.5, 0.999), weight_decay=0.0)
    
    def on_epoch_end(self, epoch, G, val_data=None):
        
        self.logs["g_loss"].append(self.logs["cur_g_loss"] / G)
        self.logs["d_loss"].append(self.logs["cur_d_loss"] / G)
        
        print(f'Epoch {epoch+1} with g_loss: {self.logs["cur_g_loss"] / G} and d_loss: {self.logs["cur_d_loss"] / G}')
        
        self.logs["cur_g_loss"] = 0
        self.logs["cur_d_loss"] = 0
        
        if epoch % 1 == 0:
            
            if val_data is not None:
                
                val_boards = val_data.bitboards
                real_moves = val_data.moves
                
                #fake_moves_ind = torch.argmax(fake_moves, dim=1)
                values = np.zeros(5)
                
                for i in range(0,len(val_boards), 100):
                
                    with torch.no_grad():
                        
                        curr_boards, curr_moves = val_boards[i:i+100], real_moves[i:i+100]
    
                        fake_moves = self(curr_boards)
                        
                        f_pred = self.discriminator(curr_boards, fake_moves)
                        r_pred = self.discriminator(curr_boards, curr_moves) 
                        
                        values[0] += torch.mean(torch.round(f_pred) == 0, dtype=torch.float).item() #d_acc_f
                        values[1] += torch.mean(torch.round(r_pred) == 1, dtype=torch.float).item() #d_acc_r
                        
                        values[2] += torch.mean(torch.abs(f_pred)) #d_dist_f
                        values[3] += torch.mean(torch.abs(1 - r_pred)) #d_dist_r
                        
                        values[4] = torch.mean((curr_moves == torch.round(fake_moves)).all(dim=1), dtype=torch.float).item() # g_acc
                
                n = int(len(val_boards) // 100)
                d_acc_f, d_acc_r = values[0] / n, values[1] / n
                d_dist_f, d_dist_r = values[2] / n, values[3] / n
                g_acc = values[4] / n
                
                print(f'Epoch: {epoch+1}, {g_acc=}, {d_acc_f=}, {d_acc_r=}')
                print(f"{d_dist_f=}, {d_dist_r=}")
                
                self.logs["d_acc_f"].append(d_acc_f)
                self.logs["d_acc_r"].append(d_acc_r)
                self.logs["d_dist_f"].append(d_dist_f)
                self.logs["d_dist_r"].append(d_dist_r)
                self.logs["g_acc"].append(g_acc)
                
                
            if epoch % 5 == 0:
            
                torch.save(self.generator, f"generator {epoch}.pt")
                torch.save(self.discriminator, f"discriminator {epoch}.pt")
            
    def create_dataloader(self, boards, meta, moves, B, N, N_val):

        if self.made_loader:

            clear_cuda()
            
        loader = DataLoader(GANData(boards[:N], meta[:N], moves[:N]), batch_size = B, shuffle = True, generator=torch.Generator(device=device))
        val_loader = GANData(boards[N:N+N_val], meta[N:N+N_val], moves[N:N+N_val])
        
        self.made_loader = True
        
        return loader, val_loader
        

        

In [24]:
class GANData(Dataset):

    def __init__(self, bitboards, white_turn, moves):

        self.bitboards = torch.tensor(bitboards, dtype = torch.float).to(device)

        self.moves = np.zeros((self.bitboards.size(dim=0), 128))

        for ind, move in tqdm(enumerate(moves), total=len(moves)):

            minn = move.from_square
            ila = move.to_square
        
            if not white_turn[ind]:
                minn = (63 - minn) // 8 * 8 + minn % 8
                ila = (63 - ila) // 8 * 8 + ila % 8

            self.moves[ind,minn] = 1
            self.moves[ind, ila + 64] = 1
            
            
        self.moves = torch.tensor(self.moves.reshape(-1,2,8,8), dtype = torch.float).to(device)

    def __len__(self):

        return self.moves.size(dim=0)


    def __getitem__(self, idx):

        return self.bitboards[idx], self.moves[idx]
    

In [67]:
torch.tensor?

In [73]:
clear_cuda()

<class 'torch.Tensor'> torch.Size([64]) cuda:0
<class 'torch.nn.parameter.Parameter'> torch.Size([64, 14, 3, 3]) cuda:0
<class 'torch.nn.parameter.Parameter'> torch.Size([64]) cuda:0
<class 'torch.nn.parameter.Parameter'> torch.Size([128, 4096]) cuda:0
<class 'torch.nn.parameter.Parameter'> torch.Size([128]) cuda:0
<class 'torch.nn.parameter.Parameter'> torch.Size([64, 64, 3, 3]) cuda:0
<class 'torch.nn.parameter.Parameter'> torch.Size([64]) cuda:0
<class 'torch.nn.parameter.Parameter'> torch.Size([64]) cuda:0
<class 'torch.nn.parameter.Parameter'> torch.Size([64]) cuda:0
<class 'torch.Tensor'> torch.Size([64]) cuda:0
<class 'torch.Tensor'> torch.Size([64]) cuda:0
<class 'torch.Tensor'> torch.Size([]) cuda:0
<class 'torch.nn.parameter.Parameter'> torch.Size([64, 64, 3, 3]) cuda:0
<class 'torch.nn.parameter.Parameter'> torch.Size([64]) cuda:0
<class 'torch.nn.parameter.Parameter'> torch.Size([64]) cuda:0
<class 'torch.nn.parameter.Parameter'> torch.Size([64]) cuda:0
<class 'torch.Tensor

In [78]:
del RSv1

In [70]:
del loader

NameError: name 'loader' is not defined

In [22]:
RSv1 = GAN_1(g_conv_depth=6, d_conv_depth=6, lr=0.0002).to(device)

cuda


In [26]:
loader, val_data = RSv1.create_dataloader(boards, meta, moves, B = 512, N=1, N_val=5_000) # try B = 128
G = len(loader)

100%|██████████| 1/1 [00:00<?, ?it/s]
100%|██████████| 5000/5000 [00:00<00:00, 384523.37it/s]


In [80]:
del RSv1.discriminator

RSv1.discriminator = discriminator_1(conv_depth=4)
RSv1.configure_optimizers(0.001)

In [84]:
del RSv1.generator

RSv1.generator = generator_1(conv_depth=6)
RSv1.configure_optimizers(0.001)

In [82]:
train_discriminator = False
train_all = True

In [83]:
for epoch in range(0,50):

    reps = 0

    if train_all or train_discriminator:

        while RSv1.logs['d_acc_f'][-1] < 0.5:
            reps += 1
            if reps > 5:
                train_all = False
                train_discriminator = False
                break
            for bitboards, mvs in tqdm(loader):

                RSv1.train_step(bitboards, mvs, train_generator=False)


            RSv1.on_epoch_end(epoch, G, val_data)

    reps = 0
    if train_all or not train_discriminator:

        while RSv1.logs['d_acc_f'][-1] > 0.5:
            reps += 1
            if reps > 12:
                train_all = False
                train_discriminator = True
                break

            #i=0
            for bitboards, mvs in tqdm(loader):

               # if i > G // 8:
                #    break

                RSv1.train_step(bitboards, mvs, train_generator=True)
                #i += 1

            RSv1.on_epoch_end(epoch, G, val_data)

100%|██████████| 4493/4493 [02:59<00:00, 25.01it/s]


Epoch 1 with g_loss: 0.004287760811470996 and d_loss: 0.0
Epoch: 1, g_acc=0.01899999976158142, d_acc_f=0.0, d_acc_r=1.0
d_dist_f=0.9999998474121093, d_dist_r=2.0320058683864773e-06


100%|██████████| 4493/4493 [03:29<00:00, 21.49it/s]


Epoch 2 with g_loss: 0.0 and d_loss: 0.06635002356950542
Epoch: 2, g_acc=0.01899999976158142, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=7.648794911801815e-05, d_dist_r=0.00019360657781362535


100%|██████████| 4493/4493 [02:59<00:00, 25.10it/s]


Epoch 2 with g_loss: 0.002508534597285314 and d_loss: 0.0
Epoch: 2, g_acc=0.01882499933242798, d_acc_f=0.0, d_acc_r=1.0
d_dist_f=0.9999996185302734, d_dist_r=0.00019360657781362535


100%|██████████| 4493/4493 [03:29<00:00, 21.45it/s]


Epoch 3 with g_loss: 0.0 and d_loss: 0.015537992884172711
Epoch: 3, g_acc=0.01882499933242798, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=8.52902012411505e-07, d_dist_r=3.1524055521003902e-06


100%|██████████| 4493/4493 [02:59<00:00, 25.01it/s]


Epoch 3 with g_loss: 0.003629609857825479 and d_loss: 0.0
Epoch: 3, g_acc=0.018799999952316283, d_acc_f=0.0, d_acc_r=1.0
d_dist_f=0.9999446105957032, d_dist_r=3.1524055521003902e-06


100%|██████████| 4493/4493 [03:29<00:00, 21.46it/s]


Epoch 4 with g_loss: 0.0 and d_loss: 0.006927627948433602
Epoch: 4, g_acc=0.018799999952316283, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=8.595296822022647e-07, d_dist_r=2.106428146362305e-06


100%|██████████| 4493/4493 [02:58<00:00, 25.10it/s]


Epoch 4 with g_loss: 0.006885015155700993 and d_loss: 0.0
Epoch: 4, g_acc=0.01879687428474426, d_acc_f=0.0, d_acc_r=1.0
d_dist_f=0.9999990844726563, d_dist_r=2.106428146362305e-06


100%|██████████| 4493/4493 [03:29<00:00, 21.43it/s]


Epoch 5 with g_loss: 0.0 and d_loss: 0.011699299567960559
Epoch: 5, g_acc=0.01879687428474426, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=6.455757829826325e-07, d_dist_r=2.2595284099224953e-06


100%|██████████| 4493/4493 [02:59<00:00, 25.09it/s]


Epoch 5 with g_loss: 0.35884690350180903 and d_loss: 0.0
Epoch: 5, g_acc=0.018784373998641968, d_acc_f=0.0, d_acc_r=1.0
d_dist_f=1.0, d_dist_r=2.2595284099224953e-06


100%|██████████| 4493/4493 [03:28<00:00, 21.58it/s]


Epoch 6 with g_loss: 0.0 and d_loss: 0.02918795471404049
Epoch: 6, g_acc=0.018784373998641968, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=3.9477188693126666e-07, d_dist_r=3.194073913618922e-05


100%|██████████| 4493/4493 [03:00<00:00, 24.92it/s]


Epoch 6 with g_loss: 0.005034048886309246 and d_loss: 0.0
Epoch: 6, g_acc=0.01913749933242798, d_acc_f=0.0, d_acc_r=1.0
d_dist_f=0.9999996948242188, d_dist_r=3.194073913618922e-05


100%|██████████| 4493/4493 [03:28<00:00, 21.57it/s]


Epoch 7 with g_loss: 0.0 and d_loss: 0.02362753376944128
Epoch: 7, g_acc=0.01913749933242798, d_acc_f=0.9969999992847443, d_acc_r=1.0
d_dist_f=0.00353040874004364, d_dist_r=7.224798173410818e-07


100%|██████████| 4493/4493 [02:56<00:00, 25.39it/s]


Epoch 7 with g_loss: 0.004881928902387458 and d_loss: 0.0
Epoch: 7, g_acc=0.018928124904632568, d_acc_f=0.0, d_acc_r=1.0
d_dist_f=1.0, d_dist_r=7.224798173410818e-07


100%|██████████| 4493/4493 [03:26<00:00, 21.75it/s]


Epoch 8 with g_loss: 0.0 and d_loss: 0.015690438082670035
Epoch: 8, g_acc=0.018928124904632568, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=2.3744440113659946e-06, d_dist_r=3.36871191393584e-06


100%|██████████| 4493/4493 [02:57<00:00, 25.33it/s]


Epoch 8 with g_loss: 0.0038431868050956466 and d_loss: 0.0
Epoch: 8, g_acc=0.019128124713897705, d_acc_f=0.0, d_acc_r=1.0
d_dist_f=0.9999954223632812, d_dist_r=3.36871191393584e-06


100%|██████████| 4493/4493 [04:01<00:00, 18.64it/s]


Epoch 9 with g_loss: 0.0 and d_loss: 0.026913287200428602
Epoch: 9, g_acc=0.019128124713897705, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=4.929437767714262e-05, d_dist_r=2.8015492716804147e-06


100%|██████████| 4493/4493 [03:30<00:00, 21.31it/s]


Epoch 9 with g_loss: 0.007653867956165284 and d_loss: 0.0
Epoch: 9, g_acc=0.018868749141693116, d_acc_f=0.0, d_acc_r=1.0
d_dist_f=0.9999945831298828, d_dist_r=2.8015492716804147e-06


100%|██████████| 4493/4493 [04:05<00:00, 18.33it/s]


Epoch 10 with g_loss: 0.0 and d_loss: 0.005906360721313015
Epoch: 10, g_acc=0.018868749141693116, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=1.1656393326120451e-06, d_dist_r=5.790483555756509e-06


100%|██████████| 4493/4493 [03:31<00:00, 21.19it/s]


Epoch 10 with g_loss: 0.004655286997634937 and d_loss: 0.0
Epoch: 10, g_acc=0.01897812485694885, d_acc_f=0.0, d_acc_r=1.0
d_dist_f=0.9998670959472656, d_dist_r=5.790483555756509e-06


100%|██████████| 4493/4493 [04:03<00:00, 18.42it/s]


Epoch 11 with g_loss: 0.0 and d_loss: 0.03925020124004992
Epoch: 11, g_acc=0.01897812485694885, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=1.5041069127619267e-05, d_dist_r=2.826615236699581e-05


100%|██████████| 4493/4493 [03:30<00:00, 21.33it/s]


Epoch 11 with g_loss: 0.0029628585411492674 and d_loss: 0.0
Epoch: 11, g_acc=0.018784373998641968, d_acc_f=0.0, d_acc_r=1.0
d_dist_f=0.9999999237060547, d_dist_r=2.826615236699581e-05


100%|██████████| 4493/4493 [04:04<00:00, 18.34it/s]


Epoch 12 with g_loss: 0.0 and d_loss: 0.020135992072099484
Epoch: 12, g_acc=0.018784373998641968, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=2.5582491070963444e-06, d_dist_r=1.3029394904151558e-05


100%|██████████| 4493/4493 [03:31<00:00, 21.22it/s]


Epoch 12 with g_loss: 0.009459425657585875 and d_loss: 0.0
Epoch: 12, g_acc=0.018934375047683714, d_acc_f=0.0, d_acc_r=1.0
d_dist_f=1.0, d_dist_r=1.3029394904151558e-05


100%|██████████| 4493/4493 [04:03<00:00, 18.42it/s]


Epoch 13 with g_loss: 0.0 and d_loss: 0.02194856524546265
Epoch: 13, g_acc=0.018934375047683714, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=0.00015341638587415218, d_dist_r=1.783411717042327e-05


100%|██████████| 4493/4493 [03:30<00:00, 21.33it/s]


Epoch 13 with g_loss: 0.007010040686463505 and d_loss: 0.0
Epoch: 13, g_acc=0.018793749809265136, d_acc_f=0.0, d_acc_r=1.0
d_dist_f=0.9999958038330078, d_dist_r=1.783411717042327e-05


100%|██████████| 4493/4493 [04:05<00:00, 18.32it/s]


Epoch 14 with g_loss: 0.0 and d_loss: 0.01652132191709836
Epoch: 14, g_acc=0.018793749809265136, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=1.5006354078650475e-06, d_dist_r=3.877340932376683e-06


100%|██████████| 4493/4493 [03:31<00:00, 21.22it/s]


Epoch 14 with g_loss: 0.009511534265145833 and d_loss: 0.0
Epoch: 14, g_acc=0.018790624141693114, d_acc_f=0.0, d_acc_r=1.0
d_dist_f=0.9998432922363282, d_dist_r=3.877340932376683e-06


100%|██████████| 4493/4493 [04:03<00:00, 18.44it/s]


Epoch 15 with g_loss: 0.0 and d_loss: 0.007855906917346842
Epoch: 15, g_acc=0.018790624141693114, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=2.0268000662326815e-06, d_dist_r=2.829945005942136e-06


100%|██████████| 4493/4493 [03:30<00:00, 21.31it/s]


Epoch 15 with g_loss: 0.0069315352704833666 and d_loss: 0.0
Epoch: 15, g_acc=0.018793749809265136, d_acc_f=0.0, d_acc_r=1.0
d_dist_f=0.9992104339599609, d_dist_r=2.829945005942136e-06


100%|██████████| 4493/4493 [04:05<00:00, 18.31it/s]


Epoch 16 with g_loss: 0.0 and d_loss: 0.0023110087825357942
Epoch: 16, g_acc=0.018793749809265136, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=1.1533544602571054e-06, d_dist_r=4.20694297645241e-06


100%|██████████| 4493/4493 [03:31<00:00, 21.28it/s]


Epoch 16 with g_loss: 0.009541713942848728 and d_loss: 0.0
Epoch: 16, g_acc=0.018774999380111693, d_acc_f=0.0, d_acc_r=1.0
d_dist_f=0.9998307800292969, d_dist_r=4.20694297645241e-06


100%|██████████| 4493/4493 [04:05<00:00, 18.31it/s]


Epoch 17 with g_loss: 0.0 and d_loss: 0.0028175072504440807
Epoch: 17, g_acc=0.018774999380111693, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=3.256241370763746e-08, d_dist_r=1.2934565893374384e-06


100%|██████████| 4493/4493 [03:30<00:00, 21.33it/s]


Epoch 17 with g_loss: 17.426822918336427 and d_loss: 0.0
Epoch: 17, g_acc=0.018774999380111693, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=3.255334149798727e-08, d_dist_r=1.2934565893374384e-06


100%|██████████| 4493/4493 [03:32<00:00, 21.16it/s]


Epoch 17 with g_loss: 1.780792792977897 and d_loss: 0.0
Epoch: 17, g_acc=0.018806250095367433, d_acc_f=0.015599999837577342, d_acc_r=1.0
d_dist_f=0.9692221832275391, d_dist_r=1.2934565893374384e-06


100%|██████████| 4493/4493 [04:04<00:00, 18.41it/s]


Epoch 18 with g_loss: 0.0 and d_loss: 0.0021880602313824266
Epoch: 18, g_acc=0.018806250095367433, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=2.1156677394174038e-06, d_dist_r=4.462921351660043e-06


100%|██████████| 4493/4493 [03:30<00:00, 21.31it/s]


Epoch 18 with g_loss: 6.574983815745683 and d_loss: 0.0
Epoch: 18, g_acc=0.018793749809265136, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=0.004409245252609253, d_dist_r=4.462921351660043e-06


100%|██████████| 4493/4493 [03:32<00:00, 21.16it/s]


Epoch 18 with g_loss: 6.521637100539811 and d_loss: 0.0
Epoch: 18, g_acc=0.018793749809265136, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=0.004733724594116211, d_dist_r=4.462921351660043e-06


100%|██████████| 4493/4493 [03:32<00:00, 21.17it/s]


Epoch 18 with g_loss: 6.489567905657811 and d_loss: 0.0
Epoch: 18, g_acc=0.018793749809265136, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=0.004991126358509064, d_dist_r=4.462921351660043e-06


100%|██████████| 4493/4493 [03:31<00:00, 21.29it/s]


Epoch 18 with g_loss: 6.445286603168154 and d_loss: 0.0
Epoch: 18, g_acc=0.018793749809265136, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=0.004489080905914307, d_dist_r=4.462921351660043e-06


100%|██████████| 4493/4493 [03:30<00:00, 21.30it/s]


Epoch 18 with g_loss: 6.466447405472328 and d_loss: 0.0
Epoch: 18, g_acc=0.018793749809265136, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=0.005122327208518982, d_dist_r=4.462921351660043e-06


100%|██████████| 4493/4493 [03:32<00:00, 21.17it/s]


Epoch 18 with g_loss: 6.413994958080748 and d_loss: 0.0
Epoch: 18, g_acc=0.018793749809265136, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=0.005326629877090454, d_dist_r=4.462921351660043e-06


100%|██████████| 4493/4493 [03:31<00:00, 21.20it/s]


Epoch 18 with g_loss: 6.331321767993369 and d_loss: 0.0
Epoch: 18, g_acc=0.018793749809265136, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=0.005419407486915589, d_dist_r=4.462921351660043e-06


100%|██████████| 4493/4493 [03:30<00:00, 21.30it/s]


Epoch 18 with g_loss: 6.284683216607785 and d_loss: 0.0
Epoch: 18, g_acc=0.018793749809265136, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=0.005441796183586121, d_dist_r=4.462921351660043e-06


100%|██████████| 4493/4493 [03:31<00:00, 21.27it/s]


Epoch 18 with g_loss: 6.255509234843794 and d_loss: 0.0
Epoch: 18, g_acc=0.018793749809265136, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=0.00540615439414978, d_dist_r=4.462921351660043e-06


100%|██████████| 4493/4493 [03:33<00:00, 21.04it/s]


Epoch 18 with g_loss: 6.241311721644687 and d_loss: 0.0
Epoch: 18, g_acc=0.018793749809265136, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=0.005747881531715393, d_dist_r=4.462921351660043e-06


100%|██████████| 4493/4493 [03:31<00:00, 21.27it/s]


Epoch 18 with g_loss: 6.184393769299668 and d_loss: 0.0
Epoch: 18, g_acc=0.018793749809265136, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=0.005931388139724731, d_dist_r=4.462921351660043e-06


100%|██████████| 4493/4493 [03:32<00:00, 21.15it/s]


Epoch 18 with g_loss: 6.1839377320266475 and d_loss: 0.0
Epoch: 18, g_acc=0.018793749809265136, d_acc_f=1.0, d_acc_r=1.0
d_dist_f=0.005398737788200378, d_dist_r=4.462921351660043e-06


In [135]:
val_boards = val_data.bitboards
real_moves = val_data.moves

minn, ila = RDv2(val_boards)
fake_moves = torch.cat([F.softmax(minn, dim=1), F.softmax(ila,dim=1)], dim=1)
torch.mean((real_moves == torch.round(fake_moves)).all(dim=1), dtype=torch.float).item()

0.2510000169277191

In [30]:
old_generator = torch.load("generator 5.pt")
#old_discriminator = torch.load("discriminator 15.pt")


In [31]:
torch.mean(old_discriminator(val_data.bitboards[:1000], old_generator(val_data.bitboards[:1000])))

tensor(1.0000, grad_fn=<MeanBackward0>)

In [137]:
fake_moves = RSv1(val_boards)
print(fake_moves[0])

#fake_moves_ind = torch.argmax(fake_moves, dim=1)
print(torch.round(fake_moves[0]))

g_acc = torch.mean((real_moves == torch.round(fake_moves)).all(dim=1), dtype=torch.float).item()

tensor([1.7404e-42, 0.0000e+00, 2.2697e-41, 2.9514e-19, 2.4398e-16, 4.8634e-12,
        9.2207e-27, 3.0531e-33, 0.0000e+00, 0.0000e+00, 2.2381e-24, 2.3693e-12,
        1.0543e-04, 2.4633e-19, 4.6248e-19, 9.9989e-01, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 1.0044e-09, 2.2264e-06, 1.8982e-10, 8.1555e-16, 6.4270e-11,
        1.2278e-31, 0.0000e+00, 1.0315e-31, 1.0168e-20, 5.1734e-15, 8.6261e-07,
        3.3822e-08, 5.3728e-09, 0.0000e+00, 0.0000e+00, 8.4442e-37, 1.1300e-11,
        1.6480e-18, 1.1017e-12, 9.6544e-10, 5.9234e-10, 0.0000e+00, 0.0000e+00,
        3.2209e-39, 1.5248e-36, 2.6022e-42, 5.4401e-41, 0.0000e+00, 4.2039e-45,
        0.0000e+00, 0.0000e+00, 7.8557e-42, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+

In [15]:
test_against(lambda x: network_agent_prob_conv(x, RDv2), generator_model, N=100)

100%|██████████| 50/50 [00:33<00:00,  1.51it/s]
100%|██████████| 50/50 [00:33<00:00,  1.50it/s]


(97, 0, 3, 0.97)

In [14]:
generator_model = lambda x: network_agent_prob_conv(x, lambda y: old_generator(y).reshape(-1,128))

In [15]:
torch.cuda.memory_allocated() 

7274496