#### Import Libraries 

In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import time
import h5py
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm

# PyTorch libraries and modules
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.optim import Adam, SGD

from train.main_config import main_config

from nn.a0_resnet import AlphaZeroResnet

#### Splitting the data

In [3]:
H5_TRAIN = main_config["train_dir"]
H5_VAL = main_config["val_dir"]
H5_TEST = main_config["test_dir"]


def create_hdf5_generator(file_path, batch_size):
    file = h5py.File(file_path)
    data_size = file['data'].shape[0]

    while True: # loop through the dataset indefinitely
        for i in np.arange(0, data_size, batch_size):
            data = file['data'][i:i+batch_size]
            labels = file['labels'][i:i+batch_size]
            # converting data into torch format
            data  = torch.from_numpy(data)
            # converting the lables into torch format
            labels = torch.from_numpy(labels)
            yield data, labels

#### Defining Dice Loss 

In [40]:
def get_sense_plane(sense_square):
    """ 
    returns e.g. for sense_square 8
    [
    0, 0, 0, 0, 0, 0, 0, 0
    0, 1, 1, 1, 0, 0, 0, 0
    0, 1, 1, 1, 0, 0, 0, 0
    0, 1, 1, 1, 0, 0, 0, 0
    0, 0, 0, 0, 0, 0, 0, 0
    0, 0, 0, 0, 0, 0, 0, 0
    0, 0, 0, 0, 0, 0, 0, 0
    0, 0, 0, 0, 0, 0, 0, 0
    ]
    """
    plane = torch.ones(8,8, dtype= torch.float, requires_grad=True)
    helper_tensor = torch.zeros(8,8, dtype= torch.float)
    if sense_square > 0:
        row = torch.div((sense_square - 1), 6).to(torch.long)
        col = torch.fmod((sense_square - 1), 6).to(torch.long)
        row = 5  - row
        for delta_rank in [1, 0, -1]:
            for delta_file in [-1, 0, 1]:
                helper_tensor[row + 1 + delta_rank][col + 1 + delta_file] = 1

    new_plane = plane * helper_tensor

    return new_plane.requires_grad(True)


class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        #flatten label and prediction tensors
        _ , inputs = torch.max(inputs, 1)
        
        new_predicted = torch.empty(64, 8, 8)
        new_labels = torch.empty(64, 8, 8)
        for i, pos in enumerate(inputs):
            sense_plane = get_sense_plane(pos)
            new_predicted[i, :, :] = sense_plane
        for i, label in enumerate(targets):
            sense_plane = get_sense_plane(label)
            new_labels[i, :, :] = sense_plane
        inputs = new_predicted
        targets = new_labels
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        torch.set_printoptions(profile="full")
        
        intersection = (inputs * targets).sum()   
        loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        return  loss

##### Defining fit method

In [7]:
batch_size = 64

train_hdf5_generator = create_hdf5_generator(H5_TRAIN, batch_size)
val_hdf5_generator = create_hdf5_generator(H5_VAL, batch_size)


def fit(epochs, lr, model, train_loader, val_loader, criterion, opt_func = SGD):
    best_val_loss = np.inf
    history = []
    optimizer = opt_func(model.parameters(), lr, weight_decay=0.01)
    for epoch in range(epochs):

        model.train()
        train_losses = []
        with tqdm(train_loader,  unit='batch') as tepoch:
            for data, labels in tepoch: 
                tepoch.set_description(f"Epoch {epoch}")
                if labels.shape[0] < batch_size or data.shape[0] < batch_size: 
                    break
                labels = labels.to(torch.long)
                # converting the data into GPU format
                if torch.cuda.is_available():
                    data, labels = data.cuda(), labels.cuda()

                loss = model.training_step(criterion, data, labels)
                train_losses.append(loss)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

        model.eval()
        outputs = []
        with tqdm(val_hdf5_generator, unit='batch') as vepoch:
            for data, labels in vepoch:
                if labels.shape[0] < batch_size or data.shape[0] < batch_size: 
                    break
                labels = labels.to(torch.long)
                # converting the data into GPU format
                if torch.cuda.is_available():
                    data, labels = data.cuda(), labels.cuda()
                outputs.append(model.validation_step(criterion, data, labels))
        result = model.validation_epoch_end(outputs)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        model.epoch_end(epoch, result)
        history.append(result)
        
        if result['val_loss'] < best_val_loss: 
            val_loss = result['val_loss']
            print(f'Validation Loss Decreased({best_val_loss}--->{val_loss}) \t Saving The Model')
            best_val_loss = val_loss
            torch.save(model, '/root/rbc/saved_model.pth')
        
    return history



#### Training the model

In [None]:
model = AlphaZeroResnet()
# print(model)
opt_func = Adam
lr = 0.001
criterion = CrossEntropyLoss()
# criterion = DiceLoss()
num_epochs = 100

# check if GPU is available
if torch.cuda.is_available():
    model = model.cuda()
    criterion = criterion.cuda()

history = fit(num_epochs, lr, model, train_hdf5_generator, val_hdf5_generator, criterion, opt_func)


#### Plot accuracies and losses 

In [None]:
def plot_accuracies(history):
    """ Plot the history of accuracies"""
    accuracies = [x['val_acc'] for x in history]
    plt.plot(accuracies, '-x')
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    plt.title('Accuracy vs. No. of epochs');
    plt.show()



def plot_losses(history):
    """ Plot the losses in each epoch"""
    train_losses = [x.get('train_loss') for x in history]
    val_losses = [x['val_loss'] for x in history]
    plt.plot(train_losses, '-bx')
    plt.plot(val_losses, '-rx')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend(['Training', 'Validation'])
    plt.title('Loss vs. No. of epochs');
    plt.show()
    
plot_accuracies(history)

plot_losses(history)