In [7]:
import warnings
warnings.filterwarnings('ignore')
from glob import glob
from IPython.display import display
from matplotlib import pyplot as plt
from matplotlib import style
from tqdm import tqdm
import os
import random
import re
#import cv2 as cv
import numpy as np
import pandas as pd
import seaborn as sns
# from chess_positions import Check
# from chess_positions import IllegalPosition
# style.use(style='seaborn-deep')

In [8]:
tr_path = "./dataset/train"
te_path = "./dataset/test"

In [9]:
tr_images = glob(pathname=(tr_path + '/*jpeg'))
te_images = glob(pathname=(te_path + '/*jpeg'))

In [10]:
tr_labels = list(map(lambda x: os.path.basename(x).split('.')[0].replace('-', '/'), tr_images))
te_labels = list(map(lambda x: os.path.basename(x).split('.')[0].replace('-', '/'), te_images))

In [11]:
tr_df = pd.DataFrame(data={'image': tr_images, 'label': tr_labels})
te_df = pd.DataFrame(data={'image': te_images, 'label': te_labels})

In [12]:
tr_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 0 entries
Data columns (total 2 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   image   0 non-null      float64
 1   label   0 non-null      float64
dtypes: float64(2)
memory usage: 132.0 bytes


In [13]:
te_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 0 entries
Data columns (total 2 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   image   0 non-null      float64
 1   label   0 non-null      float64
dtypes: float64(2)
memory usage: 132.0 bytes


In [14]:
display(any(tr_df.duplicated(subset=['label'])))
display(any(te_df.duplicated(subset=['label'])))

False

False

In [15]:
class Board(object):
    """
    This class is defines the chessboard.
    """

    def __init__(self, fen_label):
        self.fen_label = re.sub(pattern=r'\d',
                                repl=lambda x: self.get_ones(char=x.group()),
                                string=fen_label)
        self.fen_matrix = self.get_fen_matrix()

    def get_ones(self, char):
        """
        This method returns repetitive 1s based on input digit character.
        """
        if char.isdigit():
            return '1' * int(char)

    def get_fen_matrix(self):
        """
        This method constructs a FEN matrix.
        """
        fen_matrix = np.array([list(row) for row in self.fen_label.split('/')])
        return fen_matrix

    def get_piece_positions(self, notation):
        """
        This method returns the 2D index of the piece from FEN matrix.
        """
        (i, j) = np.where(self.fen_matrix == notation)
        try:
            if i is not None and j is not None:
                return i, j
        except:
            return None

In [16]:
class Check(Board):
    """
    This class finds if there are any checks in the chessboard.
    """

    def __init__(self, fen_label):
        super().__init__(fen_label=fen_label)

    def get_sub_matrix(self, ai, aj, di, dj):
        """
        This method chops the chessboard to a sub-matrix.
        """
        corners = np.array([(ai, aj), (di, aj), (ai, dj), (di, dj)])
        min_i, max_i = min(corners[:, 0]), max(corners[:, 0])
        min_j, max_j = min(corners[:, 1]), max(corners[:, 1])
        sub_matrix = self.fen_matrix[min_i:max_i+1, min_j:max_j+1]
        return sub_matrix, sub_matrix.shape

    def get_straight_checks(self, ai, aj, di, dj, a, d):
        """
        This method returns the checks along the straight path.
        """
        checks = list()
        for (i, j) in zip(ai, aj):
            if di == i:
                attack_path = self.fen_matrix[di]
            elif dj == j:
                attack_path = self.fen_matrix[:, dj]
            else:
                continue
            a_ind = np.where(attack_path == a)[0]
            d_ind = np.where(attack_path == d)[0][0]
            for a_i_ in a_ind:
                attack_path_ = attack_path[min(a_i_, d_ind): max(a_i_, d_ind)+1]
                checks.append(np.where(attack_path_ != '1')[0])
        checks = list(filter(lambda x: len(x) == 2, checks))
        return checks

    def get_diagonal_checks(self, ai, aj, di, dj, a):
        """
        This method returns the checks along the diagonal path.
        """
        checks = list()
        for (i, j) in zip(ai, aj):
            sub_mat, sub_shape = self.get_sub_matrix(ai=i, aj=j, di=di, dj=dj)
            if sub_shape[0] == sub_shape[1]:
                if a not in sub_mat.diagonal():
                    sub_mat = np.flipud(m=sub_mat)
                checks.append(np.where(sub_mat.diagonal() != '1')[0])
            else:
                continue
        checks = list(filter(lambda x: len(x) == 2, checks))
        return checks

    def get_knight_checks(self, ai, aj, di, dj):
        """
        This method returns the checks along the L-shaped paths for knights.
        """
        checks = list()
        for (i, j) in zip(ai, aj):
            attack_positions = [(i-2, j-1), (i-2, j+1),
                                (i-1, j-2), (i-1, j+2),
                                (i+1, j-2), (i+1, j+2),
                                (i+2, j-1), (i+2, j+1)]
            if (di, dj) in attack_positions:
                checks.append((i, j))
        return checks

    def get_pawn_checks(self, ai, aj, di, dj):
        """
        This method returns the checks for pawns.
        """
        checks = list()
        for (i, j) in zip(ai, aj):
            _, sub_shape = self.get_sub_matrix(ai=i, aj=j, di=di, dj=dj)
            if sub_shape[0] == 2 and sub_shape[1] == 2:
                checks.append((i, j))
            else:
                continue
        return checks

    def king_checks_king(self, attacker, defendant):
        """
        This method checks if the king is being attacked by the other king.
        This is unlikely, but I am just adding a validation rule.
        """
        flag = False
        di, dj = self.get_piece_positions(notation=defendant)
        if len(di) == 1 and len(dj) == 1:
            di, dj = di[0], dj[0]
        else:
            return flag
        ai, aj = self.get_piece_positions(notation=attacker)
        ai, aj = ai[0], aj[0]
        attack_positions = [(di, dj-1), (di, dj+1),
                            (di-1, dj), (di+1, dj),
                            (di-1, dj+1), (di-1, dj-1),
                            (di+1, dj-1), (di+1, dj+1)]
        if (ai, aj) in attack_positions:
            flag = True
        return flag

    def rook_checks_king(self, attacker, defendant):
        """
        This method checks if the king is being attacked by the rook.
        """
        flag = False
        di, dj = self.get_piece_positions(notation=defendant)
        if len(di) == 1 and len(dj) == 1:
            di, dj = di[0], dj[0]
        else:
            return flag
        ai, aj = self.get_piece_positions(notation=attacker)
        checks = self.get_straight_checks(
            ai=ai, aj=aj, di=di, dj=dj, a=attacker, d=defendant)
        if checks:
            flag = True
        return flag

    def bishop_checks_king(self, attacker, defendant):
        """
        This method checks if the king is being attacked by the bishop.
        """
        flag = False
        di, dj = self.get_piece_positions(notation=defendant)
        if len(di) == 1 and len(dj) == 1:
            di, dj = di[0], dj[0]
        else:
            return flag
        ai, aj = self.get_piece_positions(notation=attacker)
        checks = self.get_diagonal_checks(
            ai=ai, aj=aj, di=di, dj=dj, a=attacker)
        if checks:
            flag = True
        return flag

    def knight_checks_king(self, attacker, defendant):
        """
        This method checks if the king is being attacked by the knight.
        """
        flag = False
        di, dj = self.get_piece_positions(notation=defendant)
        if len(di) == 1 and len(dj) == 1:
            di, dj = di[0], dj[0]
        else:
            return flag
        ai, aj = self.get_piece_positions(notation=attacker)
        checks = self.get_knight_checks(ai=ai, aj=aj, di=di, dj=dj)
        if checks:
            flag = True
        return flag

    def queen_checks_king(self, attacker, defendant):
        """
        This method checks if the king is being attacked by the queen.
        """
        flag = False
        di, dj = self.get_piece_positions(notation=defendant)
        if len(di) == 1 and len(dj) == 1:
            di, dj = di[0], dj[0]
        else:
            return flag
        ai, aj = self.get_piece_positions(notation=attacker)
        straight_checks = self.get_straight_checks(
            ai=ai, aj=aj, di=di, dj=dj, a=attacker, d=defendant)
        diagonal_checks = self.get_diagonal_checks(
            ai=ai, aj=aj, di=di, dj=dj, a=attacker)
        if straight_checks or diagonal_checks:
            flag = True
        return flag

    def pawn_checks_king(self, attacker, defendant):
        """
        This methos checks if the king is being attacked by the pawn.

        Note: It is hard to determine from an image, which side of 
              the chessboard is black or is white.
              Hence, this method assumes the pawn is attacking the king 
              if both the pieces are diagnolly aligned by 1 step.
        """
        flag = False
        di, dj = self.get_piece_positions(notation=defendant)
        if len(di) == 1 and len(dj) == 1:
            di, dj = di[0], dj[0]
        else:
            return flag
        ai, aj = self.get_piece_positions(notation=attacker)
        checks = self.get_pawn_checks(ai=ai, aj=aj, di=di, dj=dj)
        if checks:
            flag = True
        return flag