<a href="https://colab.research.google.com/github/Teja5164/Chess_pieces_detectionusingYOLOv8/blob/main/version2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install opencv-python-headless python-chess matplotlib torch torchvision


Collecting python-chess
  Downloading python_chess-1.999-py3-none-any.whl.metadata (776 bytes)
Collecting chess<2,>=1 (from python-chess)
  Downloading chess-1.11.2.tar.gz (6.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.1/6.1 MB[0m [31m38.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading python_chess-1.999-py3-none-any.whl (1.4 kB)
Building wheels for collected packages: chess
  Building wheel for chess (setup.py) ... [?25l[?25hdone
  Created wheel for chess: filename=chess-1.11.2-py3-none-any.whl size=147775 sha256=051216faf6133d6f5d0f50e074d89e64ea5933a385ffd89c86a9e2fb01ffdf3d
  Stored in directory: /root/.cache/pip/wheels/83/1f/4e/8f4300f7dd554eb8de70ddfed96e94d3d030ace10c5b53d447
Successfully built chess
Installing collected packages: chess, python-chess
Successfully installed chess-1.11.2 python-chess-1.999


In [12]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import chess
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from PIL import Image
from torchvision import datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

def show_img(img, title="Image"):
    plt.figure(figsize=(6,6))
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    plt.title(title)
    plt.axis('off')
    plt.show()


In [11]:
from google.colab import files
import zipfile
import os

# Upload your downloaded dataset ZIP file from Kaggle
uploaded = files.upload()

# Assuming the file name is 'chess-pieces-dataset.zip'
zip_filename = next(iter(uploaded.keys()))
with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
    zip_ref.extractall('chess_pieces_dataset')

print("Dataset extracted.")


Saving chess_pieces_dataset.zip to chess_pieces_dataset.zip
Dataset extracted.


In [13]:
data_transforms = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(root='chess_pieces_dataset', transform=data_transforms)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

print("Classes:", train_dataset.classes)


Classes: ['Chess Pieces.yolov8-obb', 'Chess_pieces']


In [14]:
class SimpleChessCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleChessCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 16 * 16, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))   # [batch, 32, 32, 32]
        x = self.pool(F.relu(self.conv2(x)))   # [batch, 64, 16, 16]
        x = x.view(-1, 64 * 16 * 16)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

num_classes = len(train_dataset.classes)
model = SimpleChessCNN(num_classes)


In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 15

for epoch in range(num_epochs):
    running_loss = 0.0
    model.train()
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}")

print("Training complete.")


Epoch 1/15, Loss: 0.7262
Epoch 2/15, Loss: 0.6403
Epoch 3/15, Loss: 0.6432
Epoch 4/15, Loss: 0.6406
Epoch 5/15, Loss: 0.6383
Epoch 6/15, Loss: 0.6419
Epoch 7/15, Loss: 0.6392
Epoch 8/15, Loss: 0.6378
Epoch 9/15, Loss: 0.6394
Epoch 10/15, Loss: 0.6414
Epoch 11/15, Loss: 0.6401
Epoch 12/15, Loss: 0.6391
Epoch 13/15, Loss: 0.6430
Epoch 14/15, Loss: 0.6389
Epoch 15/15, Loss: 0.6379
Training complete.


In [16]:
from google.colab import drive
drive.mount('/content/drive')

torch.save(model.state_dict(), '/content/drive/MyDrive/chess_piece_classifier.pth')
print("Model saved to Google Drive.")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Model saved to Google Drive.


In [17]:
# Re-run this cell after reopening notebook and mounting drive

model = SimpleChessCNN(num_classes)
model.load_state_dict(torch.load('/content/drive/MyDrive/chess_piece_classifier.pth'))
model.to(device)
model.eval()
print("Model loaded and ready for inference.")


Model loaded and ready for inference.


In [30]:
def order_points(pts):
    rect = np.zeros((4, 2), dtype="float32")
    s = pts.sum(axis=1)
    rect[0] = pts[np.argmin(s)]
    rect[2] = pts[np.argmax(s)]
    diff = np.diff(pts, axis=1)
    rect[1] = pts[np.argmin(diff)]
    rect[3] = pts[np.argmax(diff)]
    return rect

def detect_and_warp_board(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    blur = cv2.GaussianBlur(gray, (5,5), 0)
    edges = cv2.Canny(blur, 50, 150)
    contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    if not contours:
        raise Exception("No contours found")
    contour_areas = [cv2.contourArea(c) for c in contours]
    max_contour_idx = np.argmax(contour_areas)
    max_contour = contours[max_contour_idx]

    peri = cv2.arcLength(max_contour, True)
    approx = None
    for eps_factor in np.linspace(0.01, 0.05, 5):
        approx_tmp = cv2.approxPolyDP(max_contour, eps_factor * peri, True)
        if len(approx_tmp) == 4:
            approx = approx_tmp
            break

    if approx is None or len(approx) != 4:
        x, y, w, h = cv2.boundingRect(max_contour)
        approx = np.array([[[x, y]], [[x+w, y]], [[x+w, y+h]], [[x, y+h]]], dtype=np.int32)

    pts = approx.reshape(4, 2)
    rect = order_points(pts)

    widthA = np.linalg.norm(rect[2] - rect[3])
    widthB = np.linalg.norm(rect[1] - rect[0])
    maxWidth = max(int(widthA), int(widthB))

    heightA = np.linalg.norm(rect[1] - rect[2])
    heightB = np.linalg.norm(rect[0] - rect[3])
    maxHeight = max(int(heightA), int(heightB))

    dst = np.array([
        [0,0],
        [maxWidth-1,0],
        [maxWidth-1,maxHeight-1],
        [0,maxHeight-1]
    ], dtype="float32")

    M = cv2.getPerspectiveTransform(rect, dst)
    warped = cv2.warpPerspective(img, M, (maxWidth, maxHeight))
    warped_resized = cv2.resize(warped, (400, 400))
    return warped_resized


In [31]:
def split_board_into_squares(board_img):
    squares = []
    sq_len = board_img.shape[0] // 8
    for row in range(8):
        r_squares = []
        for col in range(8):
            sq = board_img[row*sq_len:(row+1)*sq_len, col*sq_len:(col+1)*sq_len]
            r_squares.append(sq)
        squares.append(r_squares)
    return squares


In [32]:
def detect_piece_color(square_img):
    hsv = cv2.cvtColor(square_img, cv2.COLOR_BGR2HSV)
    white_lower = np.array([0,0,180])
    white_upper = np.array([180,30,255])
    black_lower = np.array([0,0,0])
    black_upper = np.array([180,255,50])

    white_mask = cv2.inRange(hsv, white_lower, white_upper)
    black_mask = cv2.inRange(hsv, black_lower, black_upper)

    white_count = cv2.countNonZero(white_mask)
    black_count = cv2.countNonZero(black_mask)

    if white_count > black_count and white_count > 50:
        return 'white'
    elif black_count > white_count and black_count > 50:
        return 'black'
    else:
        return 'empty'


In [33]:
def classify_square_img(square_img):
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((64,64)),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],
                             [0.229,0.224,0.225])
    ])
    input_tensor = transform(square_img).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = model(input_tensor)
        _, predicted = torch.max(outputs, 1)
    return train_dataset.classes[predicted.item()]


In [34]:
def combined_piece_detection(square_img):
    color = detect_piece_color(square_img)
    if color == 'empty':
        return '.'
    piece = classify_square_img(square_img)
    if color == 'white' and piece.isupper():
        return piece
    elif color == 'black' and piece.islower():
        return piece
    else:
        return '.'


In [35]:
def generate_fen(board_pieces):
    fen_rows = []
    for row in board_pieces:
        fen = ''
        empty_count = 0
        for sq in row:
            if sq == '.' or sq == 'unknown':
                empty_count += 1
            else:
                if empty_count > 0:
                    fen += str(empty_count)
                    empty_count = 0
                fen += sq
        if empty_count > 0:
            fen += str(empty_count)
        fen_rows.append(fen)
    fen_str = '/'.join(fen_rows) + ' w - - 0 1'
    return fen_str


In [36]:
def process_image(image_path):
    img = cv2.imread(image_path)
    try:
        warped = detect_and_warp_board(img)
        squares = split_board_into_squares(warped)
        board_pieces = []
        for row in squares:
            detected_row = []
            for sq in row:
                detected_row.append(combined_piece_detection(sq))
            board_pieces.append(detected_row)
    except Exception as e:
        print(f"Detection failed: {e}, returning empty board FEN")
        return "8/8/8/8/8/8/8/8 w - - 0 1"

    fen = generate_fen(board_pieces)
    print(fen)
    return fen


In [37]:
from google.colab import files
uploaded = files.upload()

for fname in uploaded.keys():
    print(f"Processing {fname} ...")
    process_image(fname)


Saving 1.jpeg to 1 (4).jpeg
Saving 2.jpeg to 2 (4).jpeg
Saving 3.jpeg to 3 (4).jpeg
Saving 4.jpeg to 4 (4).jpeg
Saving 5.jpeg to 5 (4).jpeg
Processing 1 (4).jpeg ...
8/8/8/8/8/8/8/8 w - - 0 1
Processing 2 (4).jpeg ...
8/8/8/8/8/8/8/8 w - - 0 1
Processing 3 (4).jpeg ...
8/8/8/8/8/8/8/8 w - - 0 1
Processing 4 (4).jpeg ...
8/8/8/8/8/8/8/8 w - - 0 1
Processing 5 (4).jpeg ...
8/8/8/8/8/8/8/8 w - - 0 1
