# Training data creation notebook

In this notebook we create training data for different models. 

### piece classification training data

for piece classification and piece color classification we extract boards and squares from videos. We have two ways of labelling them: By recognising starting positions where all piece positions are known and with the CLIP-based classification function.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from PIL import Image
import open_clip

import numpy as np
import cv2
import matplotlib.pyplot as plt
import copy

from cv2 import CAP_PROP_FRAME_COUNT, CAP_PROP_POS_FRAMES

import os

In [None]:
from src.video_utils import *
from src.board_extraction import *
from src.game_extraction import *

In [None]:
path = 'videos'
banter_videos = []

import os
count = 0
for subdir, dirs, files in os.walk(path):
    for file in files:
        filepath = subdir + os.sep + file
        if filepath.endswith(".mp4") and ('banter' in filepath or 'Banter' in filepath):
            banter_videos.append(filepath)
            count+=1
print(f"""{count} banter blitz videos.""")

In [None]:
# Hier sammeln wir bretter und startpositions: naja, erstmal nur bretter mit num
# Am schlauesten wird das schon direkt rausgeschrieben. 

save_file = '/home/user/Schreibtisch/Youtube_Scrapen_Repo/Youtube_Scrapen/bretter_images/'

count = 0
for i,video_path in enumerate(banter_videos[15:]):
    i = i + 15
    try:
        n = 10000
        board = None
        frame_number = 0
        for frame in read_every_nth_frame(video_path, n):
            if board is None:
                board = largest_board_extraction(frame)
            if board is not None:
                brett = schachbrett_auschneiden(frame,board)

                fb = board[2]
                num = empty_middle_squares(fb,brett)

                file_name = f"""vid_{i}_frame_{frame_number}_num{num}_fb_{fb}.jpeg"""

                cv2.imwrite(save_file+file_name,brett)

                count += 1

                #bretter.append((brett,fb,num))
                #plotting(brett,title=str(num)+' '+str(frame_number))
                print(i,count)

            else:
                pass
                #plotting(frame)

            frame_number += n
    except:
        pass
    

In [None]:
# Hier lesen wir alles wieder ein und sieben die nicht-bretter aus:
# 
folder_path = '/home/user/Schreibtisch/Youtube_Scrapen_Repo/Youtube_Scrapen/bretter_images/'

def parse_bretter_images(file_name):
    entries = file_name.split('_')
    video = int(entries[1])
    frame = int(entries[3])
    num   = int(entries[4][3:])
    fb    = int(entries[6][:-5])
    return (video,frame,num,fb)
    
for _,_,files in os.walk(folder_path):
    break

count = 0
for file in files:
    (video,frame,num,fb) = parse_bretter_images(file)
    if num == 32:
        count += 1
        
len(files)*64
count*64

In [None]:
is_it_board_text = open_clip.tokenize(["a full chessboard without edge",
                               "part of a chessboard with an edge beyond",
                               "a person", 
                               "a graphic"])

is_it_board_text_features = model.encode_text(is_it_board_text)
is_it_board_text_features /= is_it_board_text_features.norm(dim=-1, keepdim=True)

def is_it_a_board_classification(brett):
    image = Image.fromarray(brett)
    image = preprocess(image).unsqueeze(0)
    
    with torch.no_grad(), torch.cuda.amp.autocast():
        image_features = model.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_probs = (100.0 * image_features @ is_it_board_text_features.T).softmax(dim=-1)
        
    return text_probs[0].detach().numpy()

In [None]:
# unoptimiert: 1000 2.1150066534678142 min
# text pre-encoded: 1000 0.6793083985646565 min

# ist gepickelt riesig - nachher löschen.

import pickle

count = 0
start_time = time.time()
eval_images = []
for i,file in enumerate(os.listdir(folder_path)):
    if file.endswith('jpeg'):
        im = cv2.imread(folder_path+file)
        
        #resized = cv2.resize(im, (64,64), interpolation = cv2.INTER_AREA)
        #gray = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
        #reshaped = gray.reshape((64,64,1))

        #category = file.split('_')[2]
        #label = label2array[category]
        #image = reshaped
        classification = is_it_a_board_classification(im)
        
        eval_images.append((classification[0],classification,im,file))
        
        if i%10000 == 0:
            print(i,(time.time()-start_time)/60,'min')
            pickle.dump(eval_images,open('eval_images.pickle','wb'))
        
        #if classification[0] < 0.95:
        #    plotting(im,title=str(classification[0])+' '+str(i))
        #    count += 1
        #if count > 20:
        #    break



In [None]:
from operator import itemgetter
eval_images.sort(key=itemgetter(0))    


#for (score,classi,im,file_name) in eval_images[3000:3100]:
#    plotting(im,title=str(score))

eval_images[7000][0]
#len(eval_images)

In [None]:
# Hier optimieren wir mal which color is being played: 
# Da könnten wir auch eine zweite Version auf basis des classifiers bauen ... 

@overall_runtime
def get_square(square, fb, gray_board):
    letters = 'ABCDEFGH'
    i1 = letters.index(square[0])
    i2 = int(square[1])

    x1 = round(fb * (8 - i2))
    x2 = round(fb * (8 - i2) + fb)

    y1 = round(fb * i1)
    y2 = round(fb * i1 + fb)

    return gray_board[x1:x2, y1:y2]

@overall_runtime
def which_color_is_being_played(fb, brett):
    """
    Aka upper_vs_lower_squares_brightness
    This function checks whether the upper or the lower squares are brighter.
    In the starting position this tells us whether white or black is being played.
    TODO: Don't get_square for all separately, but get top part of board etc.
    """
    upper_mean = 0.0
    for square in upper_squares:
        square_img = get_square(square, fb, brett)
        upper_mean += np.mean(square_img)

    lower_mean = 0.0
    for square in lower_squares:
        square_img = get_square(square, fb, brett)
        lower_mean += np.mean(square_img)

    if lower_mean > upper_mean:
        return "white"
    return "black"

In [None]:
# Hier bauen wir aus den start positions ein validation set
# Wir kuratieren es manuell --> keine Fehler.

manuell_aussortiert = [16,22,41,71,539,548,563,564,587,605]
validation_set = []
for (score,classi,im,file_name) in eval_images[7000:]:
    (video,frame,num,fb) = parse_bretter_images(file_name)
    if num==32:
        which_color = which_color_is_being_played(fb, im)
        validation_set.append((im,file_name,video,frame,num,fb, which_color))
        plotting(im,title=str(len(validation_set)-1)+which_color)
len(validation_set)*64

In [None]:
# Hier die manuelle kuration:
kuratiertes_evaluation_set = []

for i,(im,file_name,video,frame,num,fb, which_color) in enumerate(validation_set):
    if i+1 in [16,22,41,71,539,548,563,564,587,605]:
        plotting(im,title=str(which_color))
    else:
        kuratiertes_evaluation_set.append((im,file_name,video,frame,num,fb, which_color))

len(kuratiertes_evaluation_set)*64

pickle.dump(kuratiertes_evaluation_set,open('evaluation_data_squares.pickle','wb'))

In [None]:
# hier labeln wir den rest mit der Clip-Funktion (oder mit dem tf-model?)
start_time = time.time()
training_data = []
for i,(score,classi,im,file_name) in enumerate(eval_images[7000:]):
    (video,frame,num,fb) = parse_bretter_images(file_name)
    if num!=32:
        prob_pos = probabilistic_position(im)

        index = 0
        for r in '87654321':
            for c in 'ABCDEFGH':
                color_prob = prob_pos[index][0]
                piece_prob = prob_pos[index][1]
                square = get_square(c + r, fb, im)
                
                color = ['W', 'B'][color_prob.argmax()]
                piece = ['N','Q','K','R','B','P','E'][piece_prob.argmax()]
                
                training_data.append((square,file_name,video,frame,num,fb,color,piece,c+r))
                
                index += 1
                
        if i%1000==0:
            print(i,time.time()-start_time)
            pickle.dump(training_data,open('training_data_squares.pickle','wb'))
            
pickle.dump(training_data,open('training_data_squares.pickle','wb'))



### Hier basteln wir zwei Order mit Bildern

In [None]:
# training data
import pickle

training_data = pickle.load(open('training_data_squares.pickle','rb'))

print(len(training_data))

folder = '/home/user/Schreibtisch/PieceClassification/square_images_color_piece/'
for i,(im,file,vid,fra,num,fb,col,piec,sq) in enumerate(training_data):
    sample_path = f"""square_{i}_{col}_{piec}.jpeg"""
    cv2.imwrite(folder+sample_path, im)



In [None]:
len(training_data)

In [None]:
# validation data
kuratiertes_evaluation_set = pickle.load(open('evaluation_data_squares.pickle','rb'))
folder = '/home/user/Schreibtisch/PieceClassification/square_images_color_piece_evaluation/'

i = 0
for (im,file_name,video,frame,num,fb, which_color) in kuratiertes_evaluation_set:
    for r in '87654321':
        for c in 'ABCDEFGH':
            square = get_square(c + r, fb, im)
            if r in '3456':
                piece = 'E'
                color = 'N'
            if r in '27':
                piece = 'P'
            if r in '18':
                if c in 'AH':
                    piece = 'R'
                if c in 'BG':
                    piece = 'N'
                if c in 'CF':
                    piece = 'B'
                if c == 'D':
                    if which_color == 'white':
                        piece = 'Q'
                    if which_color == 'black':
                        piece = 'K'
                if c == 'E':
                    if which_color == 'white':
                        piece = 'K'
                    if which_color == 'black':
                        piece = 'Q'       
            
            if r in '12':
                if which_color == 'black':
                    color = 'b'
                if which_color == 'white':
                    color = 'w'
            if r in '78':
                if which_color == 'black':
                    color = 'w'
                if which_color == 'white':
                    color = 'b'          
            
            sample_path = f"""square_{i}_{color}_{piece}.jpeg"""
            i = i+1
            cv2.imwrite(folder+sample_path, square)
            
            #plotting(square,title = sample_path)


In [None]:
len(kuratiertes_evaluation_set)*64