In [1]:
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error
import numpy as np
import tqdm as tqdm
import torch.optim as optim
import os

Kaggle Data set
https://www.kaggle.com/datasets/ronakbadhe/chess-evaluations

In [2]:
df = pd.read_csv('Data/chessData.csv')

In [3]:
print(df.head())
df = df[:(len(df)//32)].copy()
df = df.dropna()
print(len(df))


                                                 FEN Evaluation
0  rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR ...        -10
1  rnbqkbnr/pppp1ppp/4p3/8/4P3/8/PPPP1PPP/RNBQKBN...        +56
2  rnbqkbnr/pppp1ppp/4p3/8/3PP3/8/PPP2PPP/RNBQKBN...         -9
3  rnbqkbnr/ppp2ppp/4p3/3p4/3PP3/8/PPP2PPP/RNBQKB...        +52
4  rnbqkbnr/ppp2ppp/4p3/3p4/3PP3/8/PPPN1PPP/R1BQK...        -26
404938


In [4]:
def clean_evals(entry:str):
    if entry[0] == '#':
        side = 1000 if (entry[1] == '+') else -1000
        return int(entry[2:]) + side
    else:
        return int(entry)



In [5]:
df['Evaluation'] = df['Evaluation'].apply(clean_evals)

In [6]:
def convert_to_index(pos: str) -> int:
    if pos == "-":
        return 0
    file_map = {value: key for key, value in enumerate("abcdefgh")}
    col = file_map[pos[0]]
    row = int(pos[1]) - 1
    # print(f"col: {col}, row: {row}")
    return 1 << (row * 8) + col


In [7]:

# Convert binary vectors to PyTorch tensors
# tensor_data = torch.tensor(binary_vectors, dtype=torch.float32)
def convert_pieces_to_vectors(piece_map: dict[str: int]) -> None:
    for piece, board in piece_map.items():
        piece_map[piece] = np.array([(board >> i) & 1 for i in range(64)], dtype=np.float32)

In [8]:

def convert_fen(fen: str):
    piece_map = {value: 0 for key, value in enumerate("KQRBNPkqrbnp")}
    # print(f"Piece_Map: {piece_map}")
    ranks = fen.split("/")
    # print(f"ranks: {ranks}")
    curr_rank = 8
    while curr_rank > 0:
        r = ranks[8-curr_rank]
        file_index = 0
        # print(f"rank: {curr_rank-1} | {r}")
        for i in range(len(r)):
            c = r[i]
            if file_index >= 8:
                convert_pieces_to_vectors(piece_map)
                states = r[i:].split(" ")[1:]
                # print(f"States: {states}")
                piece_map["moves"] = 0 if states[0] == 'w' else 1
                piece_map['WhiteKing'] = 1 if "K" in states[1] else 0
                piece_map['WhiteQueen'] = 1 if 'Q' in states[1] else 0
                piece_map['BlackKing'] = 1 if 'k' in states[1] else 0
                piece_map['BlackQueen'] = 1 if 'q' in states[1] else 0
                piece_map["en_passant"] =  np.array([(convert_to_index(states[2]) >> i) & 1 for i in range(64)], dtype=np.float32)
                piece_map["half_clock"] = int(states[3])
                piece_map["full_clock"] = int(states[4])
                return piece_map
            else:
                if c.isdigit():
                    file_index += int(c)
                else:
                    # print(f"c: {c}")
                    piece_map[c] += (1 << ((curr_rank * 8) + file_index))
                    file_index += 1
        curr_rank -= 1
    return piece_map
                


In [9]:
df['Longs'] = df['FEN'].apply(convert_fen)
df = pd.concat([df, pd.json_normalize(df['Longs'])], axis=1)
df = df.drop(['FEN', 'Longs'], axis=1)
print(df.head())

   Evaluation                                                  K  \
0         -10  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   
1          56  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   
2          -9  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   
3          52  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   
4         -26  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   

                                                   Q  \
0  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   
1  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   
2  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   
3  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   
4  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   

                                                   R  \
0  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, ...   
1  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, ...   
2  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, ...   
3  [0.0, 0.0, 0.0, 0.0, 0.0, 0

In [10]:
print(len(df['K']))
print(len(df['en_passant']))

404938
404938


In [12]:
print(df.columns)

Index(['Evaluation', 'K', 'Q', 'R', 'B', 'N', 'P', 'k', 'q', 'r', 'b', 'n',
       'p', 'moves', 'WhiteKing', 'WhiteQueen', 'BlackKing', 'BlackQueen',
       'en_passant', 'half_clock', 'full_clock'],
      dtype='object')


In [11]:
# Print out the dimensions and types of all relevant columns
for col in ['K', 'Q', 'R', 'B', 'N', 'P', 'k', 'q', 'r', 'b', 'n', 'p', 'en_passant']:
    print(f"{col}: Type: {type(df[col].iloc[0])}, Shape: {np.array(df[col].iloc[0]).shape}")

# Ensure 'en_passant' is an array
df['en_passant'] = df['en_passant'].apply(lambda x: np.array(x, dtype=np.float32) if not isinstance(x, np.ndarray) else x)

# Recheck dimensions and types after conversion
for col in ['K', 'Q', 'R', 'B', 'N', 'P', 'k', 'q', 'r', 'b', 'n', 'p', 'en_passant']:
    print(f"{col}: Type: {type(df[col].iloc[0])}, Shape: {np.array(df[col].iloc[0]).shape}")


K: Type: <class 'numpy.ndarray'>, Shape: (64,)
Q: Type: <class 'numpy.ndarray'>, Shape: (64,)
R: Type: <class 'numpy.ndarray'>, Shape: (64,)
B: Type: <class 'numpy.ndarray'>, Shape: (64,)
N: Type: <class 'numpy.ndarray'>, Shape: (64,)
P: Type: <class 'numpy.ndarray'>, Shape: (64,)
k: Type: <class 'numpy.ndarray'>, Shape: (64,)
q: Type: <class 'numpy.ndarray'>, Shape: (64,)
r: Type: <class 'numpy.ndarray'>, Shape: (64,)
b: Type: <class 'numpy.ndarray'>, Shape: (64,)
n: Type: <class 'numpy.ndarray'>, Shape: (64,)
p: Type: <class 'numpy.ndarray'>, Shape: (64,)
en_passant: Type: <class 'numpy.ndarray'>, Shape: (64,)
K: Type: <class 'numpy.ndarray'>, Shape: (64,)
Q: Type: <class 'numpy.ndarray'>, Shape: (64,)
R: Type: <class 'numpy.ndarray'>, Shape: (64,)
B: Type: <class 'numpy.ndarray'>, Shape: (64,)
N: Type: <class 'numpy.ndarray'>, Shape: (64,)
P: Type: <class 'numpy.ndarray'>, Shape: (64,)
k: Type: <class 'numpy.ndarray'>, Shape: (64,)
q: Type: <class 'numpy.ndarray'>, Shape: (64,)
r: T

In [13]:
# Example DataFrame structure
# df contains columns like 'K', 'Q', 'R', etc. for piece positions
# and additional columns for 'moves', 'half_clock', 'full_clock', and castling rights

def prepare_input(row):
    # Extract the binary vectors for each piece type
    channels = []
    for piece in ['K', 'Q', 'R', 'B', 'N', 'P', 'k', 'q', 'r', 'b', 'n', 'p', 'en_passant']:
        piece_vector = np.array(row[piece]).reshape(8, 8)
        channels.append(piece_vector)
    
    # Additional channel for castling rights
    castling_rights_channel = np.zeros((8, 8))
    if row['WhiteKing']:
        castling_rights_channel += 1
    if row['WhiteQueen']:
        castling_rights_channel += 2
    if row['BlackKing']:
        castling_rights_channel += 4
    if row['BlackQueen']:
        castling_rights_channel += 8
    channels.append(castling_rights_channel)
    
    # Additional channels for 'moves', 'half_clock', and 'full_clock'
    moves_channel = np.full((8, 8), row['moves'])
    half_clock_channel = np.full((8, 8), row['half_clock'])
    full_clock_channel = np.full((8, 8), row['full_clock'])
    channels.extend([moves_channel, half_clock_channel, full_clock_channel])
    
    input_tensor = np.stack(channels)
    return input_tensor

# Apply this function to all rows to get the input data for the CNN
df['input_tensor'] = df.apply(prepare_input, axis=1)


In [12]:
# def concatenate_vectors(row):
#     return np.concatenate([row['K'], row['Q'], row['R'], row['B'], row['N'], row['P'], row['k'], row['q'], row['r'], row['b'], row['n'],row['p'], row['en_passant']])

# df['board_state'] = df.apply(concatenate_vectors, axis=1)

In [16]:
# X = pd.concat([df['board_state'].apply(pd.Series), df[['moves', 'half_clock', 'full_clock', 
#       'WhiteKing', 'WhiteQueen', 'BlackKing', 'BlackQueen']]], axis=1)
X = [torch.tensor(x, dtype=torch.float32) for x in df['input_tensor']]
y = df['Evaluation']
eval_scaler = MinMaxScaler(feature_range=(-1, 1))
y = eval_scaler.fit_transform(y.values.reshape(-1, 1))
# nfeats = len(X.columns)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [17]:
class Chess_Dataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = torch.tensor(y, dtype=torch.float32)
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [36]:
BATCH_SIZE = 32
CKPT_DIR = "CKPT_CNN"
LR = 0.001
EPOCHS = 1000

In [19]:
print(X_train)
print(y_train)

KeyboardInterrupt: 

In [37]:
train_dataset = Chess_Dataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataset = Chess_Dataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [21]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [38]:
class Chess_CNN(nn.Module):
    def __init__(self, input_channels=17, num_classes=1):
        super(Chess_CNN, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 1 * 1, 512)  # Corrected input size
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = torch.relu(self.conv3(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [49]:
class Chess_Model(nn.Module):
    def __init__(self, nfeats):
        super(Chess_Model, self).__init__()
        self.input_layer = nn.Linear(nfeats, 64)
        self.lyr1 = nn.Linear(64, 128)
        self.lyr2 = nn.Linear(128, 64)
        self.lyr3 = nn.Linear(64, 32)
        self.layer4 = nn.Linear(32, 16)
        self.layer5 - nn.
    
    def forward(self, x):
        x = torch.relu(self.input_layer(x))
        x = torch.relu(self.lyr1(x))
        x = torch.relu(self.lyr2(x))
        x = self.lyr3(x)
        return x
    
model = Chess_Model(nfeats).to(device)

In [117]:
print(df['Evaluation'].head())
z = eval_scaler.inverse_transform(y_test).astype(int)
print(z)

0   -10
1    56
2    -9
3    52
4   -26
Name: Evaluation, dtype: int64
[[ 332]
 [-110]
 [-797]
 ...
 [ -66]
 [ 150]
 [   0]]


In [23]:
def train(dataloader, model, optimizer, epoch):
    model = model.to(device)
    loss_fn = nn.MSELoss()
    model.train()
    with tqdm.tqdm(dataloader, unit="batch") as tbatch:
        for X, y in tbatch:
            X = X.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            pred = model(X)
            loss = loss_fn(pred, y)
            loss.backward()
            optimizer.step()
    torch.save(
        {
            "epoch" : epoch,
            "model_state_dict" : model.state_dict(),
            "optimizer_state_dict" : optimizer.state_dict(),
            "loss" : loss
        },
        f"{CKPT_DIR}/ckpt{epoch}.pt",
    )

In [34]:
def test(dataloader, model, dataset_name):
    loss_fn = nn.MSELoss()
    model = model.to(device)
    model.eval()
    test_loss = 0.0
    correct = 0
    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device)    
            y = y.to(device)
            
            # Forward pass
            pred = model(X)
           
            # Compute loss
            curr_loss = loss_fn(pred, y).item()
            test_loss += curr_loss
            
            # Append predictions and actual values
            pred = eval_scaler.inverse_transform(pred.cpu().numpy()).astype(int)
            y = eval_scaler.inverse_transform(y.cpu().numpy()).astype(int)
            
            # Round predictions and count correct predictions
            
            correct += (pred == y).sum()
    # Compute accuracy
    accuracy = correct / len(dataloader.dataset)
    print(dataset_name)
    print(f'Accuracy: {accuracy} | Mean Squared Error: {test_loss / len(dataloader.dataset) * 1000}\n')


In [119]:
# test(train_loader, model, "TRAIN")
preds = model(torch.tensor(X_test.values, dtype=torch.float32).to(device))
preds = eval_scaler.inverse_transform(preds.cpu().detach().numpy()).astype(int)
for pred, actual, in zip(preds, z):
    print(f"Prediciton: {pred}, Actual: {actual}")


Prediciton: [57], Actual: [332]
Prediciton: [-140], Actual: [-110]
Prediciton: [69], Actual: [-797]
Prediciton: [50], Actual: [128]
Prediciton: [95], Actual: [71]
Prediciton: [53], Actual: [913]
Prediciton: [34], Actual: [90]
Prediciton: [59], Actual: [31]
Prediciton: [-21], Actual: [297]
Prediciton: [141], Actual: [0]
Prediciton: [71], Actual: [0]
Prediciton: [37], Actual: [928]
Prediciton: [-239], Actual: [72]
Prediciton: [-13], Actual: [126]
Prediciton: [169], Actual: [187]
Prediciton: [-239], Actual: [-330]
Prediciton: [-41], Actual: [-383]
Prediciton: [-44], Actual: [-358]
Prediciton: [-239], Actual: [350]
Prediciton: [104], Actual: [506]
Prediciton: [62], Actual: [15]
Prediciton: [-41], Actual: [-72]
Prediciton: [41], Actual: [-339]
Prediciton: [54], Actual: [31]
Prediciton: [72], Actual: [-12]
Prediciton: [43], Actual: [5]
Prediciton: [48], Actual: [0]
Prediciton: [83], Actual: [156]
Prediciton: [59], Actual: [50]
Prediciton: [-41], Actual: [60]
Prediciton: [48], Actual: [38]
Pr

In [25]:
def make_or_restore_model():
    model = Chess_CNN().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    checkpoints = [
        CKPT_DIR + "/" + name
        for name in os.listdir(CKPT_DIR)
        if name[-1] == "t"
    ]
    if checkpoints:
        latest_checkpoint = max(checkpoints, key=os.path.getctime)
        print("Restoring from", latest_checkpoint)
        ckpt = torch.load(latest_checkpoint)
        model.load_state_dict(ckpt["model_state_dict"])
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        epoch = ckpt["epoch"]
        return model, optimizer, epoch+1
    else: 
        print("Creating new model")
        return model, optimizer, 0

In [39]:
model, optimizer, epoch_start = make_or_restore_model()
model = model.to(device)
for e in range(epoch_start, EPOCHS):
    print()
    print("Epoch", e)
    print("-------")
    model.train()
    model.to(device)
    train(train_loader, model, optimizer, e)
    print()
    model.eval()
    test(train_loader, model, "Train")
    test(test_loader, model, "Test")

Creating new model

Epoch 0
-------


100%|██████████| 10124/10124 [00:32<00:00, 312.44batch/s]



Train
Accuracy: 0.0022472603796882236 | Mean Squared Error: 0.09363615496473121

Test
Accuracy: 0.0022842890304736506 | Mean Squared Error: 0.09990957902534549


Epoch 1
-------


100%|██████████| 10124/10124 [00:32<00:00, 310.88batch/s]



Train
Accuracy: 0.0029016823583886404 | Mean Squared Error: 0.09460180107344968

Test
Accuracy: 0.002815231886205364 | Mean Squared Error: 0.10083735936223763


Epoch 2
-------


100%|██████████| 10124/10124 [00:33<00:00, 306.11batch/s]



Train
Accuracy: 0.001765704584040747 | Mean Squared Error: 0.09405309116006287

Test
Accuracy: 0.001975601323652887 | Mean Squared Error: 0.10033282340916523


Epoch 3
-------


100%|██████████| 10124/10124 [00:32<00:00, 312.35batch/s]



Train
Accuracy: 0.0 | Mean Squared Error: 0.09317190995366836

Test
Accuracy: 0.0 | Mean Squared Error: 0.09942808663043759


Epoch 4
-------


100%|██████████| 10124/10124 [00:32<00:00, 312.65batch/s]



Train
Accuracy: 0.002293563821577404 | Mean Squared Error: 0.0935270063724166

Test
Accuracy: 0.002481849162838939 | Mean Squared Error: 0.09982918969585956


Epoch 5
-------


100%|██████████| 10124/10124 [00:32<00:00, 315.86batch/s]



Train
Accuracy: 0.003309152647013428 | Mean Squared Error: 0.09452814398748619

Test
Accuracy: 0.0033832172667555686 | Mean Squared Error: 0.10076426861835328


Epoch 6
-------


100%|██████████| 10124/10124 [00:31<00:00, 316.92batch/s]



Train
Accuracy: 0.0022379996913103875 | Mean Squared Error: 0.0932944986543474

Test
Accuracy: 0.0023954166049291255 | Mean Squared Error: 0.09954192831458461


Epoch 7
-------


100%|██████████| 10124/10124 [00:32<00:00, 312.10batch/s]



Train
Accuracy: 0.005436024077789783 | Mean Squared Error: 0.09317676042741892

Test
Accuracy: 0.005111868424951845 | Mean Squared Error: 0.09944639376012343


Epoch 8
-------


100%|██████████| 10124/10124 [00:32<00:00, 312.51batch/s]



Train
Accuracy: 0.0045994752276585895 | Mean Squared Error: 0.09323721602590444

Test
Accuracy: 0.004741443176766928 | Mean Squared Error: 0.09950569904939166


Epoch 9
-------


100%|██████████| 10124/10124 [00:32<00:00, 316.35batch/s]



Train
Accuracy: 0.005176724803210372 | Mean Squared Error: 0.09320275296422452

Test
Accuracy: 0.005173605966315997 | Mean Squared Error: 0.0994540650571659


Epoch 10
-------


100%|██████████| 10124/10124 [00:32<00:00, 315.56batch/s]



Train
Accuracy: 0.004121006328137058 | Mean Squared Error: 0.0932221302123947

Test
Accuracy: 0.003988245172124265 | Mean Squared Error: 0.09948472349809696


Epoch 11
-------


100%|██████████| 10124/10124 [00:32<00:00, 316.10batch/s]



Train
Accuracy: 0.10249421206976385 | Mean Squared Error: 0.09349836987322487

Test
Accuracy: 0.10305230404504372 | Mean Squared Error: 0.09974344136853518


Epoch 12
-------


100%|██████████| 10124/10124 [00:32<00:00, 313.90batch/s]



Train
Accuracy: 0.0 | Mean Squared Error: 0.09318866241326842

Test
Accuracy: 0.0 | Mean Squared Error: 0.09944858640806058


Epoch 13
-------


100%|██████████| 10124/10124 [00:31<00:00, 320.40batch/s]



Train
Accuracy: 0.0027072079024540825 | Mean Squared Error: 0.09324850277200607

Test
Accuracy: 0.002407764113201956 | Mean Squared Error: 0.09946075017679966


Epoch 14
-------


100%|██████████| 10124/10124 [00:31<00:00, 320.59batch/s]



Train
Accuracy: 0.0030714616453156353 | Mean Squared Error: 0.09317516580226984

Test
Accuracy: 0.0030127920185706526 | Mean Squared Error: 0.09944470593367327


Epoch 15
-------


100%|██████████| 10124/10124 [00:32<00:00, 314.93batch/s]



Train
Accuracy: 0.002620774810927612 | Mean Squared Error: 0.0933275755913328

Test
Accuracy: 0.002802884377932533 | Mean Squared Error: 0.09960812655445808


Epoch 16
-------


100%|██████████| 10124/10124 [00:32<00:00, 313.85batch/s]



Train
Accuracy: 0.002858465812625405 | Mean Squared Error: 0.09320199324675543

Test
Accuracy: 0.0025312391959302614 | Mean Squared Error: 0.09945390057657412


Epoch 17
-------


100%|██████████| 10124/10124 [00:32<00:00, 315.90batch/s]



Train
Accuracy: 0.0 | Mean Squared Error: 0.09317261988035237

Test
Accuracy: 0.0 | Mean Squared Error: 0.09942905792487088


Epoch 18
-------


100%|██████████| 10124/10124 [00:32<00:00, 314.37batch/s]



Train
Accuracy: 0.0026917734218243556 | Mean Squared Error: 0.09324907213287136

Test
Accuracy: 0.0026794092952042276 | Mean Squared Error: 0.09951142191734215


Epoch 19
-------


100%|██████████| 10124/10124 [00:32<00:00, 316.05batch/s]



Train
Accuracy: 0.002889334773884859 | Mean Squared Error: 0.09318053261314524

Test
Accuracy: 0.002864621919296686 | Mean Squared Error: 0.09943855542636301


Epoch 20
-------


100%|██████████| 10124/10124 [00:31<00:00, 320.14batch/s]



Train
Accuracy: 0.002012656274116376 | Mean Squared Error: 0.09332957755380973

Test
Accuracy: 0.0020620338815627005 | Mean Squared Error: 0.09957733276901824


Epoch 21
-------


100%|██████████| 10124/10124 [00:31<00:00, 319.12batch/s]



Train
Accuracy: 0.0 | Mean Squared Error: 0.09330867391532838

Test
Accuracy: 0.0 | Mean Squared Error: 0.09955662302411968


Epoch 22
-------


100%|██████████| 10124/10124 [00:31<00:00, 318.61batch/s]



Train
Accuracy: 0.00281524926686217 | Mean Squared Error: 0.09320553387587747

Test
Accuracy: 0.003074529559934805 | Mean Squared Error: 0.09943937687647945


Epoch 23
-------


100%|██████████| 10124/10124 [00:31<00:00, 322.02batch/s]



Train
Accuracy: 0.006501003241240933 | Mean Squared Error: 0.0932827185625181

Test
Accuracy: 0.006494789351508865 | Mean Squared Error: 0.09954836168858246


Epoch 24
-------


100%|██████████| 10124/10124 [00:31<00:00, 321.97batch/s]



Train
Accuracy: 0.0027874672017286616 | Mean Squared Error: 0.09323768827360106

Test
Accuracy: 0.0031239195930261275 | Mean Squared Error: 0.0994534611904167


Epoch 25
-------


100%|██████████| 10124/10124 [00:31<00:00, 318.79batch/s]



Train
Accuracy: 0.0 | Mean Squared Error: 0.09338844548840378

Test
Accuracy: 0.0 | Mean Squared Error: 0.09963156520921347


Epoch 26
-------


100%|██████████| 10124/10124 [00:31<00:00, 326.54batch/s]



Train
Accuracy: 0.0015619694397283532 | Mean Squared Error: 0.09431410107768917

Test
Accuracy: 0.0016545661085592927 | Mean Squared Error: 0.10059602080707071


Epoch 27
-------


100%|██████████| 10124/10124 [00:31<00:00, 324.55batch/s]



Train
Accuracy: 0.0028183361629881155 | Mean Squared Error: 0.09329949821782618

Test
Accuracy: 0.002815231886205364 | Mean Squared Error: 0.0995642016447885


Epoch 28
-------


100%|██████████| 10124/10124 [00:31<00:00, 324.94batch/s]



Train
Accuracy: 0.0025281679271492515 | Mean Squared Error: 0.09345356674329816

Test
Accuracy: 0.002321331555292142 | Mean Squared Error: 0.09973181679049806


Epoch 29
-------


100%|██████████| 10124/10124 [00:31<00:00, 321.24batch/s]



Train
Accuracy: 0.0015557956474764624 | Mean Squared Error: 0.09341597192178427

Test
Accuracy: 0.0016422186002864622 | Mean Squared Error: 0.09967544332934034


Epoch 30
-------


100%|██████████| 10124/10124 [00:32<00:00, 315.47batch/s]



Train
Accuracy: 0.0026238617070535575 | Mean Squared Error: 0.09317084193256052

Test
Accuracy: 0.0031486146095717885 | Mean Squared Error: 0.09942819089427773


Epoch 31
-------


100%|██████████| 10124/10124 [00:31<00:00, 323.66batch/s]



Train
Accuracy: 0.0 | Mean Squared Error: 0.09325334986339646

Test
Accuracy: 0.0 | Mean Squared Error: 0.09951599860410652


Epoch 32
-------


100%|██████████| 10124/10124 [00:31<00:00, 325.72batch/s]



Train
Accuracy: 0.0017842259607964192 | Mean Squared Error: 0.09336617139370332

Test
Accuracy: 0.0017163036499234455 | Mean Squared Error: 0.09962260249312149


Epoch 33
-------


100%|██████████| 10124/10124 [00:31<00:00, 322.77batch/s]



Train
Accuracy: 0.001722488038277512 | Mean Squared Error: 0.09339635180105962

Test
Accuracy: 0.0017409986664691066 | Mean Squared Error: 0.09963987169085124


Epoch 34
-------


100%|██████████| 10124/10124 [00:31<00:00, 319.59batch/s]



Train
Accuracy: 0.0053465040901373665 | Mean Squared Error: 0.09335433878031953

Test
Accuracy: 0.005185953474588828 | Mean Squared Error: 0.09961976263955497


Epoch 35
-------


100%|██████████| 10124/10124 [00:31<00:00, 321.88batch/s]



Train
Accuracy: 0.0 | Mean Squared Error: 0.09378721389251843

Test
Accuracy: 0.0 | Mean Squared Error: 0.10002445544718928


Epoch 36
-------


100%|██████████| 10124/10124 [00:32<00:00, 312.50batch/s]



Train
Accuracy: 0.005639759222102176 | Mean Squared Error: 0.09317903131717277

Test
Accuracy: 0.005260038524225811 | Mean Squared Error: 0.0994357526885534


Epoch 37
-------


100%|██████████| 10124/10124 [00:32<00:00, 307.36batch/s]



Train
Accuracy: 0.004121006328137058 | Mean Squared Error: 0.0932279113527045

Test
Accuracy: 0.003988245172124265 | Mean Squared Error: 0.0994893508774107


Epoch 38
-------


100%|██████████| 10124/10124 [00:33<00:00, 303.70batch/s]



Train
Accuracy: 0.002867726501003241 | Mean Squared Error: 0.09326574394704416

Test
Accuracy: 0.0025559342124759224 | Mean Squared Error: 0.09952777210339996


Epoch 39
-------


100%|██████████| 10124/10124 [00:33<00:00, 300.32batch/s]



Train
Accuracy: 0.0027596851365951536 | Mean Squared Error: 0.09317251301862456

Test
Accuracy: 0.0028275793944781942 | Mean Squared Error: 0.09942697939468877


Epoch 40
-------


100%|██████████| 10124/10124 [00:34<00:00, 296.64batch/s]



Train
Accuracy: 0.001935483870967742 | Mean Squared Error: 0.09371550010498121

Test
Accuracy: 0.0020620338815627005 | Mean Squared Error: 0.09995283901115136


Epoch 41
-------


100%|██████████| 10124/10124 [00:34<00:00, 295.03batch/s]



Train
Accuracy: 0.0 | Mean Squared Error: 0.09318854840554851

Test
Accuracy: 0.0 | Mean Squared Error: 0.09945482965180245


Epoch 42
-------


100%|██████████| 10124/10124 [00:33<00:00, 301.49batch/s]



Train
Accuracy: 0.005436024077789783 | Mean Squared Error: 0.09318326902670797

Test
Accuracy: 0.005111868424951845 | Mean Squared Error: 0.09944527690691846


Epoch 43
-------


100%|██████████| 10124/10124 [00:34<00:00, 295.08batch/s]



Train
Accuracy: 0.0027874672017286616 | Mean Squared Error: 0.09319103148545016

Test
Accuracy: 0.0031239195930261275 | Mean Squared Error: 0.09944999673459204


Epoch 44
-------


100%|██████████| 10124/10124 [00:34<00:00, 294.25batch/s]



Train
Accuracy: 0.004596388331532644 | Mean Squared Error: 0.09418281210458235

Test
Accuracy: 0.0045809255692201315 | Mean Squared Error: 0.10041428756787342


Epoch 45
-------


100%|██████████| 10124/10124 [00:33<00:00, 301.18batch/s]



Train
Accuracy: 0.0025621237845346503 | Mean Squared Error: 0.09349310085035509

Test
Accuracy: 0.0023830690966562948 | Mean Squared Error: 0.09975507476966365


Epoch 46
-------


100%|██████████| 10124/10124 [00:33<00:00, 298.20batch/s]



Train
Accuracy: 0.00281524926686217 | Mean Squared Error: 0.0931718544235913

Test
Accuracy: 0.003074529559934805 | Mean Squared Error: 0.09943079843180383


Epoch 47
-------


100%|██████████| 10124/10124 [00:34<00:00, 297.30batch/s]



Train
Accuracy: 0.003969748417965735 | Mean Squared Error: 0.09318247138153267

Test
Accuracy: 0.003691904973576332 | Mean Squared Error: 0.0994369629712603


Epoch 48
-------


100%|██████████| 10124/10124 [00:33<00:00, 298.17batch/s]



Train
Accuracy: 0.0053465040901373665 | Mean Squared Error: 0.09335109156669294

Test
Accuracy: 0.005185953474588828 | Mean Squared Error: 0.09961889937038071


Epoch 49
-------


100%|██████████| 10124/10124 [00:33<00:00, 298.51batch/s]



Train
Accuracy: 0.10249421206976385 | Mean Squared Error: 0.09348670896289521

Test
Accuracy: 0.10305230404504372 | Mean Squared Error: 0.09972659099628246


Epoch 50
-------


100%|██████████| 10124/10124 [00:34<00:00, 295.88batch/s]



Train
Accuracy: 0.004121006328137058 | Mean Squared Error: 0.09322403736697653

Test
Accuracy: 0.003988245172124265 | Mean Squared Error: 0.09948528699975803


Epoch 51
-------


100%|██████████| 10124/10124 [00:33<00:00, 299.01batch/s]



Train
Accuracy: 0.005639759222102176 | Mean Squared Error: 0.09317905842596012

Test
Accuracy: 0.005260038524225811 | Mean Squared Error: 0.09943735684477016


Epoch 52
-------


100%|██████████| 10124/10124 [00:33<00:00, 299.83batch/s]



Train
Accuracy: 0.002620774810927612 | Mean Squared Error: 0.09332599174246205

Test
Accuracy: 0.002802884377932533 | Mean Squared Error: 0.09959137786144005


Epoch 53
-------


100%|██████████| 10124/10124 [00:34<00:00, 296.79batch/s]



Train
Accuracy: 0.0 | Mean Squared Error: 0.09318710511115408

Test
Accuracy: 0.0 | Mean Squared Error: 0.09945134607893552


Epoch 54
-------


100%|██████████| 10124/10124 [00:34<00:00, 296.23batch/s]



Train
Accuracy: 0.00281524926686217 | Mean Squared Error: 0.09317573843686687

Test
Accuracy: 0.003074529559934805 | Mean Squared Error: 0.09942874431675026


Epoch 55
-------


100%|██████████| 10124/10124 [00:34<00:00, 294.99batch/s]



Train
Accuracy: 0.005639759222102176 | Mean Squared Error: 0.09317886400889136

Test
Accuracy: 0.005260038524225811 | Mean Squared Error: 0.09943706447782262


Epoch 56
-------


100%|██████████| 10124/10124 [00:34<00:00, 296.88batch/s]



Train
Accuracy: 0.0 | Mean Squared Error: 0.09325497123866715

Test
Accuracy: 0.0 | Mean Squared Error: 0.09950459823396479


Epoch 57
-------


100%|██████████| 10124/10124 [00:34<00:00, 297.11batch/s]



Train
Accuracy: 0.0 | Mean Squared Error: 0.09321766644962044

Test
Accuracy: 0.0 | Mean Squared Error: 0.09948004955168423


Epoch 58
-------


100%|██████████| 10124/10124 [00:34<00:00, 291.86batch/s]



Train
Accuracy: 0.0037907084426609043 | Mean Squared Error: 0.09336840126628211

Test
Accuracy: 0.0038771175976687906 | Mean Squared Error: 0.09963510866013546


Epoch 59
-------


100%|██████████| 10124/10124 [00:33<00:00, 303.46batch/s]



Train
Accuracy: 0.0030714616453156353 | Mean Squared Error: 0.09319565452484922

Test
Accuracy: 0.0030127920185706526 | Mean Squared Error: 0.09945838723522189


Epoch 60
-------


100%|██████████| 10124/10124 [00:34<00:00, 294.72batch/s]



Train
Accuracy: 0.002858465812625405 | Mean Squared Error: 0.09320388088876644

Test
Accuracy: 0.0025312391959302614 | Mean Squared Error: 0.09945556636883401


Epoch 61
-------


100%|██████████| 10124/10124 [00:33<00:00, 299.82batch/s]



Train
Accuracy: 0.0 | Mean Squared Error: 0.0942348746334295

Test
Accuracy: 0.0 | Mean Squared Error: 0.10046485366607162


Epoch 62
-------


100%|██████████| 10124/10124 [00:33<00:00, 299.98batch/s]



Train
Accuracy: 0.005562586818953543 | Mean Squared Error: 0.0932329780157517

Test
Accuracy: 0.0054205561317726084 | Mean Squared Error: 0.09950803123565614


Epoch 63
-------


100%|██████████| 10124/10124 [00:34<00:00, 297.52batch/s]



Train
Accuracy: 0.006964037660132737 | Mean Squared Error: 0.09317155485371828

Test
Accuracy: 0.006050279053686966 | Mean Squared Error: 0.09942691664660677


Epoch 64
-------


100%|██████████| 10124/10124 [00:34<00:00, 297.59batch/s]



Train
Accuracy: 0.004716777280444513 | Mean Squared Error: 0.09350599521358668

Test
Accuracy: 0.004642663110584284 | Mean Squared Error: 0.09978995438290114


Epoch 65
-------


100%|██████████| 10124/10124 [00:34<00:00, 297.69batch/s]



Train
Accuracy: 0.0022379996913103875 | Mean Squared Error: 0.093289264293257

Test
Accuracy: 0.0023954166049291255 | Mean Squared Error: 0.09955134949389752


Epoch 66
-------


100%|██████████| 10124/10124 [00:33<00:00, 298.60batch/s]



Train
Accuracy: 0.002694860317950301 | Mean Squared Error: 0.0933367454659352

Test
Accuracy: 0.0026423667703857363 | Mean Squared Error: 0.09959726044451218


Epoch 67
-------


100%|██████████| 10124/10124 [00:33<00:00, 299.00batch/s]



Train
Accuracy: 0.002694860317950301 | Mean Squared Error: 0.09337143533609274

Test
Accuracy: 0.0026423667703857363 | Mean Squared Error: 0.09960772155463016


Epoch 68
-------


100%|██████████| 10124/10124 [00:34<00:00, 295.56batch/s]



Train
Accuracy: 0.00417965735453002 | Mean Squared Error: 0.09318220736291614

Test
Accuracy: 0.004000592680397096 | Mean Squared Error: 0.09946022574518384


Epoch 69
-------


100%|██████████| 10124/10124 [00:34<00:00, 297.05batch/s]



Train
Accuracy: 0.003969748417965735 | Mean Squared Error: 0.09318138859005118

Test
Accuracy: 0.003691904973576332 | Mean Squared Error: 0.09943568451216214


Epoch 70
-------


100%|██████████| 10124/10124 [00:33<00:00, 303.00batch/s]



Train
Accuracy: 0.0 | Mean Squared Error: 0.09321930133991782

Test
Accuracy: 0.0 | Mean Squared Error: 0.09948365755629145


Epoch 71
-------


100%|██████████| 10124/10124 [00:33<00:00, 298.62batch/s]



Train
Accuracy: 0.0026485568760611206 | Mean Squared Error: 0.09321660454259108

Test
Accuracy: 0.0029510544772065 | Mean Squared Error: 0.09946408910878964


Epoch 72
-------


100%|██████████| 10124/10124 [00:33<00:00, 297.99batch/s]



Train
Accuracy: 0.002864639604877296 | Mean Squared Error: 0.09329614163445064

Test
Accuracy: 0.003025139526843483 | Mean Squared Error: 0.09957303814464935


Epoch 73
-------


100%|██████████| 10124/10124 [00:33<00:00, 299.58batch/s]



Train
Accuracy: 0.0 | Mean Squared Error: 0.09317241027190615

Test
Accuracy: 0.0 | Mean Squared Error: 0.09944279025420663


Epoch 74
-------


100%|██████████| 10124/10124 [00:33<00:00, 303.09batch/s]



Train
Accuracy: 0.0026238617070535575 | Mean Squared Error: 0.09317259195152339

Test
Accuracy: 0.0031486146095717885 | Mean Squared Error: 0.0994276789712574


Epoch 75
-------


100%|██████████| 10124/10124 [00:33<00:00, 298.26batch/s]



Train
Accuracy: 0.00579719092452539 | Mean Squared Error: 0.09334947591319534

Test
Accuracy: 0.005963846495777152 | Mean Squared Error: 0.09959218231716263


Epoch 76
-------


100%|██████████| 10124/10124 [00:32<00:00, 307.76batch/s]



Train
Accuracy: 0.002849205124247569 | Mean Squared Error: 0.09317517514698469

Test
Accuracy: 0.0030374870351163137 | Mean Squared Error: 0.09944393921014441


Epoch 77
-------


100%|██████████| 10124/10124 [00:33<00:00, 301.12batch/s]



Train
Accuracy: 0.0028245099552400062 | Mean Squared Error: 0.09397670246745268

Test
Accuracy: 0.0025682817207487527 | Mean Squared Error: 0.10021106666309186


Epoch 78
-------


100%|██████████| 10124/10124 [00:33<00:00, 304.86batch/s]



Train
Accuracy: 0.002012656274116376 | Mean Squared Error: 0.09332903152532

Test
Accuracy: 0.0020620338815627005 | Mean Squared Error: 0.09957561264447831


Epoch 79
-------


100%|██████████| 10124/10124 [00:33<00:00, 298.42batch/s]



Train
Accuracy: 0.0041858311467819105 | Mean Squared Error: 0.09373377485442595

Test
Accuracy: 0.003963550155578604 | Mean Squared Error: 0.10000907121471468


Epoch 80
-------


100%|██████████| 10124/10124 [00:33<00:00, 300.11batch/s]



Train
Accuracy: 0.005176724803210372 | Mean Squared Error: 0.09322999931856142

Test
Accuracy: 0.005173605966315997 | Mean Squared Error: 0.09944965904653572


Epoch 81
-------


100%|██████████| 10124/10124 [00:34<00:00, 295.94batch/s]



Train
Accuracy: 0.0026238617070535575 | Mean Squared Error: 0.09317115224088902

Test
Accuracy: 0.0031486146095717885 | Mean Squared Error: 0.09942909328838946


Epoch 82
-------


100%|██████████| 10124/10124 [00:34<00:00, 296.77batch/s]



Train
Accuracy: 0.0018953542213304521 | Mean Squared Error: 0.09399785109251411

Test
Accuracy: 0.0018521262409245814 | Mean Squared Error: 0.10027695420255951


Epoch 83
-------


100%|██████████| 10124/10124 [00:33<00:00, 305.34batch/s]



Train
Accuracy: 0.0057447136903843185 | Mean Squared Error: 0.09320035123265094

Test
Accuracy: 0.005951498987504322 | Mean Squared Error: 0.09947694463838135


Epoch 84
-------


100%|██████████| 10124/10124 [00:33<00:00, 298.91batch/s]



Train
Accuracy: 0.00417965735453002 | Mean Squared Error: 0.09318386607560727

Test
Accuracy: 0.004000592680397096 | Mean Squared Error: 0.09944231868580805


Epoch 85
-------


 72%|███████▏  | 7306/10124 [00:23<00:09, 296.26batch/s]