#### Import Libraries 

In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import time
import h5py
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm

# PyTorch libraries and modules
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.optim import Adam, SGD

from train.main_config import main_config

from nn.a0_resnet import AlphaZeroResnet

#### Splitting the data

In [3]:
H5_TRAIN = main_config["train_dir"]
H5_VAL = main_config["val_dir"]
H5_TEST = main_config["test_dir"]


def create_hdf5_generator(file_path, batch_size):
    file = h5py.File(file_path)
    data_size = file['data'].shape[0]

    while True: # loop through the dataset indefinitely
        for i in np.arange(0, data_size, batch_size):
            data = file['data'][i:i+batch_size]
            labels = file['labels'][i:i+batch_size]
            # converting data into torch format
            data  = torch.from_numpy(data)
            # converting the lables into torch format
            labels = torch.from_numpy(labels)
            yield data, labels



#### Defining Dice Loss 

In [40]:
def get_sense_plane(sense_square):
    """ 
    returns e.g. for sense_square 8
    [
    0, 0, 0, 0, 0, 0, 0, 0
    0, 1, 1, 1, 0, 0, 0, 0
    0, 1, 1, 1, 0, 0, 0, 0
    0, 1, 1, 1, 0, 0, 0, 0
    0, 0, 0, 0, 0, 0, 0, 0
    0, 0, 0, 0, 0, 0, 0, 0
    0, 0, 0, 0, 0, 0, 0, 0
    0, 0, 0, 0, 0, 0, 0, 0
    ]
    """

    # # labels = np.zeros((37, 8, 8))
    # labels = torch.zeros(8,8, dtype= torch.float64)
    # for pos in range(1, 37): 
    #     row = (pos - 1) // 6
    #     col = (pos - 1) % 6
    #     row = 5  - row
    #     for delta_rank in [1, 0, -1]:
    #         for delta_file in [-1, 0, 1]:
    #             labels[pos][row + 1 + delta_rank][col + 1 + delta_file] = 1

    # return labels[sense_square]

    plane = torch.ones(8,8, dtype= torch.float, requires_grad=True)
    helper_tensor = torch.zeros(8,8, dtype= torch.float)
    #     print('sense square ', sense_square)
    if sense_square > 0:
        row = torch.div((sense_square - 1), 6).to(torch.long)
        col = torch.fmod((sense_square - 1), 6).to(torch.long)
        row = 5  - row
    #         print('row ', row)
    #         print('col ', col)
        for delta_rank in [1, 0, -1]:
            for delta_file in [-1, 0, 1]:
                helper_tensor[row + 1 + delta_rank][col + 1 + delta_file] = 1

    new_plane = plane * helper_tensor

    #     print('plane', plane)
    #     return labels[sense_square]
    return new_plane.requires_grad(True)


class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        #flatten label and prediction tensors
        _ , inputs = torch.max(inputs, 1)
        
        new_predicted = torch.empty(64, 8, 8)
        new_labels = torch.empty(64, 8, 8)
        for i, pos in enumerate(inputs):
            sense_plane = get_sense_plane(pos)
            new_predicted[i, :, :] = sense_plane
        for i, label in enumerate(targets):
            sense_plane = get_sense_plane(label)
            new_labels[i, :, :] = sense_plane
        inputs = new_predicted
        targets = new_labels
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        torch.set_printoptions(profile="full")
        
        intersection = (inputs * targets).sum()   
        loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        return  loss

##### Creating Model - Defining optimizer and loss function 

In [None]:
# # defining the model
# model = AlphaZeroResnet()
# # defining the optimizer
# optimizer = Adam(model.parameters(), lr=0.07)
# # defining the loss function
# criterion = CrossEntropyLoss()
# # criterion = DiceLoss()
# # checking if GPU is available
# if torch.cuda.is_available():
#     model = model.cuda()
#     criterion = criterion.cuda()
    
# # print(model)

In [7]:
batch_size = 64

train_hdf5_generator = create_hdf5_generator(H5_TRAIN, batch_size)
val_hdf5_generator = create_hdf5_generator(H5_VAL, batch_size)

@torch.no_grad()
def evaluate(model, criterion, val_loader):
    model.eval()
    outputs = [model.validation_step(criterion, data, labels) for data, labels in val_loader]
    return model.validation_epoch_end(outputs)


def fit(epochs, lr, model, train_loader, val_loader, criterion, opt_func = SGD):

    history = []
    optimizer = opt_func(model.parameters(), lr)
    for epoch in range(epochs):

        model.train()
        train_losses = []
        for data, labels in train_loader: 
            if labels.shape[0] < batch_size or data.shape[0] < batch_size: 
                break
            labels = labels.to(torch.long)
            # converting the data into GPU format
            if torch.cuda.is_available():
                data, labels = data.cuda(), labels.cuda()

            loss = model.training_step(criterion, data, labels)
            train_losses.append(loss)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        model.eval()
        outputs = []
        for data, labels in val_loader:
            if labels.shape[0] < batch_size or data.shape[0] < batch_size: 
                break
            labels = labels.to(torch.long)
            # converting the data into GPU format
            if torch.cuda.is_available():
                data, labels = data.cuda(), labels.cuda()
            outputs.append[model.validation_step(criterion, data, labels)]
        result = model.validation_epoch_end(outputs)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        model.epoch_end(epoch, result)
        history.append(result)
        
    return history


In [8]:
# defining the model
model = AlphaZeroResnet()
# defining the optimizer
opt_func = Adam

lr = 0.001
# defining the loss function
criterion = CrossEntropyLoss()
# criterion = DiceLoss()
# checking if GPU is available
if torch.cuda.is_available():
    model = model.cuda()
    criterion = criterion.cuda()

num_epochs = 5

history = fit(num_epochs, lr, model, train_hdf5_generator, val_hdf5_generator, criterion, opt_func)
    
# print(model)

KeyboardInterrupt: 

#### Training 

In [None]:
# batch_size = 64

# train_hdf5_generator = create_hdf5_generator(H5_TRAIN, batch_size)
# val_hdf5_generator = create_hdf5_generator(H5_VAL, batch_size)

# epochs = 5
# # min_valid_loss = np.inf

# # train_losses = []
# # val_losses = []

# history = []

# for e in range(epochs):
#     # train_loss = 0.0
#     model.train()
#     train_losses = []
#     with tqdm(train_hdf5_generator,  unit='batch') as tepoch:
#         for data, labels  in tepoch:
#             if labels.shape[0] < batch_size or data.shape[0] < batch_size: 
#                 break
# #             tepoch.total = batch_size * labels.shape[0]
# #             tepoch.refresh()
#             tepoch.set_description(f"Epoch {e}")
#             labels = labels.to(torch.long)

#             # converting the data into GPU format
#             if torch.cuda.is_available():
#                 data, labels = data.cuda(), labels.cuda()

#             # clearing the Gradients of the model parameters
#             optimizer.zero_grad()
#             #prediction
#             target = model(data)
# #             print('criterion target shape ', data)
# #             print('criterion label shape ', labels)
#             loss = criterion(target, labels)
#             loss.backward()
#             optimizer.step()
#             train_loss += loss.item()
#             train_losses.append(loss.item())

#             tepoch.set_postfix(loss=loss.item())
#             time.sleep(0.1)

#     valid_loss = 0.0
#     model.eval()
#     with tqdm(val_hdf5_generator, unit='batch') as vepoch:
#         for data, labels in vepoch:
#             if labels.shape[0] < batch_size: 
#                 break
#             # labels = labels.to(torch.float32)
#             labels = labels.to(torch.long)
#             # converting the data into GPU format
#             if torch.cuda.is_available():
#                     data, labels = data.cuda(), labels.cuda()
#             target = model(data)
#             loss = criterion(target, labels)
#             valid_loss = loss.item() * data.size(0)
#             val_losses.append(loss.item())
#             tepoch.set_postfix(loss=loss.item())
#             time.sleep(0.1)
    
#     print(f'Epoch {e+1} \t\t Training Loss: {train_loss / batch_size} \t\t Validation Loss: {valid_loss / batch_size}')
#     if min_valid_loss > valid_loss:
#         print(f'Validation Loss Decreased({min_valid_loss:.6f}--->{valid_loss:.6f}) \t Saving The Model')
#         min_valid_loss = valid_loss
#         # Saving State Dict
#         torch.save(model.state_dict(), '/root/rbc/saved_model.pth')


In [None]:
# plotting the training and validation loss
# plt.plot(train_losses, label='Training loss')
# plt.plot(val_losses, label='Validation loss')
# plt.legend()
# plt.show()