In [None]:
import torch
from torchvision import transforms

from vars.loading_data import ChessDB as DB
from vars.model import ConvolutionalNetwork
from vars.utilities import *
import pytorch_lightning as pl
from sklearn.metrics import classification_report


In [None]:
PATH = "./Data/Chess/"
CSV_PATH =  "./Data/"
BATCH = 32
device = torch.device("cuda")

In [None]:
transform=transforms.Compose([
        transforms.RandomRotation(10),      # rotate +/- 10 degrees
        transforms.RandomHorizontalFlip(),  # reverse 50% of images
        transforms.Resize(224),             # resize shortest side to 224 pixels
        transforms.CenterCrop(224),         # crop longest side to 224 pixels at center
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
])

In [None]:
# datloader
data = DB( directory=PATH ,transform=transform, batch_size=32)
print(data.labels)
print(len(data))
print(data.dir_list)

In [None]:
# util
create_CSV(dir=data.dir, out_dir=CSV_PATH)
plot_bar(data.dir,data.labels)20
plot_img(dir_list=data.dir_list, labels= data.labels)
# pre_process(dir_list=data.dir_list)


In [None]:
# dataloader
data.db_split(train_ratio=0.6, valid_ratio=0.2, test_ratio=0.2)
trainDB = data.train_dataloader()
validDB = data.valid_dataloader()
testDB = data.test_dataloader()

In [None]:
# model
model = ConvolutionalNetwork(data.labels)
trainer = pl.Trainer(max_epochs=50)
trainer.fit(model, data)
data.setup(stage='valid')
valid_loader = data.valid_dataloader()
trainer.test(dataloaders=valid_loader)
test_loader = data.test_dataloader()
trainer.test(dataloaders=test_loader)

In [None]:
model.eval()
y_true=[]
y_pred=[]
with torch.no_grad():
    for item in data.test_dataloader():
        item_img, item_lbl = item[0].to(device), item[1].to(device)
        pred = model(item_img).argmax(dim=1)
        for i in range(len(pred)):
            y_true.append(item_lbl[i].item())
            y_pred.append(pred[i].item())

print(classification_report(y_true=y_true,y_pred=y_pred,target_names=data.labels, digits=4))