In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("..")

from torch.utils.data import DataLoader
from datasets import Dataset, load_from_disk
from transformers import AutoTokenizer, DataCollatorWithPadding
from utils.utils import load_sql_to_df, save_to_sql
import pytorch_lightning as pl
from models.lightning import LitHuggingfaceClassifier
import torch.nn as nn

In [3]:
# checkpoint = "google/flan-t5-small"
checkpoint = "distilbert-base-multilingual-cased"
# checkpoint = "bert-base-multilingual-uncased"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [15]:
# model = AutoModelForSequenceClassification.from_pretrained("../../models/distillbert-12-18/distilbert-base-multilingual-cased/")
pl_model = LitHuggingfaceClassifier("../../models/distillbert-12-31/v1/epoch_2")

In [5]:
chess_database_file = "../../data/chess_moves_comments_nags.db"

important_columns = ["fen", "move", "comment", "color_comment", "sentiment"]
unlabeled_moves = load_sql_to_df("SELECT * FROM unlabeled_moves_with_comments", chess_database_file)[important_columns]

In [6]:
unlabeled_dataset = Dataset.from_pandas(unlabeled_moves)
unlabeled_dataset

Dataset({
    features: ['fen', 'move', 'comment', 'color_comment', 'sentiment'],
    num_rows: 2667624
})

In [7]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

def tokenize_function(example):
    return tokenizer(example["color_comment"], truncation=True)

def sample_length_function(example):
    example["length"] = example["input_ids"].shape[-1]
    return example

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


In [8]:

unlabeled_tokenized_dataset = unlabeled_dataset.map(tokenize_function, batched=True)
# unlabeled_tokenized_dataset.remove_columns(["comment", "sentiment"])
unlabeled_tokenized_dataset.set_format("torch")

Map:   0%|          | 0/2667624 [00:00<?, ? examples/s]

In [9]:
unlabeled_tokenized_dataset = unlabeled_tokenized_dataset.map(sample_length_function, batched=False)
unlabeled_tokenized_dataset = unlabeled_tokenized_dataset.sort("length")

Map:   0%|          | 0/2667624 [00:00<?, ? examples/s]

In [10]:
unlabeled_tokenized_dataset.save_to_disk("../../data/datasets/unlabeled_tokenized_dataset")

Saving the dataset (0/3 shards):   0%|          | 0/2667624 [00:00<?, ? examples/s]

In [11]:
unlabeled_tokenized_dataset = Dataset.load_from_disk("../../data/datasets/unlabeled_tokenized_dataset")

In [12]:
unlabeled_tokenized_dataset

Dataset({
    features: ['fen', 'move', 'comment', 'color_comment', 'sentiment', 'input_ids', 'attention_mask', 'length'],
    num_rows: 2667624
})

In [13]:
predict_dataset = unlabeled_tokenized_dataset.remove_columns(["fen", "move", "comment", "color_comment", "length", "sentiment"])

In [None]:
predict_dataset = predict_dataset.remove_columns(["labels"])

In [None]:
predict_dataloader = DataLoader(predict_dataset, collate_fn=data_collator, batch_size=16, shuffle=False)

In [18]:
trainer = pl.Trainer(
    accelerator="gpu",
)

prediction = trainer.predict(pl_model, dataloaders=predict_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /home/kamil/Projects/Master-Thesis/src/notebooks/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/kamil/miniconda3/envs/thesis/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]

In [19]:
flatten_predictions = [value.item() for batch in prediction for value in batch]

In [20]:
predictions_dataset = unlabeled_tokenized_dataset.add_column("prediction", flatten_predictions)

In [21]:
predictions_dataset.save_to_disk("../../data/datasets/predictions_dataset")

Saving the dataset (0/3 shards):   0%|          | 0/2667624 [00:00<?, ? examples/s]

In [22]:
predictions_dataset

Dataset({
    features: ['fen', 'move', 'comment', 'color_comment', 'sentiment', 'input_ids', 'attention_mask', 'length', 'prediction'],
    num_rows: 2667624
})

In [23]:
predictions_df = predictions_dataset.to_pandas()
predictions_df = predictions_df[predictions_df.length > 2]

In [25]:
predictions_df

Unnamed: 0,fen,move,comment,color_comment,sentiment,input_ids,attention_mask,length,prediction
0,r1b1rb1k/pppp1ppB/2n2q1p/4p3/2P5/P3PN2/1PQP1PP...,h7e4,,white [SEP] ,-1,"[101, 15263, 102, 102]","[1, 1, 1, 1]",4,1
1,1Nkr4/p3b1pp/1p3p2/2p5/4PB2/2P4N/PP2K1PP/8 b -...,c8b7,,black [SEP] ,-1,"[101, 15045, 102, 102]","[1, 1, 1, 1]",4,1
2,3r1rk1/1pp3pp/p7/3n2q1/1P2P1P1/P2P1pP1/1BQ2P2/...,d5f6,,black [SEP] ,-1,"[101, 15045, 102, 102]","[1, 1, 1, 1]",4,1
3,1r1q1rk1/4nppp/3pb3/p3p3/R1B1P3/1PP1N3/5PPP/3Q...,e1g1,,white [SEP] ,-1,"[101, 15263, 102, 102]","[1, 1, 1, 1]",4,1
4,3r1rk1/1pp1qppp/p1n5/3np2b/1P6/P2PPN1P/1BQ1BPP...,f7f6,,black [SEP] ,-1,"[101, 15045, 102, 102]","[1, 1, 1, 1]",4,1
...,...,...,...,...,...,...,...,...,...
2667619,r1bqkbnr/pp1ppp1p/2n3p1/8/3NP3/8/PPP2PPP/RNBQK...,c2c4,Este es la Estructura Lazo de Marï¿½czy -o ten...,white [SEP] Este es la Estructura Lazo de Marï...,-1,"[101, 15263, 102, 12515, 10196, 10109, 15596, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",512,1
2667620,r1bqkbnr/pp1ppp1p/2n3p1/8/3NP3/8/PPP2PPP/RNBQK...,c2c4,Este es la Estructura Lazo de Marï¿½czy -o ten...,white [SEP] Este es la Estructura Lazo de Marï...,-1,"[101, 15263, 102, 12515, 10196, 10109, 15596, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",512,1
2667621,r1bqkbnr/pp1ppp1p/2n3p1/8/3NP3/8/PPP2PPP/RNBQK...,c2c4,Este es la Estructura Lazo de Marï¿½czy -o ten...,white [SEP] Este es la Estructura Lazo de Marï...,-1,"[101, 15263, 102, 12515, 10196, 10109, 15596, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",512,1
2667622,r1bqkbnr/pp1ppp1p/2n3p1/8/3NP3/8/PPP2PPP/RNBQK...,c2c4,Este es la Estructura Lazo de Marï¿½czy -o ten...,white [SEP] Este es la Estructura Lazo de Marï...,-1,"[101, 15263, 102, 12515, 10196, 10109, 15596, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",512,1


In [27]:
predictions_df_to_save = predictions_df[predictions_df.length > 4][["fen", "move", "comment", "color_comment", "prediction"]]
predictions_df_to_save = predictions_df_to_save.rename(columns={"prediction": "sentiment"})

In [28]:
predictions_df_to_save

Unnamed: 0,fen,move,comment,color_comment,sentiment
96,r2qkb1r/1p1b1ppp/p1nppn2/6B1/B3P3/2PQ1N2/PP3PP...,f8e7,N,black [SEP] N,1
97,r1b1k2r/ppp1nppp/5q2/2bpn3/3NP3/2P1B3/PP2BPPP/...,e8g8,N,black [SEP] N,1
98,3qr1k1/1br1bpp1/p4n1p/1p1pNR2/3P3B/P1N1P3/1P2Q...,b7c8,#,black [SEP] #,1
99,rnb4r/pp1pk1bp/1qpN1pp1/8/3P4/5N2/PPP2PPP/1R1Q...,d6c4,@,white [SEP] @,1
100,r4rk1/ppp1npb1/6pp/q2Pp3/2P5/1P2B2P/P3NPP1/R2Q...,d1d2,M,white [SEP] M,1
...,...,...,...,...,...
2667619,r1bqkbnr/pp1ppp1p/2n3p1/8/3NP3/8/PPP2PPP/RNBQK...,c2c4,Este es la Estructura Lazo de Marï¿½czy -o ten...,white [SEP] Este es la Estructura Lazo de Marï...,1
2667620,r1bqkbnr/pp1ppp1p/2n3p1/8/3NP3/8/PPP2PPP/RNBQK...,c2c4,Este es la Estructura Lazo de Marï¿½czy -o ten...,white [SEP] Este es la Estructura Lazo de Marï...,1
2667621,r1bqkbnr/pp1ppp1p/2n3p1/8/3NP3/8/PPP2PPP/RNBQK...,c2c4,Este es la Estructura Lazo de Marï¿½czy -o ten...,white [SEP] Este es la Estructura Lazo de Marï...,1
2667622,r1bqkbnr/pp1ppp1p/2n3p1/8/3NP3/8/PPP2PPP/RNBQK...,c2c4,Este es la Estructura Lazo de Marï¿½czy -o ten...,white [SEP] Este es la Estructura Lazo de Marï...,1


In [None]:
save_to_sql(predictions_df_to_save, "../../data/chess_moves_comments_nags.db", "predicted_moves_with_comments", if_exists="replace")