In [3]:
from google.colab import drive
!pip install torch_scatter
!pip install torch_geometric

drive.mount('/content/drive')

# cd into where the data file is stored
%cd /content/drive/MyDrive/cs224w-project/

!pip install python-chess

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/cs224w-project
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/cs224w-project


In [20]:
#Extract Features
import chess
import torch

def calculate_material_balance(board):
    piece_values = {chess.PAWN: 1, chess.KNIGHT: 3, chess.BISHOP: 3, chess.ROOK: 5, chess.QUEEN: 9, chess.KING: 0}
    material_balance = 0
    for piece_type in piece_values:
        material_balance += len(board.pieces(piece_type, chess.WHITE)) * piece_values[piece_type]
        material_balance -= len(board.pieces(piece_type, chess.BLACK)) * piece_values[piece_type]
    return material_balance

def calculate_pawn_structure(board):
    # For simplicity, just counting the number of pawns for now.
    white_pawns = len(board.pieces(chess.PAWN, chess.WHITE))
    black_pawns = len(board.pieces(chess.PAWN, chess.BLACK))
    return white_pawns - black_pawns

def calculate_mobility(board):
    return board.legal_moves.count()

def get_squares_around(square):
    """ Get the squares around a given square. """
    surrounding_squares = []
    file = chess.square_file(square)
    rank = chess.square_rank(square)

    for f in range(file - 1, file + 2):  # File of the square and its adjacent files
        for r in range(rank - 1, rank + 2):  # Rank of the square and its adjacent ranks
            if 0 <= f <= 7 and 0 <= r <= 7:  # Ensure the file and rank are within bounds
                adjacent_square = chess.square(f, r)
                if adjacent_square != square:
                    surrounding_squares.append(adjacent_square)

    return surrounding_squares

def evaluate_king_safety(board):
    """ Evaluate the safety of the king based on surrounding pawns. """
    king_safety = 0
    for color in [chess.WHITE, chess.BLACK]:
        king_square = board.king(color)
        pawn_squares = board.pieces(chess.PAWN, color)
        safety_count = sum(1 for sq in get_squares_around(king_square) if sq in pawn_squares)
        king_safety += safety_count if color == chess.WHITE else -safety_count
    return king_safety



def control_of_center(board):
    center_squares = [chess.E4, chess.D4, chess.E5, chess.D5]
    center_control = 0
    for square in center_squares:
        if board.is_attacked_by(chess.WHITE, square):
            center_control += 1
        if board.is_attacked_by(chess.BLACK, square):
            center_control -= 1
    return center_control

def piece_development(board):
    """ Calculate the development of pieces from their starting positions. """
    development = 0

    # Initial positions for non-pawn pieces except the king.
    initial_positions = {
        chess.WHITE: [chess.A1, chess.B1, chess.C1, chess.D1, chess.E1, chess.F1, chess.G1, chess.H1],
        chess.BLACK: [chess.A8, chess.B8, chess.C8, chess.D8, chess.E8, chess.F8, chess.G8, chess.H8]
    }

    # Count developed pieces (those not in their initial positions).
    for color in [chess.WHITE, chess.BLACK]:
        for piece_type in [chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN]:
            for square in board.pieces(piece_type, color):
                if square not in initial_positions[color]:
                    development += 1 if color == chess.WHITE else -1

    return development


def extract_features(fen):
    board = chess.Board(fen)
    features = {
        "material_balance": calculate_material_balance(board),
        "pawn_structure": calculate_pawn_structure(board),
        "mobility": calculate_mobility(board),
        "king_safety": evaluate_king_safety(board),
        "center_control": control_of_center(board),
        "piece_development": piece_development(board)
    }

    features_vector = [features["material_balance"], features["pawn_structure"],
                    features["mobility"], features["king_safety"],
                    features["center_control"], features["piece_development"]]

    return torch.tensor(features_vector)

In [11]:
#pre-process
import chess
import pickle
import warnings
import concurrent.futures

#from extract_features import extract_features
import requests

from tqdm import tqdm

DEPTH = 5
MODE = 'eval'

def convert_moves_to_san(moves_str):
    # Split the string into individual moves, ignoring the white and black move indicators
    moves = moves_str.split()
    san_moves = []

    # Extract the standard algebraic notation (SAN) part of each move (ignoring the move number and color indicator)
    for move in moves:
        san_move = move.split('.')[1]  # Take only the move part after the dot
        san_moves.append(san_move)

    # Combine the moves into a single space-separated string
    san_moves_str = ' '.join(san_moves)
    return san_moves_str

def moves_to_fen(moves):
    board = chess.Board()
    fen_list = [board.fen()]  # Include the starting position

    for move_san in moves.split():
        move = board.parse_san(move_san)
        board.push(move)
        fen_list.append(board.fen())

    return fen_list

def preprocess_chess_data(file_path):
    # List to hold all games moves
    games_moves = []

    # Open the file and skip the first five lines of metadata
    with open(file_path, 'r') as file:
        for _ in range(5):
            next(file)  # Skip metadata lines

        # Process each game entry line
        for line in file:
            if '###' in line:
                # Split the line at '###' and take the second part, which contains the moves
                moves = line.split('###')[1].strip()
                # Store the moves in the list
                games_moves.append(moves)

    return games_moves

# Process the data to convert moves to FEN for each game
def games_to_fen(games_moves):
    all_games_fen = []
    for moves_str in games_moves:
        san_moves = convert_moves_to_san(moves_str)
        game_fen = moves_to_fen(san_moves)
        all_games_fen.append(game_fen)
    return all_games_fen


def create_feature_vector(fen):
    features = extract_features(fen)
    features_vector = [features["material_balance"], features["pawn_structure"],
                       features["mobility"], features["king_safety"],
                       features["center_control"], features["piece_development"]]

    return features_vector


def fetch_label(fen):
    """
    Fetch the label for a given FEN from the API.
    """
    try:
        response = requests.get(f'https://stockfish.online/api/stockfish.php?fen={fen}&depth={DEPTH}&mode={MODE}')
        if response.status_code == 200:
            return response.json().get('data').split()[2], False
        else:
            print(f"Error fetching data for FEN: {fen}. Status Code: {response.status_code}")
            return None, True
    except requests.RequestException as e:
        print(f"Request failed for FEN: {fen}. Error: {e}")
        return None, True


def process_chess_data(file_path):
    chess_games_moves = preprocess_chess_data(file_path)
    subset_games_fen = games_to_fen(chess_games_moves)

    # all_games = []
    all_labels = []
    exception_indices = []

    for game in tqdm(subset_games_fen):
        for i, fen in enumerate(game):
    #         x = [create_feature_vector(fen) for fen in game]
            labels, exception_occurred = fetch_label(fen)

    #         all_games.append(x)
            all_labels.append(labels)

            if exception_occurred:
                exception_indices.append(i)

    # data = {'x': all_games, 'y': all_labels}

    data = {'fen': subset_games_fen, 'y': all_labels}

    return data, exception_indices


#if __name__ == '__main__':
    file_path = 'data/subset.txt'
    data, exception_indices = process_chess_data(file_path)
    pickle.dump(data, open('/content/drive/MyDrive/cs224w-project/subset_games_test.pkl', 'wb'))

In [17]:
#graph
import torch
from torch_geometric.data import Data
import pickle

games_fen = pickle.load(open('/content/drive/MyDrive/cs224w-project/subset_games_fen.pkl', 'rb'))

# A function to encode a FEN string into a numerical feature vector
def encode_fen(fen):
    # Initialize a binary vector for the board
    # 12 piece types (6 for white, 6 for black) x 64 squares
    board_vector = [0] * (12 * 64)

    # Define a mapping from piece symbols to their index offsets
    piece_to_index = {
        'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
        'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
    }

    # Fill the vector with the presence of pieces
    pieces = fen.split(' ')[0]  # Get the piece placement part of the FEN
    square = 0
    for char in pieces:
        if char.isdigit():  # Empty squares
            square += int(char)
        elif char in piece_to_index:  # Occupied square
            index = piece_to_index[char] * 64 + square
            board_vector[index] = 1
            square += 1
        elif char == '/':  # New rank
            continue

    return board_vector


# Encode all FENs and create a mapping from FEN to node index
fen_to_index = {}
node_features = []

for game_fens in games_fen:
    for fen in game_fens:
        if fen not in fen_to_index:
            fen_to_index[fen] = len(fen_to_index)
            encoded_fen = encode_fen(fen)
            node_features.append(encoded_fen)

# Convert node features to a tensor
node_features_tensor = torch.tensor(node_features, dtype=torch.float)

# Create edges based on the sequence of moves
edge_indices = []

for game_fens in games_fen:
    for i in range(len(game_fens) - 1):
        source = fen_to_index[game_fens[i]]
        target = fen_to_index[game_fens[i + 1]]
        edge_indices.append((source, target))

# Convert edge indices to a tensor
edge_index_tensor = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()

# Create a PyG Data object
data = Data(x=node_features_tensor, edge_index=edge_index_tensor)


import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx

# Convert PyG Data object to NetworkX graph
G = to_networkx(data, to_undirected=True)

# Draw the graph using NetworkX
#plt.figure(figsize=(12, 8))
#nx.draw(G, with_labels=False, node_color='lightblue', node_size=25, edge_color='gray')
#plt.title('Graph Representation of Chess Games')
#plt.show()

In [37]:
#model
import pickle
import torch
from torch_geometric.data import Data
import torch.nn.functional as F
from torch_geometric.nn import GATConv
import torch_scatter

#from preprocess import process_chess_data
#from extract_features import extract_features

import time
from tqdm import tqdm

class GATChessModel(torch.nn.Module):
    def __init__(self, num_features, heads=8, dropout=0.6):
        super(GATChessModel, self).__init__()
        self.conv1 = GATConv(num_features, 32, heads=heads, dropout=dropout)
        self.conv2 = GATConv(32 * heads, 1, heads=1, dropout=dropout)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)

        return x


def get_edge_indices():
    file_path = '/content/drive/MyDrive/cs224w-project/subset.txt'
    # _, exception_indices = process_chess_data(file_path)
    exception_indices = []

    # Read exception indices from file
    with open('/content/drive/MyDrive/cs224w-project/exception_indices.txt', 'r') as file:
        for line in file:
            exception_indices.append(int(line.strip()))

    print("exceptions", exception_indices)

    with open('/content/drive/MyDrive/cs224w-project/subset_games_test.pkl', 'rb') as f:
        node_id = 0
        fen_to_node = {}

        games = pickle.load(f)
        print("NUM GAME", len(games))
        for i, game in enumerate(games):
            if i in exception_indices:  # skip games with exceptions
                continue
            for state in game:
                if state not in fen_to_node:
                    fen_to_node[state] = node_id
                    node_id += 1

        print("NODE ID", node_id)

        src, target = [], []
        for j, game in enumerate(games):
            for i in range(len(game) - 1):
                if j in exception_indices:  # skip games with exceptions
                    continue
                else:
                    src.append(fen_to_node[game[i]])
                    target.append(fen_to_node[game[i+1]])

        edge_index = torch.tensor([src, target], dtype=torch.long)

        return edge_index


def get_data():
    from sklearn.model_selection import train_test_split
    with open('/content/drive/MyDrive/cs224w-project/subset_games_test.pkl', 'rb') as f:
        data = pickle.load(f)
        print(extract_features(data['fen'][0][0]))
        x = [extract_features(fen) for game in data['fen'] for fen in game]
        x = torch.tensor(torch.stack(x), dtype=torch.float32)
        y = [float(score) for game in data['y'] for score in game]
        edge_index = torch.tensor(get_edge_indices(), dtype=torch.long)

        y = torch.tensor(y, dtype=torch.float32)

        dataset = Data(x=x, edge_index=edge_index, y=y)

        # Split the dataset into training, validation, and testing sets
        x_train, x_temp, y_train, y_temp = train_test_split(x, y, test_size=0.3, random_state=42)
        x_val, x_test, y_val, y_test = train_test_split(x_temp, y_temp, test_size=0.5, random_state=42)

        # Create Data objects for each set
        train_data = Data(x=x_train, edge_index=edge_index, y=y_train)
        val_data = Data(x=x_val, edge_index=edge_index, y=y_val)
        test_data = Data(x=x_test, edge_index=edge_index, y=y_test)

        return train_data, val_data, test_data


if __name__ == '__main__':
    num_features = 6
    learning_rate = 0.01
    num_epochs = 100

    model = GATChessModel(num_features)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    loss_func = torch.nn.MSELoss()  # Mean Squared Error for regression

    data, val_data, test_data = get_data()

    model.train()
    for epoch in tqdm(range(num_epochs), desc="Epochs"):
        start_time = time.time()

        optimizer.zero_grad()
        out = model(data)
        loss = loss_func(out, data.y)
        loss.backward()
        optimizer.step()

        end_time = time.time()
        epoch_duration = end_time - start_time
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.4f}, Duration: {epoch_duration:.2f} sec")

        # print(f"Epoch {epoch+1}, Loss: {loss.item()}")

    # Save the trained model
    torch.save(model.state_dict(), '/content/drive/MyDrive/cs224w-project/trained_model.pth')

    model = GATChessModel(num_features)
    model.load_state_dict(torch.load('/content/drive/MyDrive/cs224w-project/trained_model.pth'))
    model.eval()

    with torch.no_grad():
        model_output = model(test_data)



tensor([ 0,  0, 20,  0,  0,  0])


  x = torch.tensor(torch.stack(x), dtype=torch.float32)
  edge_index = torch.tensor(get_edge_indices(), dtype=torch.long)


exceptions []
NUM GAME 2
NODE ID 4


  return F.mse_loss(input, target, reduction=self.reduction)
Epochs:   1%|          | 1/100 [00:00<00:11,  8.54it/s]

Epoch 1/100, Loss: 268.2765, Duration: 0.12 sec


Epochs:   2%|▏         | 2/100 [00:00<00:11,  8.41it/s]

Epoch 2/100, Loss: 272.9636, Duration: 0.12 sec


Epochs:   3%|▎         | 3/100 [00:00<00:11,  8.14it/s]

Epoch 3/100, Loss: 206.0652, Duration: 0.13 sec


Epochs:   4%|▍         | 4/100 [00:00<00:11,  8.18it/s]

Epoch 4/100, Loss: 167.6808, Duration: 0.12 sec


Epochs:   5%|▌         | 5/100 [00:00<00:12,  7.83it/s]

Epoch 5/100, Loss: 189.2035, Duration: 0.14 sec


Epochs:   6%|▌         | 6/100 [00:00<00:11,  7.89it/s]

Epoch 6/100, Loss: 114.4422, Duration: 0.12 sec


Epochs:   7%|▋         | 7/100 [00:00<00:11,  8.06it/s]

Epoch 7/100, Loss: 78.8734, Duration: 0.12 sec


Epochs:   8%|▊         | 8/100 [00:00<00:11,  8.02it/s]

Epoch 8/100, Loss: 91.5570, Duration: 0.12 sec


Epochs:   9%|▉         | 9/100 [00:01<00:11,  8.06it/s]

Epoch 9/100, Loss: 63.5009, Duration: 0.12 sec


Epochs:  10%|█         | 10/100 [00:01<00:10,  8.19it/s]

Epoch 10/100, Loss: 63.7400, Duration: 0.12 sec


Epochs:  11%|█         | 11/100 [00:01<00:10,  8.26it/s]

Epoch 11/100, Loss: 51.3735, Duration: 0.12 sec


Epochs:  12%|█▏        | 12/100 [00:01<00:10,  8.32it/s]

Epoch 12/100, Loss: 38.4438, Duration: 0.12 sec


Epochs:  13%|█▎        | 13/100 [00:01<00:10,  8.26it/s]

Epoch 13/100, Loss: 35.1093, Duration: 0.12 sec


Epochs:  14%|█▍        | 14/100 [00:01<00:10,  8.11it/s]

Epoch 14/100, Loss: 27.1464, Duration: 0.13 sec


Epochs:  15%|█▌        | 15/100 [00:01<00:10,  8.00it/s]

Epoch 15/100, Loss: 23.7194, Duration: 0.13 sec


Epochs:  16%|█▌        | 16/100 [00:01<00:10,  8.10it/s]

Epoch 16/100, Loss: 24.9415, Duration: 0.12 sec


Epochs:  17%|█▋        | 17/100 [00:02<00:10,  8.24it/s]

Epoch 17/100, Loss: 21.4966, Duration: 0.11 sec


Epochs:  18%|█▊        | 18/100 [00:02<00:09,  8.26it/s]

Epoch 18/100, Loss: 17.8917, Duration: 0.12 sec


Epochs:  19%|█▉        | 19/100 [00:02<00:09,  8.34it/s]

Epoch 19/100, Loss: 15.7367, Duration: 0.12 sec


Epochs:  20%|██        | 20/100 [00:02<00:09,  8.28it/s]

Epoch 20/100, Loss: 13.2357, Duration: 0.12 sec


Epochs:  21%|██        | 21/100 [00:02<00:09,  8.30it/s]

Epoch 21/100, Loss: 13.1729, Duration: 0.12 sec


Epochs:  22%|██▏       | 22/100 [00:02<00:09,  8.09it/s]

Epoch 22/100, Loss: 13.2603, Duration: 0.13 sec


Epochs:  23%|██▎       | 23/100 [00:02<00:09,  8.20it/s]

Epoch 23/100, Loss: 12.7224, Duration: 0.11 sec


Epochs:  24%|██▍       | 24/100 [00:02<00:09,  8.33it/s]

Epoch 24/100, Loss: 12.9091, Duration: 0.11 sec


Epochs:  25%|██▌       | 25/100 [00:03<00:08,  8.40it/s]

Epoch 25/100, Loss: 11.5011, Duration: 0.11 sec


Epochs:  26%|██▌       | 26/100 [00:03<00:08,  8.42it/s]

Epoch 26/100, Loss: 11.0046, Duration: 0.12 sec


Epochs:  27%|██▋       | 27/100 [00:03<00:08,  8.42it/s]

Epoch 27/100, Loss: 10.1681, Duration: 0.12 sec


Epochs:  28%|██▊       | 28/100 [00:03<00:08,  8.24it/s]

Epoch 28/100, Loss: 9.8627, Duration: 0.12 sec


Epochs:  29%|██▉       | 29/100 [00:03<00:08,  8.30it/s]

Epoch 29/100, Loss: 9.4660, Duration: 0.12 sec


Epochs:  30%|███       | 30/100 [00:03<00:08,  8.30it/s]

Epoch 30/100, Loss: 9.3926, Duration: 0.12 sec


Epochs:  31%|███       | 31/100 [00:03<00:08,  8.09it/s]

Epoch 31/100, Loss: 9.7795, Duration: 0.13 sec


Epochs:  32%|███▏      | 32/100 [00:03<00:08,  8.17it/s]

Epoch 32/100, Loss: 9.4312, Duration: 0.12 sec


Epochs:  33%|███▎      | 33/100 [00:04<00:08,  8.26it/s]

Epoch 33/100, Loss: 9.4603, Duration: 0.12 sec


Epochs:  34%|███▍      | 34/100 [00:04<00:07,  8.34it/s]

Epoch 34/100, Loss: 9.2796, Duration: 0.12 sec


Epochs:  35%|███▌      | 35/100 [00:04<00:07,  8.40it/s]

Epoch 35/100, Loss: 9.3636, Duration: 0.12 sec


Epochs:  36%|███▌      | 36/100 [00:04<00:07,  8.40it/s]

Epoch 36/100, Loss: 9.0666, Duration: 0.12 sec


Epochs:  37%|███▋      | 37/100 [00:04<00:07,  8.38it/s]

Epoch 37/100, Loss: 9.0180, Duration: 0.12 sec


Epochs:  38%|███▊      | 38/100 [00:04<00:07,  8.32it/s]

Epoch 38/100, Loss: 8.9259, Duration: 0.12 sec


Epochs:  39%|███▉      | 39/100 [00:04<00:07,  8.05it/s]

Epoch 39/100, Loss: 8.7252, Duration: 0.13 sec


Epochs:  40%|████      | 40/100 [00:04<00:07,  8.18it/s]

Epoch 40/100, Loss: 8.6249, Duration: 0.12 sec


Epochs:  41%|████      | 41/100 [00:04<00:07,  8.28it/s]

Epoch 41/100, Loss: 8.7223, Duration: 0.12 sec


Epochs:  42%|████▏     | 42/100 [00:05<00:06,  8.32it/s]

Epoch 42/100, Loss: 8.7115, Duration: 0.12 sec


Epochs:  43%|████▎     | 43/100 [00:05<00:06,  8.36it/s]

Epoch 43/100, Loss: 8.6997, Duration: 0.12 sec


Epochs:  44%|████▍     | 44/100 [00:05<00:06,  8.41it/s]

Epoch 44/100, Loss: 8.6071, Duration: 0.11 sec


Epochs:  45%|████▌     | 45/100 [00:05<00:06,  8.33it/s]

Epoch 45/100, Loss: 8.6215, Duration: 0.12 sec


Epochs:  46%|████▌     | 46/100 [00:05<00:06,  8.32it/s]

Epoch 46/100, Loss: 8.6235, Duration: 0.12 sec


Epochs:  47%|████▋     | 47/100 [00:05<00:06,  8.33it/s]

Epoch 47/100, Loss: 8.5821, Duration: 0.12 sec


Epochs:  48%|████▊     | 48/100 [00:05<00:06,  8.10it/s]

Epoch 48/100, Loss: 8.5538, Duration: 0.13 sec


Epochs:  49%|████▉     | 49/100 [00:05<00:06,  8.23it/s]

Epoch 49/100, Loss: 8.5284, Duration: 0.12 sec


Epochs:  50%|█████     | 50/100 [00:06<00:06,  8.29it/s]

Epoch 50/100, Loss: 8.5413, Duration: 0.12 sec


Epochs:  51%|█████     | 51/100 [00:06<00:05,  8.30it/s]

Epoch 51/100, Loss: 8.4933, Duration: 0.12 sec


Epochs:  52%|█████▏    | 52/100 [00:06<00:05,  8.22it/s]

Epoch 52/100, Loss: 8.4776, Duration: 0.12 sec


Epochs:  53%|█████▎    | 53/100 [00:06<00:05,  8.18it/s]

Epoch 53/100, Loss: 8.4149, Duration: 0.12 sec


Epochs:  54%|█████▍    | 54/100 [00:06<00:05,  8.12it/s]

Epoch 54/100, Loss: 8.4488, Duration: 0.12 sec


Epochs:  55%|█████▌    | 55/100 [00:06<00:06,  7.30it/s]

Epoch 55/100, Loss: 8.3159, Duration: 0.16 sec


Epochs:  56%|█████▌    | 56/100 [00:06<00:06,  6.86it/s]

Epoch 56/100, Loss: 8.3766, Duration: 0.16 sec


Epochs:  57%|█████▋    | 57/100 [00:07<00:06,  6.67it/s]

Epoch 57/100, Loss: 8.3097, Duration: 0.16 sec


Epochs:  58%|█████▊    | 58/100 [00:07<00:06,  6.28it/s]

Epoch 58/100, Loss: 8.3658, Duration: 0.18 sec


Epochs:  59%|█████▉    | 59/100 [00:07<00:06,  6.36it/s]

Epoch 59/100, Loss: 8.3437, Duration: 0.15 sec


Epochs:  60%|██████    | 60/100 [00:07<00:06,  6.31it/s]

Epoch 60/100, Loss: 8.2543, Duration: 0.16 sec


Epochs:  61%|██████    | 61/100 [00:07<00:06,  6.19it/s]

Epoch 61/100, Loss: 8.3958, Duration: 0.16 sec


Epochs:  62%|██████▏   | 62/100 [00:07<00:06,  6.13it/s]

Epoch 62/100, Loss: 8.2854, Duration: 0.16 sec


Epochs:  63%|██████▎   | 63/100 [00:08<00:06,  6.16it/s]

Epoch 63/100, Loss: 8.2296, Duration: 0.16 sec


Epochs:  64%|██████▍   | 64/100 [00:08<00:05,  6.14it/s]

Epoch 64/100, Loss: 8.2512, Duration: 0.16 sec


Epochs:  65%|██████▌   | 65/100 [00:08<00:05,  6.16it/s]

Epoch 65/100, Loss: 8.2498, Duration: 0.16 sec


Epochs:  66%|██████▌   | 66/100 [00:08<00:05,  6.02it/s]

Epoch 66/100, Loss: 8.1917, Duration: 0.17 sec


Epochs:  67%|██████▋   | 67/100 [00:08<00:05,  5.92it/s]

Epoch 67/100, Loss: 8.1984, Duration: 0.17 sec


Epochs:  68%|██████▊   | 68/100 [00:08<00:05,  5.92it/s]

Epoch 68/100, Loss: 8.1360, Duration: 0.17 sec


Epochs:  69%|██████▉   | 69/100 [00:09<00:05,  5.93it/s]

Epoch 69/100, Loss: 8.1400, Duration: 0.16 sec


Epochs:  70%|███████   | 70/100 [00:09<00:04,  6.06it/s]

Epoch 70/100, Loss: 8.1156, Duration: 0.15 sec


Epochs:  71%|███████   | 71/100 [00:09<00:04,  6.06it/s]

Epoch 71/100, Loss: 8.1231, Duration: 0.16 sec


Epochs:  72%|███████▏  | 72/100 [00:09<00:04,  6.10it/s]

Epoch 72/100, Loss: 8.1248, Duration: 0.16 sec


Epochs:  73%|███████▎  | 73/100 [00:09<00:04,  6.18it/s]

Epoch 73/100, Loss: 8.0965, Duration: 0.15 sec


Epochs:  74%|███████▍  | 74/100 [00:09<00:04,  6.23it/s]

Epoch 74/100, Loss: 8.0829, Duration: 0.15 sec


Epochs:  75%|███████▌  | 75/100 [00:10<00:04,  6.05it/s]

Epoch 75/100, Loss: 8.1113, Duration: 0.17 sec


Epochs:  76%|███████▌  | 76/100 [00:10<00:03,  6.11it/s]

Epoch 76/100, Loss: 8.1040, Duration: 0.16 sec


Epochs:  77%|███████▋  | 77/100 [00:10<00:03,  6.12it/s]

Epoch 77/100, Loss: 8.0342, Duration: 0.16 sec


Epochs:  78%|███████▊  | 78/100 [00:10<00:03,  6.04it/s]

Epoch 78/100, Loss: 8.0400, Duration: 0.17 sec


Epochs:  79%|███████▉  | 79/100 [00:10<00:03,  6.12it/s]

Epoch 79/100, Loss: 8.0097, Duration: 0.16 sec


Epochs:  80%|████████  | 80/100 [00:10<00:03,  6.16it/s]

Epoch 80/100, Loss: 8.0016, Duration: 0.15 sec


Epochs:  81%|████████  | 81/100 [00:11<00:03,  6.20it/s]

Epoch 81/100, Loss: 8.0331, Duration: 0.15 sec


Epochs:  82%|████████▏ | 82/100 [00:11<00:02,  6.03it/s]

Epoch 82/100, Loss: 7.9958, Duration: 0.17 sec


Epochs:  83%|████████▎ | 83/100 [00:11<00:02,  6.08it/s]

Epoch 83/100, Loss: 8.0022, Duration: 0.16 sec


Epochs:  84%|████████▍ | 84/100 [00:11<00:02,  6.00it/s]

Epoch 84/100, Loss: 8.0440, Duration: 0.17 sec


Epochs:  85%|████████▌ | 85/100 [00:11<00:02,  6.01it/s]

Epoch 85/100, Loss: 7.9714, Duration: 0.16 sec


Epochs:  86%|████████▌ | 86/100 [00:11<00:02,  6.01it/s]

Epoch 86/100, Loss: 7.9873, Duration: 0.16 sec


Epochs:  87%|████████▋ | 87/100 [00:11<00:01,  6.61it/s]

Epoch 87/100, Loss: 7.9583, Duration: 0.11 sec


Epochs:  88%|████████▊ | 88/100 [00:12<00:01,  7.06it/s]

Epoch 88/100, Loss: 7.9585, Duration: 0.12 sec


Epochs:  89%|████████▉ | 89/100 [00:12<00:01,  7.10it/s]

Epoch 89/100, Loss: 7.8918, Duration: 0.14 sec


Epochs:  90%|█████████ | 90/100 [00:12<00:01,  7.43it/s]

Epoch 90/100, Loss: 7.9098, Duration: 0.12 sec


Epochs:  91%|█████████ | 91/100 [00:12<00:01,  7.62it/s]

Epoch 91/100, Loss: 7.9003, Duration: 0.12 sec


Epochs:  92%|█████████▏| 92/100 [00:12<00:01,  7.87it/s]

Epoch 92/100, Loss: 7.8975, Duration: 0.11 sec


Epochs:  93%|█████████▎| 93/100 [00:12<00:00,  7.84it/s]

Epoch 93/100, Loss: 7.9006, Duration: 0.13 sec


Epochs:  94%|█████████▍| 94/100 [00:12<00:00,  7.98it/s]

Epoch 94/100, Loss: 7.8596, Duration: 0.12 sec


Epochs:  95%|█████████▌| 95/100 [00:12<00:00,  8.00it/s]

Epoch 95/100, Loss: 7.8705, Duration: 0.12 sec


Epochs:  96%|█████████▌| 96/100 [00:13<00:00,  7.45it/s]

Epoch 96/100, Loss: 7.8585, Duration: 0.15 sec


Epochs:  97%|█████████▋| 97/100 [00:13<00:00,  7.39it/s]

Epoch 97/100, Loss: 7.8413, Duration: 0.14 sec


Epochs:  98%|█████████▊| 98/100 [00:13<00:00,  7.42it/s]

Epoch 98/100, Loss: 7.8442, Duration: 0.13 sec


Epochs:  99%|█████████▉| 99/100 [00:13<00:00,  7.39it/s]

Epoch 99/100, Loss: 7.8737, Duration: 0.13 sec


Epochs: 100%|██████████| 100/100 [00:13<00:00,  7.33it/s]

Epoch 100/100, Loss: 7.8312, Duration: 0.13 sec





In [41]:
    import numpy as np

    true_labels = test_data.y.numpy()
    predictions = model_output.argmax(dim=1).numpy()

    result = (true_labels - predictions)**2

    # Evaluate results (Mean Squared Error and Mean Absolute Error)
    mse = np.mean((result)**2)
    mae = np.mean(np.abs(true_labels - predictions))

    print("Mean Squared Error:", mse)
    print("Mean Absolute Error:", mae)

Mean Squared Error: 501.4729558698046
Mean Absolute Error: 1.6331724122038176
