In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, IterableDataset
import chess
from time import time
import tqdm

HEADERS = ("bitmaps", "movePlayed", "validMoves")
BATCH_SIZE = 64

In [2]:
import torch.nn as nn
import torch.nn.functional as F
class PositionEvaluatorNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(13, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        return torch.tanh(self.fc2(x))  # Output between -1 and 1
    

class ChessBotNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # 13 channels: 12 pieces + side to play (tensor[true's|false's])
        self.conv1 = nn.Conv2d(13,64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(8 * 8 * 128, 256)
        self.fc2 = nn.Linear(256, 64)
        self.relu = nn.ReLU()

        # Initialize weights
        nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu')
        nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu')
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)


    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)  # Output raw logits
        return x

class _CompleteChessBotNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # 13 channels: 12 pieces + side to play (tensor[true's|false's])
        self.conv1 = nn.Conv2d(13, 64, kernel_size=3, padding='same')
        self.batchnorm1 = nn.BatchNorm2d(64)

        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding='same')
        self.batchnorm2 = nn.BatchNorm2d(128)

        self.fc1 = nn.Linear(8 * 8 * 128, 256) # 8 * 8 (num of squares) * 128 (num of channels)
        self.fc2 = nn.Linear(256, 64 * 63) # (Choose 2 squares from the board where the order matters)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.flatten = nn.Flatten()

        # Initialize weights
        nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu')
        nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu')
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)


    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.batchnorm1(x)

        x = self.relu(self.conv2(x))
        x = self.batchnorm2(x)

        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)  # Output raw logits
        return x

class CompleteChessBotNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # 13 channels: 12 pieces + side to play (tensor[true's|false's])

        self.inputLayer = nn.Conv2d(13, 64, kernel_size=3, padding='same')
        self.batchnorm0 = nn.BatchNorm2d(64)

        self.conv1 = nn.Conv2d(64, 64, kernel_size=3, padding='same')
        self.batchnorm1 = nn.BatchNorm2d(64)

        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding='same')
        self.batchnorm2 = nn.BatchNorm2d(64)

        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding='same')
        self.batchnorm3 = nn.BatchNorm2d(64)


        self.convLayers = nn.Sequential(
            self.inputLayer,
            self.batchnorm0,
            nn.ReLU(),
            self.conv1,
            self.batchnorm1,
            nn.ReLU(),
            self.conv2,
            self.batchnorm2,
            nn.ReLU(),
            self.conv3,
            self.batchnorm3,
            nn.ReLU()
        )

        self.fc1 = nn.Linear(8 * 8 * 64, 256) # 8 * 8 (num of squares) * 128 (num of channels)
        self.fc2 = nn.Linear(256, 64 * 63) # (Choose 2 squares from the board where the order matters)
        
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()

        # Initialize weights
        nn.init.kaiming_uniform_(self.inputLayer.weight, nonlinearity='relu')
        nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu')
        nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu')
        nn.init.kaiming_uniform_(self.conv3.weight, nonlinearity='relu')
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)


    def forward(self, x):
        x = self.convLayers(x)
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)  # Output raw logits
        return x
## Idea:
##   One model that answers "best piece to move in this position"
##   Then another model that answers "best square to move piece X to"

# Add normalization after conv
# Switch from relu to sigmoid or smth
# Add more preprocessing by making the tensors there and avoid pre processing before training

In [3]:
import json

piece_to_idx = {
    'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
    'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
}

def board_to_tensor(board):
    tensor = np.zeros((12, 8, 8), dtype=np.uint8)
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            idx = piece_to_idx[piece.symbol()]
            row = 7 - square // 8
            col = square % 8
            tensor[idx, row, col] = 1
    return tensor

def convert_to_array_bitmap(row: str):
    """"
    Converts board into array

    :param str row: Board with side to play;
                        This must represent a np.ndarray[np.int64, shape=(13)] 
    :return: np.ndarray[shape(13, 8, 8), dtype=np.float32]]:
    """
    boards = np.array(json.loads(row), dtype=np.uint64) # shape = (13,)
    array = np.empty((13, 8, 8), dtype=np.uint8)

    # Set pieces boards
    board_int8_view = boards.view(dtype=np.uint8).reshape((13, 8)) # shape = (13, 8)
    board_as_int = np.unpackbits(board_int8_view, axis=1).reshape((13, 8, 8))
    array[:] = board_as_int

    # Set side to play board
    array[12] = np.ones(shape = (1, 8, 8)) if boards[12] == 1 else np.zeros(shape=(1,8,8))
    return array


letters = ["a", "b", "c", "d", "e", "f", "g", "h"]
numbers = list(range(1, 10)) # [1..9]
MOVE_DICTIONARY = {}
cumulative = 0
for i in range(8):
    for j in range(8):
        for k in range(8):
            for w in range (8):
                if (i == k and j == w):
                    cumulative += 1
                    continue
                from_square = f"{letters[i]}{numbers[j]}"
                to_square = f"{letters[k]}{numbers[w]}"
                MOVE_DICTIONARY[f"{from_square}{to_square}"] = (i * 8**3) + (j * 8**2) + (k * 8) + w - cumulative
REVERSE_MOVE_DICTIONARY = {
    value: key for key,value in MOVE_DICTIONARY.items()
}

In [4]:
from typing import Literal
import polars as pl
import random

class ChessEvalDataset(Dataset):
# class ChessEvalDataset(IterableDataset):
    def __init__(self, file: str, model: Literal["pieces", "moves"] = "pieces", load_batch_size = 6_400):
        self.model = model
        self.lazy_dataset = pl.scan_csv(file, has_header=False, new_columns=HEADERS)
        self.batch_size = load_batch_size
        self.feature_col = "bitmaps"
        self.target_col = "movePlayed"
        self.total_rows = self.lazy_dataset.select(pl.len()).collect().item()

        self.cached_batches: dict[int, tuple] = {}
        self.cached_batch_id: int | None = None

    def __len__(self):
        return self.total_rows
    
    def _get_batch(self, batch_id):
        """Load a specific batch of data, wrapped with lru_cache for memory management"""
        # Calculate batch range
        if batch_id == self.cached_batch_id:
            return self.cached_batches[batch_id]
        

        self.cached_batches.pop(self.cached_batch_id, None) # Delete old batch
        
        # Calculate batch range
        start_idx = batch_id * self.batch_size
        end_idx = min((batch_id + 1) * self.batch_size, self.total_rows)
        self.idxs = [i for i in range(end_idx - start_idx)]
        random.shuffle(self.idxs) # Shuffle the indices for the batch
        
        # Fetch only this batch of data using offset and limit
        batch_dataset = (self.lazy_dataset
                    .slice(start_idx, end_idx - start_idx)
                    .collect())
        
        # Process features and target
        features = batch_dataset.select(self.feature_col)
        features = torch.tensor(np.array([convert_to_array_bitmap(bitmaps) for bitmaps in features["bitmaps"]]), dtype=torch.float32)
        
        played_moves = batch_dataset.select(self.target_col).to_numpy()
        targets = torch.tensor(played_moves, dtype=torch.long)

        valid_moves = batch_dataset.select("validMoves")
        valid_moves = torch.tensor(np.array([json.loads(row) for row in valid_moves["validMoves"]], dtype=np.uint64))
        # valid_moves = np.zeros(shape=np.shape(played_moves))


        self.cached_batch_id = batch_id
        self.cached_batches[self.cached_batch_id] = (features, targets, valid_moves)
        
        return self.cached_batches[batch_id]

    def __getitem__(self, idx):
        # Calculate which batch this index belongs to
        batch_id = idx // self.batch_size
        # Get the batch
        features, targets, valid_moves = self._get_batch(batch_id)

        # Get the item from the batch
        idx_in_batch = self.idxs[idx % self.batch_size]  # Get the next index from the shuffled list

        return features[idx_in_batch], targets[idx_in_batch], valid_moves[idx_in_batch]

In [5]:
VALID_MOVE_LOSS = -0.1/1024
INVALID_MOVE_LOSS = +10/1024

piece_to_idx = {
    'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
    'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
}

def bitmaps_to_board(bitmaps):
    board = chess.Board(fen = "8/8/8/8/8/8/8/8 w KQkq - 0 1")
    for piece_name, piece_idx in piece_to_idx.items():
        for row in range(8):
            for col in range(8):
                if bitmaps[piece_idx][row][col] == 1:
                    piece = chess.Piece(piece_idx, chess.WHITE) if piece_idx < 6 else chess.Piece(piece_idx - 6, chess.BLACK)
                    board.set_piece_at(row * 8 + col, piece)
    return board

# LOSS FUNCTION
cross_entropy_loss = torch.nn.CrossEntropyLoss()
def loss_fn(outputs: torch.Tensor, targets, valid_moves): # targets tensor with twp elements: [targets, possible_moves]
    loss = cross_entropy_loss(outputs, targets)

    played_move = [output.argmax().item() for output in outputs]
    penalization = 0
    invalid_played = False
    for move, valid_moves in zip(played_move, valid_moves):
        num_idx = move // 64
        bit_idx = move % 64
        num = valid_moves[num_idx]

        shift_tensor = torch.tensor(1 << bit_idx, dtype=num.dtype) 

        if num & shift_tensor != 0:
            # invalid_played = True
            penalization += INVALID_MOVE_LOSS
        else:
            penalization += VALID_MOVE_LOSS
            
    # if invalid_played:
        # penalization += INVALID_MOVE_LOSS
    return torch.add(loss, penalization)

In [None]:
DATASET_PATH = '../dataset/processed/results_ELO_1500.csv'
NUM_EPOCHS = 400

TRAINING_MODE = "pieces" # "pieces" or "moves"
MODEL_WEIGHTS_OUTPUT_PATH = "./models/CompleteModel_new_architecture_skipmoves_FINISHED.pth"

## !IMPORTANT: This dictates how much ram will be used, and how much data will be loaded
# 1_280_000 loads around 5gb, dont push this too high as it will crash if ram deplects
# NUM_EXAMPLES_TO_LOAD_PER_FETCH = 1_280_000 
NUM_EXAMPLES_TO_LOAD_PER_FETCH = 640_000 
BATCH_SIZE = 1024

test = ChessEvalDataset(file = DATASET_PATH, model=TRAINING_MODE, load_batch_size = NUM_EXAMPLES_TO_LOAD_PER_FETCH)
loader = DataLoader(test, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CompleteChessBotNetwork().to(device)

# Continue with pretrained weights
model.load_state_dict(torch.load('./models/CompleteModel_new_architecture_skipmoves_epoch-30.pth'))
# model.load_state_dict(torch.load('./models/CompleteModel_noskip_version2_epoch-170.pth'))

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

print(torch.cuda.is_available())
print("Using device: ", device)
for epoch in range(7, NUM_EPOCHS+1):
    model.train()
    t0 = time()
    avg_loss = 0.0
    i = 1
    correct = 0
    for board_tensor, target_eval, valid_moves in tqdm.tqdm(loader):        
        board_tensor_gpu, target_eval_gpu = board_tensor.to(device), target_eval.to(device)  # Move data to GPU
        optimizer.zero_grad()
        pred = model(board_tensor_gpu)

        # Compute loss with valid move vlaidaiton
        # loss = loss_fn(board_tensor, pred, target_eval_gpu.squeeze(1))
        loss = loss_fn(pred, target_eval_gpu.squeeze(1), valid_moves)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        avg_loss += loss.item()
        i+=1
        correct += (pred.argmax(dim=1) == target_eval_gpu[:, 0]).sum().item()
    accuracy = 100 * correct / (len(loader) * BATCH_SIZE)
    tf = time()
    print(f"Epoch {epoch} - {avg_loss / len(loader):.4f} | Accuracy: {accuracy:.2f}% | Time: {tf-t0}")

    if epoch % 5 == 0:
        torch.save(model.state_dict(), f"./models/CompleteModel_new_architecture_skipmoves_epoch-{epoch}.pth")
# Save the trained model
torch.save(model.state_dict(), MODEL_WEIGHTS_OUTPUT_PATH)

# filtra 1100 - 14000

True
Using device:  cuda


100%|██████████| 2744/2744 [03:45<00:00, 12.18it/s]


Epoch 7 - 1.4678 | Accuracy: 52.09% | Time: 225.31490898132324


100%|██████████| 2744/2744 [03:44<00:00, 12.25it/s]


Epoch 8 - 1.4615 | Accuracy: 52.23% | Time: 224.08702063560486


100%|██████████| 2744/2744 [03:43<00:00, 12.26it/s] 


Epoch 9 - 1.4574 | Accuracy: 52.34% | Time: 223.81604194641113


100%|██████████| 2744/2744 [03:43<00:00, 12.26it/s]


Epoch 10 - 1.4530 | Accuracy: 52.46% | Time: 223.85280752182007


100%|██████████| 2744/2744 [03:44<00:00, 12.24it/s]


Epoch 11 - 1.4490 | Accuracy: 52.57% | Time: 224.1455385684967


100%|██████████| 2744/2744 [03:43<00:00, 12.26it/s]


Epoch 12 - 1.4453 | Accuracy: 52.64% | Time: 223.872878074646


100%|██████████| 2744/2744 [03:43<00:00, 12.25it/s] 


Epoch 13 - 1.4413 | Accuracy: 52.77% | Time: 223.9711172580719


100%|██████████| 2744/2744 [03:44<00:00, 12.25it/s]


Epoch 14 - 1.4377 | Accuracy: 52.87% | Time: 224.0687985420227


100%|██████████| 2744/2744 [03:43<00:00, 12.26it/s]


Epoch 15 - 1.4343 | Accuracy: 52.93% | Time: 223.81940007209778


100%|██████████| 2744/2744 [03:43<00:00, 12.25it/s] 


Epoch 16 - 1.4312 | Accuracy: 53.04% | Time: 223.97946333885193


100%|██████████| 2744/2744 [03:43<00:00, 12.27it/s]


Epoch 17 - 1.4277 | Accuracy: 53.12% | Time: 223.65896201133728


100%|██████████| 2744/2744 [03:43<00:00, 12.27it/s]


Epoch 18 - 1.4250 | Accuracy: 53.19% | Time: 223.61775946617126


100%|██████████| 2744/2744 [03:43<00:00, 12.27it/s] 


Epoch 19 - 1.4216 | Accuracy: 53.31% | Time: 223.72099041938782


100%|██████████| 2744/2744 [03:44<00:00, 12.25it/s]


Epoch 20 - 1.4192 | Accuracy: 53.35% | Time: 224.0343108177185


100%|██████████| 2744/2744 [03:43<00:00, 12.26it/s]


Epoch 21 - 1.4160 | Accuracy: 53.41% | Time: 223.76253604888916


100%|██████████| 2744/2744 [03:43<00:00, 12.26it/s]


Epoch 22 - 1.4137 | Accuracy: 53.51% | Time: 223.74419045448303


100%|██████████| 2744/2744 [03:44<00:00, 12.24it/s]


Epoch 23 - 1.4107 | Accuracy: 53.58% | Time: 224.11111402511597


100%|██████████| 2744/2744 [03:43<00:00, 12.27it/s]


Epoch 24 - 1.4083 | Accuracy: 53.62% | Time: 223.70377802848816


100%|██████████| 2744/2744 [03:43<00:00, 12.26it/s]


Epoch 25 - 1.4058 | Accuracy: 53.70% | Time: 223.761785030365


100%|██████████| 2744/2744 [03:41<00:00, 12.38it/s]


Epoch 26 - 1.4032 | Accuracy: 53.77% | Time: 221.59964752197266


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s]


Epoch 27 - 1.4008 | Accuracy: 53.83% | Time: 221.45677256584167


100%|██████████| 2744/2744 [03:41<00:00, 12.38it/s]


Epoch 28 - 1.3985 | Accuracy: 53.92% | Time: 221.63443303108215


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s]


Epoch 29 - 1.3967 | Accuracy: 53.94% | Time: 221.54545307159424


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s]


Epoch 30 - 1.3941 | Accuracy: 54.01% | Time: 221.3382499217987


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s]


Epoch 31 - 1.3923 | Accuracy: 54.05% | Time: 221.44609475135803


100%|██████████| 2744/2744 [03:41<00:00, 12.37it/s]


Epoch 32 - 1.3902 | Accuracy: 54.12% | Time: 221.80693459510803


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 33 - 1.3882 | Accuracy: 54.17% | Time: 221.52601075172424


100%|██████████| 2744/2744 [03:40<00:00, 12.42it/s] 


Epoch 34 - 1.3859 | Accuracy: 54.23% | Time: 220.9802758693695


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 35 - 1.3839 | Accuracy: 54.28% | Time: 221.4046139717102


100%|██████████| 2744/2744 [03:41<00:00, 12.37it/s] 


Epoch 36 - 1.3824 | Accuracy: 54.31% | Time: 221.75119352340698


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 37 - 1.3804 | Accuracy: 54.39% | Time: 221.03418636322021


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 38 - 1.3786 | Accuracy: 54.41% | Time: 221.15270733833313


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 39 - 1.3764 | Accuracy: 54.47% | Time: 221.41660928726196


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 40 - 1.3750 | Accuracy: 54.54% | Time: 221.09187746047974


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 41 - 1.3732 | Accuracy: 54.58% | Time: 221.40109872817993


100%|██████████| 2744/2744 [03:41<00:00, 12.38it/s] 


Epoch 42 - 1.3720 | Accuracy: 54.60% | Time: 221.5828766822815


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 43 - 1.3704 | Accuracy: 54.63% | Time: 221.21905136108398


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 44 - 1.3683 | Accuracy: 54.68% | Time: 221.43519020080566


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 45 - 1.3675 | Accuracy: 54.75% | Time: 221.51767992973328


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 46 - 1.3655 | Accuracy: 54.80% | Time: 221.2453155517578


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 47 - 1.3642 | Accuracy: 54.80% | Time: 221.4122667312622


100%|██████████| 2744/2744 [03:41<00:00, 12.37it/s] 


Epoch 48 - 1.3627 | Accuracy: 54.86% | Time: 221.86811113357544


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 49 - 1.3610 | Accuracy: 54.91% | Time: 221.53465366363525


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 50 - 1.3598 | Accuracy: 54.94% | Time: 221.5258936882019


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 51 - 1.3583 | Accuracy: 54.98% | Time: 221.55766010284424


100%|██████████| 2744/2744 [03:41<00:00, 12.38it/s] 


Epoch 52 - 1.3573 | Accuracy: 54.98% | Time: 221.73397517204285


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 53 - 1.3555 | Accuracy: 55.03% | Time: 221.41699314117432


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 54 - 1.3543 | Accuracy: 55.07% | Time: 221.52252745628357


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 55 - 1.3532 | Accuracy: 55.09% | Time: 221.34529995918274


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 56 - 1.3521 | Accuracy: 55.13% | Time: 221.426043510437


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 57 - 1.3504 | Accuracy: 55.19% | Time: 221.28192734718323


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 58 - 1.3495 | Accuracy: 55.19% | Time: 221.3320062160492


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 59 - 1.3481 | Accuracy: 55.25% | Time: 221.47092604637146


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 60 - 1.3470 | Accuracy: 55.29% | Time: 221.28486490249634


100%|██████████| 2744/2744 [03:41<00:00, 12.42it/s] 


Epoch 61 - 1.3456 | Accuracy: 55.33% | Time: 221.0172290802002


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 62 - 1.3448 | Accuracy: 55.34% | Time: 221.28021907806396


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 63 - 1.3435 | Accuracy: 55.35% | Time: 221.44171500205994


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 64 - 1.3423 | Accuracy: 55.43% | Time: 221.08451962471008


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 65 - 1.3413 | Accuracy: 55.43% | Time: 221.20680689811707


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 66 - 1.3405 | Accuracy: 55.47% | Time: 221.4134681224823


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 67 - 1.3390 | Accuracy: 55.49% | Time: 221.2620723247528


100%|██████████| 2744/2744 [03:40<00:00, 12.42it/s] 


Epoch 68 - 1.3384 | Accuracy: 55.53% | Time: 220.89212250709534


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 69 - 1.3373 | Accuracy: 55.53% | Time: 221.1697256565094


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 70 - 1.3363 | Accuracy: 55.57% | Time: 221.46353960037231


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 71 - 1.3353 | Accuracy: 55.59% | Time: 221.26050281524658


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 72 - 1.3344 | Accuracy: 55.63% | Time: 221.18033838272095


100%|██████████| 2744/2744 [03:40<00:00, 12.42it/s] 


Epoch 73 - 1.3333 | Accuracy: 55.66% | Time: 220.86745953559875


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 74 - 1.3324 | Accuracy: 55.67% | Time: 221.3142445087433


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 75 - 1.3317 | Accuracy: 55.70% | Time: 221.17776918411255


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 76 - 1.3306 | Accuracy: 55.73% | Time: 221.41621041297913


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 77 - 1.3296 | Accuracy: 55.75% | Time: 221.09048438072205


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 78 - 1.3287 | Accuracy: 55.76% | Time: 221.0482051372528


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 79 - 1.3277 | Accuracy: 55.79% | Time: 221.0294599533081


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 80 - 1.3269 | Accuracy: 55.84% | Time: 221.55922675132751


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 81 - 1.3263 | Accuracy: 55.84% | Time: 221.12894201278687


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 82 - 1.3255 | Accuracy: 55.87% | Time: 221.4976634979248


100%|██████████| 2744/2744 [03:40<00:00, 12.42it/s] 


Epoch 83 - 1.3243 | Accuracy: 55.89% | Time: 220.88272166252136


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 84 - 1.3238 | Accuracy: 55.91% | Time: 221.28628492355347


100%|██████████| 2744/2744 [03:41<00:00, 12.38it/s] 


Epoch 85 - 1.3230 | Accuracy: 55.91% | Time: 221.73495507240295


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 86 - 1.3224 | Accuracy: 55.95% | Time: 221.17509174346924


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 87 - 1.3213 | Accuracy: 55.97% | Time: 221.08336210250854


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 88 - 1.3202 | Accuracy: 56.01% | Time: 221.1987144947052


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 89 - 1.3199 | Accuracy: 56.02% | Time: 221.06498622894287


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 90 - 1.3192 | Accuracy: 56.03% | Time: 221.1177544593811


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 91 - 1.3183 | Accuracy: 56.06% | Time: 221.04199123382568


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 92 - 1.3177 | Accuracy: 56.09% | Time: 221.2211949825287


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 93 - 1.3170 | Accuracy: 56.11% | Time: 221.34997367858887


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 94 - 1.3161 | Accuracy: 56.12% | Time: 221.1547975540161


100%|██████████| 2744/2744 [03:40<00:00, 12.42it/s] 


Epoch 95 - 1.3151 | Accuracy: 56.15% | Time: 220.93915033340454


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 96 - 1.3148 | Accuracy: 56.15% | Time: 221.46981072425842


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 97 - 1.3140 | Accuracy: 56.17% | Time: 221.41094255447388


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 98 - 1.3135 | Accuracy: 56.19% | Time: 221.22131872177124


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 99 - 1.3127 | Accuracy: 56.22% | Time: 221.06176662445068


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 100 - 1.3118 | Accuracy: 56.24% | Time: 221.3724548816681


100%|██████████| 2744/2744 [03:41<00:00, 12.42it/s] 


Epoch 101 - 1.3116 | Accuracy: 56.25% | Time: 221.0142741203308


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 102 - 1.3108 | Accuracy: 56.27% | Time: 221.14248776435852


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 103 - 1.3100 | Accuracy: 56.31% | Time: 221.26588106155396


100%|██████████| 2744/2744 [03:41<00:00, 12.38it/s]


Epoch 104 - 1.3091 | Accuracy: 56.31% | Time: 221.5730438232422


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 105 - 1.3089 | Accuracy: 56.32% | Time: 221.13218593597412


100%|██████████| 2744/2744 [03:40<00:00, 12.42it/s] 


Epoch 106 - 1.3079 | Accuracy: 56.32% | Time: 220.90293526649475


100%|██████████| 2744/2744 [03:40<00:00, 12.42it/s] 


Epoch 107 - 1.3077 | Accuracy: 56.34% | Time: 220.91570711135864


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 108 - 1.3070 | Accuracy: 56.37% | Time: 221.41190695762634


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 109 - 1.3062 | Accuracy: 56.38% | Time: 221.132337808609


100%|██████████| 2744/2744 [03:40<00:00, 12.42it/s] 


Epoch 110 - 1.3056 | Accuracy: 56.40% | Time: 220.94219040870667


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 111 - 1.3048 | Accuracy: 56.43% | Time: 221.02834630012512


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 112 - 1.3044 | Accuracy: 56.44% | Time: 221.17348885536194


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 113 - 1.3037 | Accuracy: 56.46% | Time: 221.27929043769836


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s]


Epoch 114 - 1.3031 | Accuracy: 56.47% | Time: 221.20513653755188


100%|██████████| 2744/2744 [03:40<00:00, 12.43it/s] 


Epoch 115 - 1.3029 | Accuracy: 56.48% | Time: 220.7624945640564


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 116 - 1.3020 | Accuracy: 56.49% | Time: 221.03318119049072


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 117 - 1.3018 | Accuracy: 56.49% | Time: 221.10520768165588


100%|██████████| 2744/2744 [03:40<00:00, 12.42it/s] 


Epoch 118 - 1.3009 | Accuracy: 56.55% | Time: 220.95498895645142


100%|██████████| 2744/2744 [03:40<00:00, 12.42it/s] 


Epoch 119 - 1.3005 | Accuracy: 56.56% | Time: 220.97241973876953


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 120 - 1.3001 | Accuracy: 56.53% | Time: 221.05061769485474


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 121 - 1.2993 | Accuracy: 56.60% | Time: 221.38500928878784


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 122 - 1.2991 | Accuracy: 56.58% | Time: 221.13979530334473


100%|██████████| 2744/2744 [03:40<00:00, 12.42it/s] 


Epoch 123 - 1.2982 | Accuracy: 56.61% | Time: 220.98925828933716


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 124 - 1.2976 | Accuracy: 56.62% | Time: 221.05432057380676


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 125 - 1.2972 | Accuracy: 56.62% | Time: 221.28706169128418


100%|██████████| 2744/2744 [03:40<00:00, 12.42it/s] 


Epoch 126 - 1.2969 | Accuracy: 56.62% | Time: 220.99677658081055


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 127 - 1.2965 | Accuracy: 56.66% | Time: 221.07191681861877


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 128 - 1.2957 | Accuracy: 56.69% | Time: 221.51365756988525


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 129 - 1.2948 | Accuracy: 56.72% | Time: 221.3183467388153


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 130 - 1.2948 | Accuracy: 56.70% | Time: 221.26679849624634


100%|██████████| 2744/2744 [03:41<00:00, 12.37it/s] 


Epoch 131 - 1.2944 | Accuracy: 56.71% | Time: 221.81186699867249


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 132 - 1.2937 | Accuracy: 56.72% | Time: 221.18114066123962


100%|██████████| 2744/2744 [03:41<00:00, 12.38it/s] 


Epoch 133 - 1.2932 | Accuracy: 56.74% | Time: 221.65952491760254


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 134 - 1.2930 | Accuracy: 56.76% | Time: 221.46853232383728


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 135 - 1.2923 | Accuracy: 56.76% | Time: 221.27719068527222


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 136 - 1.2920 | Accuracy: 56.77% | Time: 221.1147496700287


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 137 - 1.2912 | Accuracy: 56.80% | Time: 221.24831581115723


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 138 - 1.2913 | Accuracy: 56.83% | Time: 221.14724445343018


100%|██████████| 2744/2744 [03:41<00:00, 12.42it/s] 


Epoch 139 - 1.2907 | Accuracy: 56.83% | Time: 221.0099127292633


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 140 - 1.2901 | Accuracy: 56.82% | Time: 221.1374795436859


100%|██████████| 2744/2744 [03:40<00:00, 12.42it/s] 


Epoch 141 - 1.2898 | Accuracy: 56.84% | Time: 220.9501712322235


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 142 - 1.2892 | Accuracy: 56.86% | Time: 221.12146425247192


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 143 - 1.2889 | Accuracy: 56.88% | Time: 221.10291075706482


100%|██████████| 2744/2744 [03:41<00:00, 12.42it/s] 


Epoch 144 - 1.2887 | Accuracy: 56.86% | Time: 221.01215720176697


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 145 - 1.2880 | Accuracy: 56.89% | Time: 221.2523171901703


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 146 - 1.2876 | Accuracy: 56.91% | Time: 221.04236316680908


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 147 - 1.2872 | Accuracy: 56.90% | Time: 221.03390979766846


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 148 - 1.2867 | Accuracy: 56.93% | Time: 221.18651032447815


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 149 - 1.2865 | Accuracy: 56.91% | Time: 221.21582627296448


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 150 - 1.2861 | Accuracy: 56.90% | Time: 221.1130301952362


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 151 - 1.2856 | Accuracy: 56.96% | Time: 221.0703432559967


100%|██████████| 2744/2744 [03:41<00:00, 12.40it/s] 


Epoch 152 - 1.2852 | Accuracy: 56.97% | Time: 221.24461841583252


100%|██████████| 2744/2744 [03:41<00:00, 12.38it/s] 


Epoch 153 - 1.2845 | Accuracy: 56.99% | Time: 221.57728385925293


100%|██████████| 2744/2744 [03:40<00:00, 12.42it/s] 


Epoch 154 - 1.2844 | Accuracy: 56.98% | Time: 220.89800262451172


100%|██████████| 2744/2744 [03:41<00:00, 12.39it/s] 


Epoch 155 - 1.2841 | Accuracy: 56.98% | Time: 221.49611282348633


100%|██████████| 2744/2744 [03:41<00:00, 12.38it/s] 


Epoch 156 - 1.2835 | Accuracy: 57.03% | Time: 221.60513615608215


100%|██████████| 2744/2744 [03:41<00:00, 12.41it/s] 


Epoch 157 - 1.2830 | Accuracy: 57.04% | Time: 221.04370403289795


100%|██████████| 2744/2744 [03:51<00:00, 11.83it/s] 


Epoch 158 - 1.2831 | Accuracy: 57.01% | Time: 231.92207431793213


100%|██████████| 2744/2744 [05:07<00:00,  8.91it/s]  


Epoch 159 - 1.2825 | Accuracy: 57.04% | Time: 307.81230425834656


100%|██████████| 2744/2744 [05:08<00:00,  8.89it/s]  


Epoch 160 - 1.2820 | Accuracy: 57.05% | Time: 308.5661563873291


 97%|█████████▋| 2654/2744 [04:59<00:10,  8.86it/s]  


KeyboardInterrupt: 