# ReesSaver Discriminator Agent

## Notes:
- Roughly 90% of our boards are unique
- Every time you call generate_data it gets new games/games in a different order?

In [1]:
import chess

import chess.svg
import cv2
from IPython.display import display, SVG

import numpy as np
import random
from tqdm import tqdm
from importlib import reload
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import utils
reload(utils)
import utils


from sklearn.model_selection import train_test_split

  _C._set_default_tensor_type(t)


In [2]:
if torch.cuda.is_available():
    # Set default tensor type to CUDA tensors
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
    
else:
    
    torch.set_default_tensor_type(torch.FloatTensor)
 
print(torch.cuda.is_available())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

made_loader = False

True


In [3]:
def clear_cuda():
    
    input()
    
    for obj in gc.get_objects():
        if torch.is_tensor(obj):
            if obj.is_cuda:
                print(type(obj), obj.size(), obj.device)
                del obj
    torch.cuda.empty_cache()
    gc.collect()

In [4]:
clear_cuda()

<class 'torch.Tensor'> torch.Size([64]) cuda:0




In [32]:
from utils.Datasets import *

import utils.Dataloading
reload(utils.Dataloading)
from utils.Dataloading import *
from utils.Game_playing import *

import utils.Playing_agents
reload(utils.Playing_agents)
from utils.Playing_agents import *

from utils.CSV_data import *
from utils.Puzzles import *

In [34]:
class MLPv2_1(nn.Module):

    def __init__(self):

        super().__init__()
        self.conv1 = nn.Conv2d(14, 64, 3, 1, padding=1, padding_mode = 'zeros')

        self.layers = nn.ModuleList()

        self.depth = 6

        for _ in range(self.depth):
            self.layers.append(nn.Conv2d(64,64, 3, 1, padding=1, padding_mode = 'zeros'))
            self.layers.append(nn.BatchNorm2d(64))
            self.layers.append(nn.Conv2d(64,64, 3, 1, padding=1, padding_mode = 'zeros'))
            self.layers.append(nn.BatchNorm2d(64))

        self.linear = nn.Linear(4096, 128)

    def forward(self, x):

        x = self.conv1(x)

        for i in range(self.depth):
            j = i*4
            ph = x.clone()
            ph = self.layers[j](ph)
            ph = self.layers[j+1](ph)
            ph = F.relu(ph)
            ph = self.layers[j+2](ph)
            ph = self.layers[j+3](ph)

            x = x + ph
            x = F.relu(x)


        x = torch.flatten(x, start_dim=1)

        x = self.linear(x)

        minn, ila = x[:,:64], x[:,64:]

        return minn, ila

In [35]:
RDv2 = torch.load("Models/RDv2.3 CB.pt", map_location= device)

In [6]:
boards, meta, elo, moves, _, _, fens = generate_data("./Data/GAN_human_data.pgn", N = 40_000)
elo = [int(x) for x in elo]

0it [00:00, ?it/s]
100%|██████████| 40000/40000 [18:08<00:00, 36.74it/s]  


In [7]:
class generator_1(nn.Module):

    def __init__(self, conv_depth, hidden_size, hidden_depth):

        super().__init__()
        self.conv1 = nn.Conv2d(14, 64, 3, 1, padding=1, padding_mode = 'zeros')
        
        self.conv_layers = nn.ModuleList()
        self.hidden_layers = nn.ModuleList()
        
        self.conv_depth = conv_depth
        self.hidden_depth = hidden_depth
        self.hidden_size = hidden_size
        
        for _ in range(self.conv_depth):
            self.conv_layers.append(nn.Conv2d(64,64, 3, 1, padding=1, padding_mode = 'zeros'))
            self.conv_layers.append(nn.BatchNorm2d(64))
            self.conv_layers.append(nn.Conv2d(64,64, 3, 1, padding=1, padding_mode = 'zeros'))
            self.conv_layers.append(nn.BatchNorm2d(64))

        self.hidden_layers.append(nn.Linear(4096, hidden_size))
        
        for _ in range(self.hidden_depth - 2):

            self.hidden_layers.append(nn.Linear(hidden_size, hidden_size))

        self.hidden_layers.append(nn.Linear(hidden_size, 128))

    def forward(self, x):

        x = self.conv1(x)
        
        for i in range(self.conv_depth):
            j = i*4
            ph = x.clone()
            ph = self.conv_layers[j](ph)
            ph = self.conv_layers[j+1](ph)
            ph = F.relu(ph)
            ph = self.conv_layers[j+2](ph)
            ph = self.conv_layers[j+3](ph)
            
            x = x + ph
            x = F.relu(x)   
        
        x = torch.flatten(x, start_dim=1)
        
        for i in range(self.hidden_depth - 1):
            
            x = self.hidden_layers[i](x)
            x = F.leaky_relu(x)

        x = self.hidden_layers[-1](x)
        minn, ila = x[:,:64], x[:,64:]

        minn = F.softmax(minn, dim=1)
        ila = F.softmax(ila, dim=1)

        return torch.cat([minn, ila], dim=1)

In [9]:
class discriminator_1(nn.Module):

    def __init__(self, conv_depth, hidden_size, hidden_depth):

        super().__init__()
        self.conv1 = nn.Conv2d(14, 64, 3, 1, padding=1, padding_mode = 'zeros')

        self.conv_layers = nn.ModuleList()
        self.hidden_layers = nn.ModuleList()

        self.conv_depth = conv_depth
        self.hidden_depth = hidden_depth
        self.hidden_size = hidden_size

        for _ in range(self.conv_depth):
            self.conv_layers.append(nn.Conv2d(64,64, 3, 1, padding=1, padding_mode = 'zeros'))
            self.conv_layers.append(nn.BatchNorm2d(64))
            self.conv_layers.append(nn.Conv2d(64,64, 3, 1, padding=1, padding_mode = 'zeros'))
            self.conv_layers.append(nn.BatchNorm2d(64))

        self.hidden_layers.append(nn.Linear(4096, hidden_size))
        
        self.hidden_layers.append(nn.Linear(hidden_size + 128, hidden_size))

        for _ in range(self.hidden_depth - 3):

            self.hidden_layers.append(nn.Linear(hidden_size, hidden_size))

        self.hidden_layers.append(nn.Linear(hidden_size, 1))
    

    def forward(self, board, move):

        x = self.conv1(board)
        
        for i in range(self.conv_depth):
            j = i*4
            ph = x.clone()
            ph = self.conv_layers[j](ph)
            ph = self.conv_layers[j+1](ph)
            ph = F.relu(ph)
            ph = self.conv_layers[j+2](ph)
            ph = self.conv_layers[j+3](ph)
            
            x = x + ph
            x = F.relu(x)
                  
                  
        x = torch.flatten(x, start_dim=1)

        for i in range(self.hidden_depth - 1):

            if i == 1:
                x = torch.cat((x, move), dim=1)
                
            x = self.hidden_layers[i](x)
            x = F.leaky_relu(x)

        x = self.hidden_layers[-1](x)
        x = F.sigmoid(x)
        
        return x

In [126]:
class GAN_1(nn.Module):
    #AI: 0, Human: 1
    def __init__(self, g_conv_depth, g_hidden_size, g_hidden_depth, d_conv_depth, d_hidden_size, d_hidden_depth, lr):
        
        super().__init__()
        
        print(device)
        
        self.generator = generator_1(g_conv_depth, g_hidden_size, g_hidden_depth)
        self.discriminator = discriminator_1(d_conv_depth, d_hidden_size, d_hidden_depth)
        
        self.logs = {"g_acc": [], "d_acc_r": [], "d_acc_f": [], 
                     "g_loss": [], "d_loss": [], 
                     "cur_g_loss": 0, "cur_d_loss": 0}
        
        self.made_loader = False
        
        self.configure_optimizers(lr)
        
    def forward(self, x):
        return self.generator(x)
    
    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)
    
    def train_step(self, train_boards, real_moves, train_generator):
        
        #train generator
        if train_generator:

            self.opt_g.zero_grad()
            
            fake_moves = self(train_boards)
                
            y_hat = self.discriminator(train_boards, fake_moves)
            
            y = torch.ones(real_moves.size(0), 1).to(device)
            
            g_loss = self.adversarial_loss(y_hat, y)
            self.logs["cur_g_loss"] += g_loss.item()
            
            g_loss.backward()

            self.opt_g.step()
              
            
        else:

            self.opt_d.zero_grad()
            
            y_hat_real = self.discriminator(train_boards, real_moves)
            y_real = torch.ones(real_moves.size(0), 1).to(device)
            
            d_real_loss = self.adversarial_loss(y_hat_real, y_real)
            
            y_hat_fake = self.discriminator(train_boards, self(train_boards).detach())
            y_fake = torch.zeros(real_moves.size(0), 1).to(device)
            
            d_fake_loss = self.adversarial_loss(y_hat_fake, y_fake)
            
            d_loss = d_real_loss + d_fake_loss
            self.logs["cur_d_loss"] += d_loss.item()
            
            d_loss.backward()

            self.opt_d.step()
            
    
    def configure_optimizers(self, lr):
        self.lr = lr
        self.opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, weight_decay=0.0001)
        self.opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr,weight_decay=0.0001)
    
    def on_epoch_end(self, epoch, G, val_data=None):
        
        self.logs["g_loss"].append(self.logs["cur_g_loss"] / G)
        self.logs["d_loss"].append(self.logs["cur_d_loss"] / G)
        
        print(f'Epoch {epoch+1} with g_loss: {self.logs["cur_g_loss"] / G} and d_loss: {self.logs["cur_d_loss"] / G}')
        
        self.logs["cur_g_loss"] = 0
        self.logs["cur_d_loss"] = 0
        
        if epoch % 1 == 0:
            
            if val_data is not None:
                
                val_boards = val_data.bitboards
                real_moves = val_data.moves
                
                fake_moves = self(val_boards)
                #fake_moves_ind = torch.argmax(fake_moves, dim=1)
                
                with torch.no_grad():
                    f_pred = self.discriminator(val_boards, fake_moves)
                    r_pred = self.discriminator(val_boards, real_moves)
                    
                    d_acc_f = torch.mean(torch.round(f_pred) == 0, dtype=torch.float).item()
                    d_acc_r = torch.mean(torch.round(r_pred) == 1, dtype=torch.float).item()
                    
                    d_dist_f = torch.mean(torch.abs(f_pred))
                    d_dist_r = torch.mean(torch.abs(1 - r_pred))
                    
                    g_acc = torch.mean((real_moves == torch.round(fake_moves)).all(dim=1), dtype=torch.float).item()
                
                print(f'Epoch: {epoch+1}, {g_acc=}, {d_acc_f=}, {d_acc_r=}')
                print(f"{d_dist_f=}, {d_dist_r=}")
                
                self.logs["d_acc_f"].append(d_acc_f)
                self.logs["d_acc_r"].append(d_acc_r)
                self.logs["d_dist_f"].append(d_dist_f)
                self.logs["d_dist_r"].append(d_dist_r)
                self.logs["g_acc"].append(g_acc)
                
        
            torch.save(self.generator, f"generator {epoch}.pt")
            torch.save(self.discriminator, f"discriminator {epoch}.pt")
            
    def create_dataloader(self, boards, meta, moves, B, N, N_val):

        if self.made_loader:

            clear_cuda()
            
        loader = DataLoader(GANData(boards[:N], meta[:N], moves[:N]), batch_size = B, shuffle = True, generator=torch.Generator(device=device))
        val_loader = GANData(boards[N:N+N_val], meta[N:N+N_val], moves[N:N+N_val])
        
        self.made_loader = True
        
        return loader, val_loader
        

        

In [11]:
class GANData(Dataset):

    def __init__(self, bitboards, white_turn, moves):

        self.bitboards = torch.tensor(bitboards, dtype = torch.float).to(device)

        self.moves = torch.zeros((self.bitboards.size(dim=0), 128), dtype = torch.float).to(device)

        for ind, move in tqdm(enumerate(moves), total=len(moves)):

            minn = move.from_square
            ila = move.to_square
        
            if not white_turn[ind]:
                minn = (63 - minn) // 8 * 8 + minn % 8
                ila = (63 - ila) // 8 * 8 + ila % 8

            self.moves[ind,minn] = 1
            self.moves[ind, ila + 64] = 1


    def __len__(self):

        return self.moves.size(dim=0)


    def __getitem__(self, idx):

        return self.bitboards[idx], self.moves[idx]
    

In [16]:
clear_cuda()

<class 'torch.Tensor'> torch.Size([64]) cuda:0
<class 'torch.Tensor'> torch.Size([64]) cuda:0
<class 'torch.Tensor'> torch.Size([64]) cuda:0
<class 'torch.Tensor'> torch.Size([]) cuda:0
<class 'torch.Tensor'> torch.Size([64]) cuda:0
<class 'torch.Tensor'> torch.Size([64]) cuda:0
<class 'torch.Tensor'> torch.Size([]) cuda:0
<class 'torch.Tensor'> torch.Size([64]) cuda:0
<class 'torch.Tensor'> torch.Size([64]) cuda:0
<class 'torch.Tensor'> torch.Size([]) cuda:0
<class 'torch.Tensor'> torch.Size([64]) cuda:0
<class 'torch.Tensor'> torch.Size([64]) cuda:0
<class 'torch.Tensor'> torch.Size([]) cuda:0
<class 'torch.Tensor'> torch.Size([64]) cuda:0
<class 'torch.Tensor'> torch.Size([64]) cuda:0
<class 'torch.Tensor'> torch.Size([]) cuda:0
<class 'torch.Tensor'> torch.Size([64]) cuda:0
<class 'torch.Tensor'> torch.Size([64]) cuda:0
<class 'torch.Tensor'> torch.Size([]) cuda:0
<class 'torch.Tensor'> torch.Size([64]) cuda:0
<class 'torch.Tensor'> torch.Size([64]) cuda:0
<class 'torch.Tensor'> to

In [93]:
del RSv1

In [94]:
RSv1 = GAN_1(g_conv_depth=6, g_hidden_size=1024, g_hidden_depth=2, 
             d_conv_depth=4, d_hidden_size=512, d_hidden_depth=3,
             lr=0.001).to(device)

cuda


In [17]:
loader, val_data = RSv1.create_dataloader(boards, meta, moves, B = 512, N=2_300_000, N_val=1_000)
G = len(loader)

100%|██████████| 2300000/2300000 [02:18<00:00, 16665.51it/s]
100%|██████████| 1000/1000 [00:00<00:00, 13151.30it/s]


In [80]:
del RSv1.discriminator

RSv1.discriminator = discriminator_1(conv_depth=4, hidden_size=512, hidden_depth=3)
RSv1.configure_optimizers(0.001)

In [84]:
del RSv1.generator

RSv1.generator = generator_1(conv_depth=6, hidden_size=1024, hidden_depth=2)
RSv1.configure_optimizers(0.001)

In [75]:
generator_1?

In [110]:
train_discriminator = False
train_all = True

In [142]:
for epoch in range(51,150):

    reps = 0

    if train_all or train_discriminator:

        while RSv1.logs['d_acc_f'][-1] < 0.5:
            reps += 1
            if reps > 5:
                train_all = False
                train_discriminator = False
                break
            for bitboards, mvs in tqdm(loader):

                RSv1.train_step(bitboards, mvs, train_generator=False)


            RSv1.on_epoch_end(epoch, G, val_data)

    reps = 0
    if train_all or not train_discriminator:

        while RSv1.logs['d_acc_f'][-1] > 0.5:
            reps += 1
            if reps > 12:
                train_all = False
                train_discriminator = True
                break

            i=0
            for bitboards, mvs in tqdm(loader):

                if i > G // 8:
                    break

                RSv1.train_step(bitboards, mvs, train_generator=True)
                i += 1

            RSv1.on_epoch_end(epoch, G, val_data)

100%|██████████| 4493/4493 [03:03<00:00, 24.45it/s]


Epoch 52 with g_loss: 0.0 and d_loss: 0.028107999774310815
Epoch: 52, g_acc=0.003000000026077032, d_acc_f=1.0, d_acc_r=0.9970000386238098


 13%|█▎        | 562/4493 [00:23<02:41, 24.37it/s]


Epoch 52 with g_loss: 0.002368966074145966 and d_loss: 0.0
Epoch: 52, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.9970000386238098


100%|██████████| 4493/4493 [03:02<00:00, 24.59it/s]


Epoch 53 with g_loss: 0.0 and d_loss: 0.010856575276686837
Epoch: 53, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:24<02:49, 23.13it/s]


Epoch 53 with g_loss: 0.004162909363405023 and d_loss: 0.0
Epoch: 53, g_acc=0.0010000000474974513, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:02<00:00, 24.62it/s]


Epoch 54 with g_loss: 0.0 and d_loss: 0.009535298077133086
Epoch: 54, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:22<02:40, 24.44it/s]


Epoch 54 with g_loss: 0.012177055214498347 and d_loss: 0.0
Epoch: 54, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.999000072479248


100%|██████████| 4493/4493 [03:03<00:00, 24.43it/s]


Epoch 55 with g_loss: 0.0 and d_loss: 0.005395235253465298
Epoch: 55, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:22<02:40, 24.44it/s]


Epoch 55 with g_loss: 0.047325344803845044 and d_loss: 0.0
Epoch: 55, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:02<00:00, 24.58it/s]


Epoch 56 with g_loss: 0.0 and d_loss: 0.009966958911728717
Epoch: 56, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:22<02:40, 24.48it/s]


Epoch 56 with g_loss: 0.001928559547894389 and d_loss: 0.0
Epoch: 56, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:04<00:00, 24.40it/s]


Epoch 57 with g_loss: 0.0 and d_loss: 0.006461976194049765
Epoch: 57, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.34it/s]


Epoch 57 with g_loss: 0.007140628914133907 and d_loss: 0.0
Epoch: 57, g_acc=0.0010000000474974513, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:02<00:00, 24.58it/s]


Epoch 58 with g_loss: 0.0 and d_loss: 0.010481936920442216
Epoch: 58, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:23<02:41, 24.38it/s]


Epoch 58 with g_loss: 1.1081478962409 and d_loss: 0.0
Epoch: 58, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:24<02:50, 23.04it/s]


Epoch 58 with g_loss: 0.41066834716664846 and d_loss: 0.0
Epoch: 58, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.999000072479248


100%|██████████| 4493/4493 [03:02<00:00, 24.59it/s]


Epoch 59 with g_loss: 0.0 and d_loss: 0.01521147516108586
Epoch: 59, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.41it/s]


Epoch 59 with g_loss: 0.05443762750500293 and d_loss: 0.0
Epoch: 59, g_acc=0.0010000000474974513, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:02<00:00, 24.57it/s]


Epoch 60 with g_loss: 0.0 and d_loss: 0.015906701095104607
Epoch: 60, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:24<02:49, 23.18it/s]


Epoch 60 with g_loss: 0.6081862187570507 and d_loss: 0.0
Epoch: 60, g_acc=0.0020000000949949026, d_acc_f=0.0, d_acc_r=0.999000072479248


100%|██████████| 4493/4493 [03:02<00:00, 24.59it/s]


Epoch 61 with g_loss: 0.0 and d_loss: 0.018672624747590922
Epoch: 61, g_acc=0.0020000000949949026, d_acc_f=1.0, d_acc_r=0.9980000257492065


 13%|█▎        | 562/4493 [00:22<02:40, 24.49it/s]


Epoch 61 with g_loss: 1.0704600671445132 and d_loss: 0.0
Epoch: 61, g_acc=0.0020000000949949026, d_acc_f=1.0, d_acc_r=0.9980000257492065


 13%|█▎        | 562/4493 [00:24<02:50, 23.08it/s]


Epoch 61 with g_loss: 1.070576541067735 and d_loss: 0.0
Epoch: 61, g_acc=0.0020000000949949026, d_acc_f=1.0, d_acc_r=0.9980000257492065


 13%|█▎        | 562/4493 [00:23<02:41, 24.38it/s]


Epoch 61 with g_loss: 0.17362352871768202 and d_loss: 0.0
Epoch: 61, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.9980000257492065


100%|██████████| 4493/4493 [03:02<00:00, 24.57it/s]


Epoch 62 with g_loss: 0.0 and d_loss: 0.008569490114174495
Epoch: 62, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.34it/s]


Epoch 62 with g_loss: 0.8775896160689637 and d_loss: 0.0
Epoch: 62, g_acc=0.003000000026077032, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:03<00:00, 24.42it/s]


Epoch 63 with g_loss: 0.0 and d_loss: 0.008683272759934352
Epoch: 63, g_acc=0.003000000026077032, d_acc_f=1.0, d_acc_r=0.9970000386238098


 13%|█▎        | 562/4493 [00:23<02:40, 24.42it/s]


Epoch 63 with g_loss: 0.9456287759590223 and d_loss: 0.0
Epoch: 63, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.9970000386238098


100%|██████████| 4493/4493 [03:03<00:00, 24.53it/s]


Epoch 64 with g_loss: 0.0 and d_loss: 0.0023455811389773234
Epoch: 64, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:22<02:40, 24.45it/s]


Epoch 64 with g_loss: 1.2729282939510995 and d_loss: 0.0
Epoch: 64, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.38it/s]


Epoch 64 with g_loss: 0.7203488752850931 and d_loss: 0.0
Epoch: 64, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:04<00:00, 24.40it/s]


Epoch 65 with g_loss: 0.0 and d_loss: 0.010984087497807283
Epoch: 65, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.29it/s]


Epoch 65 with g_loss: 0.006057196499681569 and d_loss: 0.0
Epoch: 65, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:02<00:00, 24.57it/s]


Epoch 66 with g_loss: 0.0 and d_loss: 0.165339949204678
Epoch: 66, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:24<02:50, 23.02it/s]


Epoch 66 with g_loss: 1.1637175398893884 and d_loss: 0.0
Epoch: 66, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:42, 24.20it/s]


Epoch 66 with g_loss: 1.1630678382769952 and d_loss: 0.0
Epoch: 66, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.32it/s]


Epoch 66 with g_loss: 1.1633727831217537 and d_loss: 0.0
Epoch: 66, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.40it/s]


Epoch 66 with g_loss: 1.1635279382705264 and d_loss: 0.0
Epoch: 66, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:24<02:49, 23.15it/s]


Epoch 66 with g_loss: 1.1640877094349458 and d_loss: 0.0
Epoch: 66, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:22<02:40, 24.48it/s]


Epoch 66 with g_loss: 1.1634356197099922 and d_loss: 0.0
Epoch: 66, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:22<02:40, 24.44it/s]


Epoch 66 with g_loss: 0.8753813482901318 and d_loss: 0.0
Epoch: 66, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:02<00:00, 24.56it/s]


Epoch 67 with g_loss: 0.0 and d_loss: 0.009774592416976735
Epoch: 67, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:24<02:50, 23.11it/s]


Epoch 67 with g_loss: 0.0028655421839889955 and d_loss: 0.0
Epoch: 67, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:02<00:00, 24.57it/s]


Epoch 68 with g_loss: 0.0 and d_loss: 0.010596726745413662
Epoch: 68, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:22<02:40, 24.49it/s]


Epoch 68 with g_loss: 0.7622906274676088 and d_loss: 0.0
Epoch: 68, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:02<00:00, 24.57it/s]


Epoch 69 with g_loss: 0.0 and d_loss: 0.005158627062757869
Epoch: 69, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:24<02:50, 23.10it/s]


Epoch 69 with g_loss: 1.017652838501475 and d_loss: 0.0
Epoch: 69, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:03<00:00, 24.55it/s]


Epoch 70 with g_loss: 0.0 and d_loss: 0.012547929510349706
Epoch: 70, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:22<02:40, 24.45it/s]


Epoch 70 with g_loss: 0.36643778097513874 and d_loss: 0.0
Epoch: 70, g_acc=0.003000000026077032, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:04<00:00, 24.35it/s]


Epoch 71 with g_loss: 0.0 and d_loss: 0.02535790947568608
Epoch: 71, g_acc=0.003000000026077032, d_acc_f=1.0, d_acc_r=0.9970000386238098


 13%|█▎        | 562/4493 [00:23<02:41, 24.41it/s]


Epoch 71 with g_loss: 1.0398685318203724 and d_loss: 0.0
Epoch: 71, g_acc=0.003000000026077032, d_acc_f=1.0, d_acc_r=0.9970000386238098


 13%|█▎        | 562/4493 [00:22<02:40, 24.45it/s]


Epoch 71 with g_loss: 1.0406282666051625 and d_loss: 0.0
Epoch: 71, g_acc=0.003000000026077032, d_acc_f=1.0, d_acc_r=0.9970000386238098


 13%|█▎        | 562/4493 [00:23<02:41, 24.40it/s]


Epoch 71 with g_loss: 1.0402255201562591 and d_loss: 0.0
Epoch: 71, g_acc=0.003000000026077032, d_acc_f=1.0, d_acc_r=0.9970000386238098


 13%|█▎        | 562/4493 [00:22<02:40, 24.45it/s]


Epoch 71 with g_loss: 0.4623212639384192 and d_loss: 0.0
Epoch: 71, g_acc=0.0, d_acc_f=0.01600000075995922, d_acc_r=0.9970000386238098


100%|██████████| 4493/4493 [03:03<00:00, 24.43it/s]


Epoch 72 with g_loss: 0.0 and d_loss: 0.001835285983805973
Epoch: 72, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.40it/s]


Epoch 72 with g_loss: 1.1972540013381003 and d_loss: 0.0
Epoch: 72, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.30it/s]


Epoch 72 with g_loss: 1.1972559933775972 and d_loss: 0.0
Epoch: 72, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.38it/s]


Epoch 72 with g_loss: 1.1974541566272687 and d_loss: 0.0
Epoch: 72, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:24<02:50, 23.05it/s]


Epoch 72 with g_loss: 1.1971344303612503 and d_loss: 0.0
Epoch: 72, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.39it/s]


Epoch 72 with g_loss: 1.1972563911487233 and d_loss: 0.0
Epoch: 72, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.40it/s]


Epoch 72 with g_loss: 0.971340102064768 and d_loss: 0.0
Epoch: 72, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:03<00:00, 24.53it/s]


Epoch 73 with g_loss: 0.0 and d_loss: 0.008394144573948633
Epoch: 73, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:24<02:50, 23.08it/s]


Epoch 73 with g_loss: 1.1411118769523643 and d_loss: 0.0
Epoch: 73, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:22<02:40, 24.46it/s]


Epoch 73 with g_loss: 0.5178125482000744 and d_loss: 0.0
Epoch: 73, g_acc=0.0010000000474974513, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:02<00:00, 24.60it/s]


Epoch 74 with g_loss: 0.0 and d_loss: 0.012654568673651573
Epoch: 74, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:23<02:41, 24.37it/s]


Epoch 74 with g_loss: 0.98869633488896 and d_loss: 0.0
Epoch: 74, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:24<02:49, 23.17it/s]


Epoch 74 with g_loss: 0.9886507930666467 and d_loss: 0.0
Epoch: 74, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:22<02:40, 24.46it/s]


Epoch 74 with g_loss: 0.9885679722999693 and d_loss: 0.0
Epoch: 74, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:22<02:40, 24.47it/s]


Epoch 74 with g_loss: 0.9886072188741827 and d_loss: 0.0
Epoch: 74, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:22<02:40, 24.45it/s]


Epoch 74 with g_loss: 0.21473103944185665 and d_loss: 0.0
Epoch: 74, g_acc=0.005000000353902578, d_acc_f=0.0, d_acc_r=0.999000072479248


100%|██████████| 4493/4493 [03:02<00:00, 24.56it/s]


Epoch 75 with g_loss: 0.0 and d_loss: 0.03154092860322001
Epoch: 75, g_acc=0.005000000353902578, d_acc_f=1.0, d_acc_r=0.9950000643730164


 13%|█▎        | 562/4493 [00:24<02:49, 23.20it/s]


Epoch 75 with g_loss: 0.0031398099924384807 and d_loss: 0.0
Epoch: 75, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.9950000643730164


100%|██████████| 4493/4493 [03:02<00:00, 24.58it/s]


Epoch 76 with g_loss: 0.0 and d_loss: 0.03893615529266071
Epoch: 76, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.42it/s]


Epoch 76 with g_loss: 0.0017171251076108598 and d_loss: 0.0
Epoch: 76, g_acc=0.0010000000474974513, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:11<00:00, 23.52it/s]


Epoch 77 with g_loss: 0.0 and d_loss: 0.018082237938021935
Epoch: 77, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:23<02:46, 23.66it/s]


Epoch 77 with g_loss: 0.0018778061369892792 and d_loss: 0.0
Epoch: 77, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.999000072479248


100%|██████████| 4493/4493 [03:09<00:00, 23.71it/s]


Epoch 78 with g_loss: 0.0 and d_loss: 0.028661496748077362
Epoch: 78, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:47, 23.53it/s]


Epoch 78 with g_loss: 0.0043685871073641655 and d_loss: 0.0
Epoch: 78, g_acc=0.003000000026077032, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:11<00:00, 23.51it/s]


Epoch 79 with g_loss: 0.0 and d_loss: 0.021499786390246186
Epoch: 79, g_acc=0.003000000026077032, d_acc_f=1.0, d_acc_r=0.9970000386238098


 13%|█▎        | 562/4493 [00:23<02:47, 23.53it/s]


Epoch 79 with g_loss: 0.0023052012772752676 and d_loss: 0.0
Epoch: 79, g_acc=0.01100000087171793, d_acc_f=0.0, d_acc_r=0.9970000386238098


100%|██████████| 4493/4493 [03:09<00:00, 23.69it/s]


Epoch 80 with g_loss: 0.0 and d_loss: 0.01743149398224683
Epoch: 80, g_acc=0.01100000087171793, d_acc_f=1.0, d_acc_r=0.9890000224113464


 13%|█▎        | 562/4493 [00:23<02:46, 23.63it/s]


Epoch 80 with g_loss: 0.11852390842141006 and d_loss: 0.0
Epoch: 80, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.9890000224113464


100%|██████████| 4493/4493 [03:11<00:00, 23.51it/s]


Epoch 81 with g_loss: 0.0 and d_loss: 0.01882851815650228
Epoch: 81, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:47, 23.43it/s]


Epoch 81 with g_loss: 1.0672715575716794 and d_loss: 0.0
Epoch: 81, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:47, 23.48it/s]


Epoch 81 with g_loss: 1.0678032990317023 and d_loss: 0.0
Epoch: 81, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:46, 23.59it/s]


Epoch 81 with g_loss: 1.0676724701131501 and d_loss: 0.0
Epoch: 81, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:47, 23.52it/s]


Epoch 81 with g_loss: 1.066735910992566 and d_loss: 0.0
Epoch: 81, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:46, 23.62it/s]


Epoch 81 with g_loss: 0.30132065877221376 and d_loss: 0.0
Epoch: 81, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:10<00:00, 23.53it/s]


Epoch 82 with g_loss: 0.0 and d_loss: 0.0033020996924686217
Epoch: 82, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:46, 23.56it/s]


Epoch 82 with g_loss: 1.0804126899968853 and d_loss: 0.0
Epoch: 82, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:47, 23.48it/s]


Epoch 82 with g_loss: 1.0808719532701503 and d_loss: 0.0
Epoch: 82, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:25<02:56, 22.32it/s]


Epoch 82 with g_loss: 1.0804153695394116 and d_loss: 0.0
Epoch: 82, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:46, 23.65it/s]


Epoch 82 with g_loss: 0.7468417561023303 and d_loss: 0.0
Epoch: 82, g_acc=0.0010000000474974513, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:09<00:00, 23.72it/s]


Epoch 83 with g_loss: 0.0 and d_loss: 0.006217157644477738
Epoch: 83, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:24<02:48, 23.30it/s]


Epoch 83 with g_loss: 1.0872535503921492 and d_loss: 0.0
Epoch: 83, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:25<02:54, 22.47it/s]


Epoch 83 with g_loss: 1.0874167649697652 and d_loss: 0.0
Epoch: 83, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:23<02:46, 23.59it/s]


Epoch 83 with g_loss: 1.0873067037793382 and d_loss: 0.0
Epoch: 83, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:23<02:46, 23.65it/s]


Epoch 83 with g_loss: 1.0871356263236056 and d_loss: 0.0
Epoch: 83, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:23<02:46, 23.67it/s]


Epoch 83 with g_loss: 1.0872459335208284 and d_loss: 0.0
Epoch: 83, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:25<02:56, 22.30it/s]


Epoch 83 with g_loss: 1.087109001129103 and d_loss: 0.0
Epoch: 83, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:23<02:46, 23.54it/s]


Epoch 83 with g_loss: 1.0873113728142676 and d_loss: 0.0
Epoch: 83, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:23<02:47, 23.48it/s]


Epoch 83 with g_loss: 1.0871853311582422 and d_loss: 0.0
Epoch: 83, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:23<02:46, 23.60it/s]


Epoch 83 with g_loss: 1.0869357917559803 and d_loss: 0.0
Epoch: 83, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:25<02:55, 22.37it/s]


Epoch 83 with g_loss: 1.0867480709535253 and d_loss: 0.0
Epoch: 83, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:23<02:46, 23.59it/s]


Epoch 83 with g_loss: 1.0869287343962326 and d_loss: 0.0
Epoch: 83, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:23<02:46, 23.62it/s]


Epoch 83 with g_loss: 0.021263865336180663 and d_loss: 0.0
Epoch: 83, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.999000072479248


100%|██████████| 4493/4493 [03:08<00:00, 23.87it/s]


Epoch 84 with g_loss: 0.0 and d_loss: 0.03480817238615124
Epoch: 84, g_acc=0.0, d_acc_f=1.0, d_acc_r=0.9950000643730164


 13%|█▎        | 562/4493 [00:25<02:55, 22.43it/s]


Epoch 84 with g_loss: 0.0028362867327362644 and d_loss: 0.0
Epoch: 84, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.9950000643730164


100%|██████████| 4493/4493 [03:08<00:00, 23.82it/s]


Epoch 85 with g_loss: 0.0 and d_loss: 47.409138822482156
Epoch: 85, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:46, 23.59it/s]


Epoch 85 with g_loss: 0.0023367475056544195 and d_loss: 0.0
Epoch: 85, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:08<00:00, 23.78it/s]


Epoch 86 with g_loss: 0.0 and d_loss: 0.018808709044100418
Epoch: 86, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:25<02:55, 22.36it/s]


Epoch 86 with g_loss: 0.006770601535380557 and d_loss: 0.0
Epoch: 86, g_acc=0.005000000353902578, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:08<00:00, 23.88it/s]


Epoch 87 with g_loss: 0.0 and d_loss: 0.024509974021720372
Epoch: 87, g_acc=0.005000000353902578, d_acc_f=1.0, d_acc_r=0.9950000643730164


 13%|█▎        | 562/4493 [00:23<02:45, 23.69it/s]


Epoch 87 with g_loss: 0.09390177066014421 and d_loss: 0.0
Epoch: 87, g_acc=0.0020000000949949026, d_acc_f=0.0, d_acc_r=0.9950000643730164


100%|██████████| 4493/4493 [03:05<00:00, 24.17it/s]


Epoch 88 with g_loss: 0.0 and d_loss: 0.013084765430537255
Epoch: 88, g_acc=0.0020000000949949026, d_acc_f=1.0, d_acc_r=0.9980000257492065


 13%|█▎        | 562/4493 [00:25<02:58, 22.04it/s]


Epoch 88 with g_loss: 0.99856780914164 and d_loss: 0.0
Epoch: 88, g_acc=0.0020000000949949026, d_acc_f=1.0, d_acc_r=0.9980000257492065


 13%|█▎        | 562/4493 [00:23<02:47, 23.44it/s]


Epoch 88 with g_loss: 0.9988680680769736 and d_loss: 0.0
Epoch: 88, g_acc=0.0020000000949949026, d_acc_f=1.0, d_acc_r=0.9980000257492065


 13%|█▎        | 562/4493 [00:23<02:45, 23.70it/s]


Epoch 88 with g_loss: 0.9987560406363729 and d_loss: 0.0
Epoch: 88, g_acc=0.0020000000949949026, d_acc_f=1.0, d_acc_r=0.9980000257492065


 13%|█▎        | 562/4493 [00:23<02:45, 23.81it/s]


Epoch 88 with g_loss: 0.5331499634983548 and d_loss: 0.0
Epoch: 88, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.9980000257492065


100%|██████████| 4493/4493 [03:11<00:00, 23.47it/s]


Epoch 89 with g_loss: 0.0 and d_loss: 0.06962539549681095
Epoch: 89, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:46, 23.61it/s]


Epoch 89 with g_loss: 0.014897658247434328 and d_loss: 0.0
Epoch: 89, g_acc=0.004000000189989805, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:11<00:00, 23.43it/s]


Epoch 90 with g_loss: 0.0 and d_loss: 0.03257294662090059
Epoch: 90, g_acc=0.004000000189989805, d_acc_f=1.0, d_acc_r=0.9960000514984131


 13%|█▎        | 562/4493 [00:23<02:46, 23.60it/s]


Epoch 90 with g_loss: 0.0036342408468751364 and d_loss: 0.0
Epoch: 90, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.9960000514984131


  2%|▏         | 87/4493 [00:03<03:14, 22.61it/s]


KeyboardInterrupt: 

In [150]:
old_generator = torch.load("generator 85.pt")
torch.mean(RSv1.discriminator(val_boards, old_generator(val_boards)))

OutOfMemoryError: CUDA out of memory. Tried to allocate 16.00 MiB. GPU 0 has a total capacity of 8.00 GiB of which 0 bytes is free. Of the allocated memory 14.42 GiB is allocated by PyTorch, and 67.98 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [140]:
for epoch in range(50):
    
    reps = 0
    
    if train_all or train_discriminator:
        
        while RSv1.logs['d_acc_f'][-1] < 0.5:
            reps += 1
            if reps > 5:
                train_all = False
                train_discriminator = False
                break
            for bitboards, mvs in tqdm(loader):
        
                RSv1.train_step(bitboards, mvs, train_generator=False)
            

            RSv1.on_epoch_end(epoch, G, val_data)

    reps = 0
    if train_all or not train_discriminator:
        
        while RSv1.logs['d_acc_f'][-1] > 0.5:
            reps += 1
            if reps > 12:
                train_all = False
                train_discriminator = True
                break
                
            i=0
            for bitboards, mvs in tqdm(loader):
                
                if i > G // 8:
                    break
                
                RSv1.train_step(bitboards, mvs, train_generator=True)
                i += 1
            
            RSv1.on_epoch_end(epoch, G, val_data)

100%|██████████| 4493/4493 [03:02<00:00, 24.58it/s]


Epoch 1 with g_loss: 0.0 and d_loss: 0.024132794972129612
Epoch: 1, g_acc=0.0020000000949949026, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.40it/s]


Epoch 1 with g_loss: 0.0024971719830587076 and d_loss: 0.0
Epoch: 1, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:03<00:00, 24.51it/s]


Epoch 2 with g_loss: 0.0 and d_loss: 0.013789552377934151
Epoch: 2, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:40, 24.42it/s]


Epoch 2 with g_loss: 0.0026067342491099383 and d_loss: 0.0
Epoch: 2, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:02<00:00, 24.60it/s]


Epoch 3 with g_loss: 0.0 and d_loss: 0.011044516429126192
Epoch: 3, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:24<02:50, 23.08it/s]


Epoch 3 with g_loss: 0.001956753350941522 and d_loss: 0.0
Epoch: 3, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:02<00:00, 24.58it/s]


Epoch 4 with g_loss: 0.0 and d_loss: 0.017382621447475502
Epoch: 4, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:22<02:40, 24.44it/s]


Epoch 4 with g_loss: 0.0024146829054086227 and d_loss: 0.0
Epoch: 4, g_acc=0.003000000026077032, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:02<00:00, 24.62it/s]


Epoch 5 with g_loss: 0.0 and d_loss: 0.02116424934301673
Epoch: 5, g_acc=0.003000000026077032, d_acc_f=1.0, d_acc_r=0.9970000386238098


 13%|█▎        | 562/4493 [00:24<02:50, 23.02it/s]


Epoch 5 with g_loss: 0.0022389527927109136 and d_loss: 0.0
Epoch: 5, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.9970000386238098


100%|██████████| 4493/4493 [03:02<00:00, 24.62it/s]


Epoch 6 with g_loss: 0.0 and d_loss: 0.025064310492302424
Epoch: 6, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.36it/s]


Epoch 6 with g_loss: 0.002354529493762095 and d_loss: 0.0
Epoch: 6, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:11<00:00, 23.47it/s]


Epoch 7 with g_loss: 0.0 and d_loss: 0.012076646195229303
Epoch: 7, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:47, 23.48it/s]


Epoch 7 with g_loss: 0.003113638971924814 and d_loss: 0.0
Epoch: 7, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:09<00:00, 23.75it/s]


Epoch 8 with g_loss: 0.0 and d_loss: 0.02602581964305289
Epoch: 8, g_acc=0.0, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:23<02:40, 24.43it/s]


Epoch 8 with g_loss: 0.0025406566680193803 and d_loss: 0.0
Epoch: 8, g_acc=0.0010000000474974513, d_acc_f=0.0, d_acc_r=0.999000072479248


100%|██████████| 4493/4493 [03:04<00:00, 24.39it/s]


Epoch 9 with g_loss: 0.0 and d_loss: 0.017290843504149573
Epoch: 9, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:23<02:42, 24.20it/s]


Epoch 9 with g_loss: 0.004489525197342968 and d_loss: 0.0
Epoch: 9, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.999000072479248


100%|██████████| 4493/4493 [03:02<00:00, 24.61it/s]


Epoch 10 with g_loss: 0.0 and d_loss: 0.03192948539286519
Epoch: 10, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.37it/s]


Epoch 10 with g_loss: 0.004548603995298355 and d_loss: 0.0
Epoch: 10, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:03<00:00, 24.46it/s]


Epoch 11 with g_loss: 0.0 and d_loss: 0.02545475482570881
Epoch: 11, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.38it/s]


Epoch 11 with g_loss: 0.0071322112835725225 and d_loss: 0.0
Epoch: 11, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:02<00:00, 24.60it/s]


Epoch 12 with g_loss: 0.0 and d_loss: 0.021561579191836235
Epoch: 12, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.40it/s]


Epoch 12 with g_loss: 0.0030548788558641365 and d_loss: 0.0
Epoch: 12, g_acc=0.0010000000474974513, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:07<00:00, 23.95it/s]


Epoch 13 with g_loss: 0.0 and d_loss: 0.008166748338256332
Epoch: 13, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:23<02:40, 24.42it/s]


Epoch 13 with g_loss: 1.14613275402722 and d_loss: 0.0
Epoch: 13, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:22<02:40, 24.48it/s]


Epoch 13 with g_loss: 1.1462635239381025 and d_loss: 0.0
Epoch: 13, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:24<02:50, 23.08it/s]


Epoch 13 with g_loss: 0.6434604772776485 and d_loss: 0.0
Epoch: 13, g_acc=0.007000000216066837, d_acc_f=0.0, d_acc_r=0.999000072479248


100%|██████████| 4493/4493 [03:02<00:00, 24.60it/s]


Epoch 14 with g_loss: 0.0 and d_loss: 0.08564183620863015
Epoch: 14, g_acc=0.007000000216066837, d_acc_f=1.0, d_acc_r=0.9930000305175781


 13%|█▎        | 562/4493 [00:23<02:41, 24.37it/s]


Epoch 14 with g_loss: 1.0083900795942315 and d_loss: 0.0
Epoch: 14, g_acc=0.007000000216066837, d_acc_f=1.0, d_acc_r=0.9930000305175781


 13%|█▎        | 562/4493 [00:24<02:50, 23.06it/s]


Epoch 14 with g_loss: 1.0084392576049968 and d_loss: 0.0
Epoch: 14, g_acc=0.007000000216066837, d_acc_f=1.0, d_acc_r=0.9930000305175781


 13%|█▎        | 562/4493 [00:23<02:41, 24.41it/s]


Epoch 14 with g_loss: 1.0086065882859612 and d_loss: 0.0
Epoch: 14, g_acc=0.007000000216066837, d_acc_f=1.0, d_acc_r=0.9930000305175781


 13%|█▎        | 562/4493 [00:23<02:41, 24.29it/s]


Epoch 14 with g_loss: 1.0078608365146986 and d_loss: 0.0
Epoch: 14, g_acc=0.007000000216066837, d_acc_f=1.0, d_acc_r=0.9930000305175781


 13%|█▎        | 562/4493 [00:23<02:41, 24.40it/s]


Epoch 14 with g_loss: 1.008536771812074 and d_loss: 0.0
Epoch: 14, g_acc=0.007000000216066837, d_acc_f=1.0, d_acc_r=0.9930000305175781


 13%|█▎        | 562/4493 [00:24<02:50, 23.09it/s]


Epoch 14 with g_loss: 0.49203410669700104 and d_loss: 0.0
Epoch: 14, g_acc=0.0010000000474974513, d_acc_f=0.0, d_acc_r=0.9930000305175781


100%|██████████| 4493/4493 [03:02<00:00, 24.62it/s]


Epoch 15 with g_loss: 0.0 and d_loss: 0.007852844052392461
Epoch: 15, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:23<02:41, 24.29it/s]


Epoch 15 with g_loss: 0.3792542241055219 and d_loss: 0.0
Epoch: 15, g_acc=0.0010000000474974513, d_acc_f=0.0, d_acc_r=0.999000072479248


100%|██████████| 4493/4493 [03:02<00:00, 24.65it/s]


Epoch 16 with g_loss: 0.0 and d_loss: 0.005216305374649326
Epoch: 16, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:24<02:50, 23.01it/s]


Epoch 16 with g_loss: 0.9095760940205353 and d_loss: 0.0
Epoch: 16, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:23<02:41, 24.32it/s]


Epoch 16 with g_loss: 0.909494537036827 and d_loss: 0.0
Epoch: 16, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:23<02:41, 24.36it/s]


Epoch 16 with g_loss: 0.4869868645703765 and d_loss: 0.0
Epoch: 16, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.999000072479248


100%|██████████| 4493/4493 [03:02<00:00, 24.61it/s]


Epoch 17 with g_loss: 0.0 and d_loss: 0.020298042552250393
Epoch: 17, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:24<02:50, 23.12it/s]


Epoch 17 with g_loss: 0.0019280314855860746 and d_loss: 0.0
Epoch: 17, g_acc=0.0010000000474974513, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:02<00:00, 24.63it/s]


Epoch 18 with g_loss: 0.0 and d_loss: 0.02613615630448376
Epoch: 18, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:23<02:41, 24.38it/s]


Epoch 18 with g_loss: 0.0022720197435810866 and d_loss: 0.0
Epoch: 18, g_acc=0.007000000216066837, d_acc_f=0.0, d_acc_r=0.999000072479248


100%|██████████| 4493/4493 [03:02<00:00, 24.58it/s]


Epoch 19 with g_loss: 0.0 and d_loss: 0.03952874378361315
Epoch: 19, g_acc=0.007000000216066837, d_acc_f=1.0, d_acc_r=0.9930000305175781


 13%|█▎        | 562/4493 [00:25<02:55, 22.41it/s]


Epoch 19 with g_loss: 1.0567799240662477 and d_loss: 0.0
Epoch: 19, g_acc=0.007000000216066837, d_acc_f=1.0, d_acc_r=0.9930000305175781


 13%|█▎        | 562/4493 [00:23<02:47, 23.44it/s]


Epoch 19 with g_loss: 1.0576880620790434 and d_loss: 0.0
Epoch: 19, g_acc=0.007000000216066837, d_acc_f=1.0, d_acc_r=0.9930000305175781


 13%|█▎        | 562/4493 [00:23<02:47, 23.47it/s]


Epoch 19 with g_loss: 1.0577225030304302 and d_loss: 0.0
Epoch: 19, g_acc=0.007000000216066837, d_acc_f=1.0, d_acc_r=0.9930000305175781


 13%|█▎        | 562/4493 [00:23<02:47, 23.47it/s]


Epoch 19 with g_loss: 1.057060945546099 and d_loss: 0.0
Epoch: 19, g_acc=0.007000000216066837, d_acc_f=1.0, d_acc_r=0.9930000305175781


 13%|█▎        | 562/4493 [00:25<02:55, 22.34it/s]


Epoch 19 with g_loss: 0.857387409039782 and d_loss: 0.0
Epoch: 19, g_acc=0.007000000216066837, d_acc_f=0.0, d_acc_r=0.9930000305175781


100%|██████████| 4493/4493 [03:10<00:00, 23.59it/s]


Epoch 20 with g_loss: 0.0 and d_loss: 0.024749410654262293
Epoch: 20, g_acc=0.007000000216066837, d_acc_f=0.9970000386238098, d_acc_r=0.9960000514984131


 13%|█▎        | 562/4493 [00:23<02:47, 23.52it/s]


Epoch 20 with g_loss: 0.24784160958378665 and d_loss: 0.0
Epoch: 20, g_acc=0.007000000216066837, d_acc_f=0.0, d_acc_r=0.9960000514984131


100%|██████████| 4493/4493 [03:09<00:00, 23.74it/s]


Epoch 21 with g_loss: 0.0 and d_loss: 0.029546770543393164
Epoch: 21, g_acc=0.007000000216066837, d_acc_f=1.0, d_acc_r=0.9930000305175781


 13%|█▎        | 562/4493 [00:25<02:56, 22.26it/s]


Epoch 21 with g_loss: 0.1482545676324645 and d_loss: 0.0
Epoch: 21, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.9930000305175781


100%|██████████| 4493/4493 [03:05<00:00, 24.18it/s]


Epoch 22 with g_loss: 0.0 and d_loss: 0.0021309153491904584
Epoch: 22, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:22<02:40, 24.46it/s]


Epoch 22 with g_loss: 1.1494790569010807 and d_loss: 0.0
Epoch: 22, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:24<02:50, 23.05it/s]


Epoch 22 with g_loss: 0.49715727468460785 and d_loss: 0.0
Epoch: 22, g_acc=0.0010000000474974513, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:02<00:00, 24.62it/s]


Epoch 23 with g_loss: 0.0 and d_loss: 0.008832896417877698
Epoch: 23, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:23<02:40, 24.42it/s]


Epoch 23 with g_loss: 0.6402830028497702 and d_loss: 0.0
Epoch: 23, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.999000072479248


100%|██████████| 4493/4493 [03:02<00:00, 24.61it/s]


Epoch 24 with g_loss: 0.0 and d_loss: 0.004233283480006336
Epoch: 24, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:24<02:50, 23.10it/s]


Epoch 24 with g_loss: 0.23528336050444834 and d_loss: 0.0
Epoch: 24, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:02<00:00, 24.57it/s]


Epoch 25 with g_loss: 0.0 and d_loss: 0.01180190011063971
Epoch: 25, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.40it/s]


Epoch 25 with g_loss: 0.42441664595341805 and d_loss: 0.0
Epoch: 25, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:24<02:50, 23.03it/s]


Epoch 25 with g_loss: 0.4231362898949921 and d_loss: 0.0
Epoch: 25, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:22<02:40, 24.49it/s]


Epoch 25 with g_loss: 0.42293300937497214 and d_loss: 0.0
Epoch: 25, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:40, 24.42it/s]


Epoch 25 with g_loss: 0.42309014138680856 and d_loss: 0.0
Epoch: 25, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.35it/s]


Epoch 25 with g_loss: 0.3306266656112207 and d_loss: 0.0
Epoch: 25, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:02<00:00, 24.64it/s]


Epoch 26 with g_loss: 0.0 and d_loss: 0.0053197652965731425
Epoch: 26, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:24<02:50, 23.11it/s]


Epoch 26 with g_loss: 1.0851617696793074 and d_loss: 0.0
Epoch: 26, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.39it/s]


Epoch 26 with g_loss: 1.0851141625885545 and d_loss: 0.0
Epoch: 26, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.38it/s]


Epoch 26 with g_loss: 1.0850176746479632 and d_loss: 0.0
Epoch: 26, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.38it/s]


Epoch 26 with g_loss: 1.0851203571203263 and d_loss: 0.0
Epoch: 26, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:24<02:50, 23.11it/s]


Epoch 26 with g_loss: 1.0850722145031322 and d_loss: 0.0
Epoch: 26, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:40, 24.43it/s]


Epoch 26 with g_loss: 0.34364054295389834 and d_loss: 0.0
Epoch: 26, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:02<00:00, 24.62it/s]


Epoch 27 with g_loss: 0.0 and d_loss: 0.02810128170069002
Epoch: 27, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.35it/s]


Epoch 27 with g_loss: 0.004730754672665122 and d_loss: 0.0
Epoch: 27, g_acc=0.01100000087171793, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:03<00:00, 24.47it/s]


Epoch 28 with g_loss: 0.0 and d_loss: 0.0701304057754693
Epoch: 28, g_acc=0.01100000087171793, d_acc_f=0.999000072479248, d_acc_r=0.9890000224113464


 13%|█▎        | 562/4493 [00:23<02:41, 24.40it/s]


Epoch 28 with g_loss: 1.0104810746348092 and d_loss: 0.0
Epoch: 28, g_acc=0.01100000087171793, d_acc_f=0.999000072479248, d_acc_r=0.9890000224113464


 13%|█▎        | 562/4493 [00:23<02:45, 23.70it/s]


Epoch 28 with g_loss: 1.0109130541837323 and d_loss: 0.0
Epoch: 28, g_acc=0.01100000087171793, d_acc_f=0.999000072479248, d_acc_r=0.9890000224113464


 13%|█▎        | 562/4493 [00:23<02:46, 23.56it/s]


Epoch 28 with g_loss: 1.0110644866489977 and d_loss: 0.0
Epoch: 28, g_acc=0.01100000087171793, d_acc_f=0.999000072479248, d_acc_r=0.9890000224113464


 13%|█▎        | 562/4493 [00:23<02:46, 23.68it/s]


Epoch 28 with g_loss: 1.0102967549452877 and d_loss: 0.0
Epoch: 28, g_acc=0.01100000087171793, d_acc_f=0.999000072479248, d_acc_r=0.9890000224113464


 13%|█▎        | 562/4493 [00:23<02:46, 23.61it/s]


Epoch 28 with g_loss: 1.0108271996162528 and d_loss: 0.0
Epoch: 28, g_acc=0.01100000087171793, d_acc_f=0.999000072479248, d_acc_r=0.9890000224113464


 13%|█▎        | 562/4493 [00:23<02:45, 23.78it/s]


Epoch 28 with g_loss: 0.319415412789852 and d_loss: 0.0
Epoch: 28, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.9890000224113464


100%|██████████| 4493/4493 [03:09<00:00, 23.70it/s]


Epoch 29 with g_loss: 0.0 and d_loss: 0.025883360263627645
Epoch: 29, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:46, 23.66it/s]


Epoch 29 with g_loss: 0.0015045398033764943 and d_loss: 0.0
Epoch: 29, g_acc=0.01100000087171793, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:08<00:00, 23.80it/s]


Epoch 30 with g_loss: 0.0 and d_loss: 0.013897468978089305
Epoch: 30, g_acc=0.01100000087171793, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:25<02:55, 22.44it/s]


Epoch 30 with g_loss: 0.0019653787617225004 and d_loss: 0.0
Epoch: 30, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:08<00:00, 23.83it/s]


Epoch 31 with g_loss: 0.0 and d_loss: 0.014311085127042421
Epoch: 31, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:47, 23.44it/s]


Epoch 31 with g_loss: 0.00571343115088333 and d_loss: 0.0
Epoch: 31, g_acc=0.010000000707805157, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:08<00:00, 23.84it/s]


Epoch 32 with g_loss: 0.0 and d_loss: 0.0377351770585002
Epoch: 32, g_acc=0.010000000707805157, d_acc_f=1.0, d_acc_r=0.9900000691413879


 13%|█▎        | 562/4493 [00:25<02:55, 22.35it/s]


Epoch 32 with g_loss: 0.0020329938528218967 and d_loss: 0.0
Epoch: 32, g_acc=0.0010000000474974513, d_acc_f=0.0, d_acc_r=0.9900000691413879


100%|██████████| 4493/4493 [03:08<00:00, 23.79it/s]


Epoch 33 with g_loss: 0.0 and d_loss: 0.00889352647667435
Epoch: 33, g_acc=0.0010000000474974513, d_acc_f=1.0, d_acc_r=0.999000072479248


 13%|█▎        | 562/4493 [00:23<02:46, 23.57it/s]


Epoch 33 with g_loss: 0.7105380524813301 and d_loss: 0.0
Epoch: 33, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.999000072479248


100%|██████████| 4493/4493 [03:09<00:00, 23.69it/s]


Epoch 34 with g_loss: 0.0 and d_loss: 0.019446368198690493
Epoch: 34, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:46, 23.67it/s]


Epoch 34 with g_loss: 0.0018333214690047005 and d_loss: 0.0
Epoch: 34, g_acc=0.005000000353902578, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:08<00:00, 23.82it/s]


Epoch 35 with g_loss: 0.0 and d_loss: 0.016625313673626206
Epoch: 35, g_acc=0.005000000353902578, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.32it/s]


Epoch 35 with g_loss: 0.006040684454103249 and d_loss: 0.0
Epoch: 35, g_acc=0.005000000353902578, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:04<00:00, 24.37it/s]


Epoch 36 with g_loss: 0.0 and d_loss: 0.013459019428842482
Epoch: 36, g_acc=0.005000000353902578, d_acc_f=1.0, d_acc_r=0.9950000643730164


 13%|█▎        | 562/4493 [00:23<02:42, 24.23it/s]


Epoch 36 with g_loss: 0.046957266955801644 and d_loss: 0.0
Epoch: 36, g_acc=0.0020000000949949026, d_acc_f=0.0, d_acc_r=0.9950000643730164


100%|██████████| 4493/4493 [03:03<00:00, 24.47it/s]


Epoch 37 with g_loss: 0.0 and d_loss: 62.37730184793092
Epoch: 37, g_acc=0.0020000000949949026, d_acc_f=1.0, d_acc_r=0.9980000257492065


 13%|█▎        | 562/4493 [00:23<02:41, 24.32it/s]


Epoch 37 with g_loss: 0.3587451279818375 and d_loss: 0.0
Epoch: 37, g_acc=0.0020000000949949026, d_acc_f=0.0010000000474974513, d_acc_r=0.9980000257492065


100%|██████████| 4493/4493 [03:03<00:00, 24.43it/s]


Epoch 38 with g_loss: 0.0 and d_loss: 0.006774450282417731
Epoch: 38, g_acc=0.0020000000949949026, d_acc_f=1.0, d_acc_r=0.9980000257492065


 13%|█▎        | 562/4493 [00:23<02:41, 24.35it/s]


Epoch 38 with g_loss: 1.0585761667757185 and d_loss: 0.0
Epoch: 38, g_acc=0.0020000000949949026, d_acc_f=1.0, d_acc_r=0.9980000257492065


 13%|█▎        | 562/4493 [00:23<02:41, 24.37it/s]


Epoch 38 with g_loss: 1.0579749531239675 and d_loss: 0.0
Epoch: 38, g_acc=0.0020000000949949026, d_acc_f=1.0, d_acc_r=0.9980000257492065


 13%|█▎        | 562/4493 [00:23<02:42, 24.26it/s]


Epoch 38 with g_loss: 1.0579145006153914 and d_loss: 0.0
Epoch: 38, g_acc=0.0020000000949949026, d_acc_f=1.0, d_acc_r=0.9980000257492065


 13%|█▎        | 562/4493 [00:23<02:41, 24.39it/s]


Epoch 38 with g_loss: 0.3598136633385102 and d_loss: 0.0
Epoch: 38, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.9980000257492065


100%|██████████| 4493/4493 [03:03<00:00, 24.42it/s]


Epoch 39 with g_loss: 0.0 and d_loss: 0.06503237561745502
Epoch: 39, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:22<02:40, 24.45it/s]


Epoch 39 with g_loss: 0.0038040497419236765 and d_loss: 0.0
Epoch: 39, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:02<00:00, 24.56it/s]


Epoch 40 with g_loss: 0.0 and d_loss: 0.0285007918934809
Epoch: 40, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:24<02:51, 22.92it/s]


Epoch 40 with g_loss: 0.8459185938407444 and d_loss: 0.0
Epoch: 40, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:24<02:48, 23.37it/s]


Epoch 40 with g_loss: 0.8459076930227987 and d_loss: 0.0
Epoch: 40, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:47, 23.48it/s]


Epoch 40 with g_loss: 0.5057048317482056 and d_loss: 0.0
Epoch: 40, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:10<00:00, 23.54it/s]


Epoch 41 with g_loss: 0.0 and d_loss: 0.010005825766747122
Epoch: 41, g_acc=0.0, d_acc_f=0.999000072479248, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:25<02:58, 22.07it/s]


Epoch 41 with g_loss: 0.07064528664954159 and d_loss: 0.0
Epoch: 41, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:11<00:00, 23.52it/s]


Epoch 42 with g_loss: 0.0 and d_loss: 0.016110362718400006
Epoch: 42, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:24<02:47, 23.41it/s]


Epoch 42 with g_loss: 0.0020489374464775743 and d_loss: 0.0
Epoch: 42, g_acc=0.0, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:11<00:00, 23.51it/s]


Epoch 43 with g_loss: 0.0 and d_loss: 0.23114775911245297
Epoch: 43, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:25<02:56, 22.23it/s]


Epoch 43 with g_loss: 0.09901157946243133 and d_loss: 0.0
Epoch: 43, g_acc=0.003000000026077032, d_acc_f=0.0, d_acc_r=1.0


100%|██████████| 4493/4493 [03:11<00:00, 23.50it/s]


Epoch 44 with g_loss: 0.0 and d_loss: 0.034431312724988554
Epoch: 44, g_acc=0.003000000026077032, d_acc_f=1.0, d_acc_r=0.9970000386238098


 13%|█▎        | 562/4493 [00:24<02:48, 23.35it/s]


Epoch 44 with g_loss: 0.002903821849011941 and d_loss: 0.0
Epoch: 44, g_acc=0.01100000087171793, d_acc_f=0.0, d_acc_r=0.9970000386238098


100%|██████████| 4493/4493 [03:10<00:00, 23.52it/s]


Epoch 45 with g_loss: 0.0 and d_loss: 0.016043013366815136
Epoch: 45, g_acc=0.01100000087171793, d_acc_f=1.0, d_acc_r=0.9890000224113464


 13%|█▎        | 562/4493 [00:25<02:58, 22.07it/s]


Epoch 45 with g_loss: 1.0727303874484686 and d_loss: 0.0
Epoch: 45, g_acc=0.01100000087171793, d_acc_f=1.0, d_acc_r=0.9890000224113464


 13%|█▎        | 562/4493 [00:24<02:48, 23.32it/s]


Epoch 45 with g_loss: 1.072683942681454 and d_loss: 0.0
Epoch: 45, g_acc=0.01100000087171793, d_acc_f=1.0, d_acc_r=0.9890000224113464


 13%|█▎        | 562/4493 [00:24<02:48, 23.32it/s]


Epoch 45 with g_loss: 1.072877934853006 and d_loss: 0.0
Epoch: 45, g_acc=0.01100000087171793, d_acc_f=1.0, d_acc_r=0.9890000224113464


 13%|█▎        | 562/4493 [00:24<02:47, 23.41it/s]


Epoch 45 with g_loss: 1.0728727275322838 and d_loss: 0.0
Epoch: 45, g_acc=0.01100000087171793, d_acc_f=1.0, d_acc_r=0.9890000224113464


 13%|█▎        | 562/4493 [00:25<02:57, 22.16it/s]


Epoch 45 with g_loss: 1.0728035828543485 and d_loss: 0.0
Epoch: 45, g_acc=0.01100000087171793, d_acc_f=1.0, d_acc_r=0.9890000224113464


 13%|█▎        | 562/4493 [00:24<02:48, 23.37it/s]


Epoch 45 with g_loss: 0.38786240676831474 and d_loss: 0.0
Epoch: 45, g_acc=0.003000000026077032, d_acc_f=0.0, d_acc_r=0.9890000224113464


100%|██████████| 4493/4493 [03:10<00:00, 23.55it/s]


Epoch 46 with g_loss: 0.0 and d_loss: 0.029451147682765084
Epoch: 46, g_acc=0.003000000026077032, d_acc_f=1.0, d_acc_r=0.9970000386238098


 13%|█▎        | 562/4493 [00:23<02:44, 23.85it/s]


Epoch 46 with g_loss: 0.0017781945799185172 and d_loss: 0.0
Epoch: 46, g_acc=0.003000000026077032, d_acc_f=0.0, d_acc_r=0.9970000386238098


100%|██████████| 4493/4493 [03:04<00:00, 24.38it/s]


Epoch 47 with g_loss: 0.0 and d_loss: 0.03428712216891328
Epoch: 47, g_acc=0.003000000026077032, d_acc_f=1.0, d_acc_r=0.9970000386238098


 13%|█▎        | 562/4493 [00:23<02:42, 24.24it/s]


Epoch 47 with g_loss: 0.004009029440807323 and d_loss: 0.0
Epoch: 47, g_acc=0.01100000087171793, d_acc_f=0.0, d_acc_r=0.9970000386238098


100%|██████████| 4493/4493 [03:03<00:00, 24.51it/s]


Epoch 48 with g_loss: 0.0 and d_loss: 0.016665056084429142
Epoch: 48, g_acc=0.01100000087171793, d_acc_f=1.0, d_acc_r=0.9890000224113464


 13%|█▎        | 562/4493 [00:23<02:41, 24.27it/s]


Epoch 48 with g_loss: 1.2500380750172604 and d_loss: 0.0
Epoch: 48, g_acc=0.01100000087171793, d_acc_f=1.0, d_acc_r=0.9890000224113464


 13%|█▎        | 562/4493 [00:23<02:41, 24.42it/s]


Epoch 48 with g_loss: 1.2503132982506615 and d_loss: 0.0
Epoch: 48, g_acc=0.01100000087171793, d_acc_f=1.0, d_acc_r=0.9890000224113464


 13%|█▎        | 562/4493 [00:23<02:41, 24.31it/s]


Epoch 48 with g_loss: 1.2504820937227359 and d_loss: 0.0
Epoch: 48, g_acc=0.01100000087171793, d_acc_f=1.0, d_acc_r=0.9890000224113464


 13%|█▎        | 562/4493 [00:24<02:50, 23.05it/s]


Epoch 48 with g_loss: 1.2503875228946362 and d_loss: 0.0
Epoch: 48, g_acc=0.01100000087171793, d_acc_f=1.0, d_acc_r=0.9890000224113464


 13%|█▎        | 562/4493 [00:23<02:41, 24.31it/s]


Epoch 48 with g_loss: 1.2506287702335006 and d_loss: 0.0
Epoch: 48, g_acc=0.01100000087171793, d_acc_f=1.0, d_acc_r=0.9890000224113464


 13%|█▎        | 562/4493 [00:23<02:41, 24.33it/s]


Epoch 48 with g_loss: 0.8199879871894834 and d_loss: 0.0
Epoch: 48, g_acc=0.0, d_acc_f=0.0, d_acc_r=0.9890000224113464


100%|██████████| 4493/4493 [03:04<00:00, 24.38it/s]


Epoch 49 with g_loss: 0.0 and d_loss: 0.010932574784350531
Epoch: 49, g_acc=0.0, d_acc_f=1.0, d_acc_r=1.0


 13%|█▎        | 562/4493 [00:23<02:41, 24.40it/s]


Epoch 49 with g_loss: 0.03563844770916103 and d_loss: 0.0
Epoch: 49, g_acc=0.0020000000949949026, d_acc_f=0.009000000543892384, d_acc_r=1.0


100%|██████████| 4493/4493 [03:03<00:00, 24.51it/s]


Epoch 50 with g_loss: 0.0 and d_loss: 0.010246130882807146
Epoch: 50, g_acc=0.0020000000949949026, d_acc_f=1.0, d_acc_r=0.9980000257492065


 13%|█▎        | 562/4493 [00:23<02:42, 24.22it/s]


Epoch 50 with g_loss: 1.012568757286216 and d_loss: 0.0
Epoch: 50, g_acc=0.0020000000949949026, d_acc_f=1.0, d_acc_r=0.9980000257492065


 13%|█▎        | 562/4493 [00:24<02:50, 23.07it/s]

Epoch 50 with g_loss: 0.39470355253507655 and d_loss: 0.0
Epoch: 50, g_acc=0.003000000026077032, d_acc_f=0.0, d_acc_r=0.9980000257492065





In [135]:
val_boards = val_data.bitboards
real_moves = val_data.moves

minn, ila = RDv2(val_boards)
fake_moves = torch.cat([F.softmax(minn, dim=1), F.softmax(ila,dim=1)], dim=1)
torch.mean((real_moves == torch.round(fake_moves)).all(dim=1), dtype=torch.float).item()

0.2510000169277191

In [137]:
fake_moves = RSv1(val_boards)
print(fake_moves[0])

#fake_moves_ind = torch.argmax(fake_moves, dim=1)
print(torch.round(fake_moves[0]))

g_acc = torch.mean((real_moves == torch.round(fake_moves)).all(dim=1), dtype=torch.float).item()

tensor([1.7404e-42, 0.0000e+00, 2.2697e-41, 2.9514e-19, 2.4398e-16, 4.8634e-12,
        9.2207e-27, 3.0531e-33, 0.0000e+00, 0.0000e+00, 2.2381e-24, 2.3693e-12,
        1.0543e-04, 2.4633e-19, 4.6248e-19, 9.9989e-01, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 1.0044e-09, 2.2264e-06, 1.8982e-10, 8.1555e-16, 6.4270e-11,
        1.2278e-31, 0.0000e+00, 1.0315e-31, 1.0168e-20, 5.1734e-15, 8.6261e-07,
        3.3822e-08, 5.3728e-09, 0.0000e+00, 0.0000e+00, 8.4442e-37, 1.1300e-11,
        1.6480e-18, 1.1017e-12, 9.6544e-10, 5.9234e-10, 0.0000e+00, 0.0000e+00,
        3.2209e-39, 1.5248e-36, 2.6022e-42, 5.4401e-41, 0.0000e+00, 4.2039e-45,
        0.0000e+00, 0.0000e+00, 7.8557e-42, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+

In [141]:
test_against(lambda x: network_agent_prob_conv(x, RDv2), lambda x: network_agent_prob_conv(x, RSv1.generator), N=100)

100%|██████████| 50/50 [00:28<00:00,  1.77it/s]
100%|██████████| 50/50 [00:27<00:00,  1.80it/s]


(100, 0, 0, 1.0)

In [15]:
torch.cuda.memory_allocated() 

7274496