#### Import Libraries 

In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [8]:
from rtpt import RTPT
import time
import h5py
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm

# PyTorch libraries and modules
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.optim import Adam, SGD

from train.main_config import main_config

from nn.a0_resnet import AlphaZeroResnet

#### Splitting the data

In [9]:
H5_TRAIN = main_config["train_dir"]
H5_VAL = main_config["val_dir"]
H5_TEST = main_config["test_dir"]


def create_hdf5_generator(file_path, batch_size):
    file = h5py.File(file_path)
    data_size = file['data'].shape[0]

    while True: # loop through the dataset indefinitely
        for i in np.arange(0, data_size, batch_size):
            data = file['data'][i:i+batch_size]
            labels = file['labels'][i:i+batch_size]
            # converting data into torch format
            data  = torch.from_numpy(data)
            # converting the lables into torch format
            labels = torch.from_numpy(labels)
            yield data, labels

#### Defining Dice Metric

In [10]:
def get_lookup_masks():
    """ 
    returns e.g. for sense_square 2
    [
    0, 0, 0, 0, 0, 0, 0, 0
    0, 0, 0, 0, 0, 0, 0, 0
    0, 0, 0, 0, 0, 0, 0, 0
    0, 0, 0, 0, 0, 0, 0, 0
    0, 0, 0, 0, 0, 0, 0, 0
    0, 1, 1, 1, 0, 0, 0, 0
    0, 1, 1, 1, 0, 0, 0, 0
    0, 1, 1, 1, 0, 0, 0, 0
    ]
    """
    
    masks = np.zeros((36, 8, 8))
    for pos in range(36): 
        row = (pos) // 6
        col = (pos) % 6
        row = 5  - row
        for delta_rank in [1, 0, -1]:
            for delta_file in [-1, 0, 1]:
                masks[pos][row + 1 + delta_rank][col + 1 + delta_file] = 1
    
    masks = np.delete(masks, 0,1)
    masks = np.delete(masks, 0,2)
    masks = np.delete(masks, 6,1)
    masks = np.delete(masks, 6,2)

    return masks
    

lookup_masks = get_lookup_masks()

def DiceMetric(outputs, predicted_outputs):
    new_predicted = []
    new_outputs = []
    smooth = .001
    for i, pos in enumerate(outputs):
        sense_plane = lookup_masks[pos.item()]
        new_predicted.append(sense_plane)
    for i, pos in enumerate(predicted_outputs):
        sense_plane = lookup_masks[pos.item()]
        new_outputs.append(sense_plane)
    inputs = torch.tensor(new_predicted)
    targets = torch.tensor(new_outputs)
    #flatten label and prediction tensors
    inputs = inputs.view(-1)
    targets = targets.view(-1)
    torch.set_printoptions(profile="full")

    intersection = (inputs * targets).sum()   
    Dice = (2. *intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
    
    return Dice.item()


#PyTorch
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1): 
        new_targets = []

        for i, pos in enumerate(targets):
            sense_plane = lookup_masks[pos.item()]
            t = np.flip(sense_plane, axis = 0)
            # t = t.flatten()
            # t = np.insert(t, 0, 0)
            new_targets.append(t)
       
        targets = torch.tensor(new_targets)
        if torch.cuda.is_available():
            targets = targets.cuda()

#         flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
#         print('input ', inputs )
#         print('targets ', targets)
#         print('input shape  ', inputs.shape )
#         print('targets shape ', targets.shape)
        
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice

##### Defining fit method

In [11]:
batch_size = 64

file = h5py.File(H5_TRAIN)
train_data_size = file['data'].shape[0]
file = h5py.File(H5_VAL)
val_data_size = file['data'].shape[0]
train_hdf5_generator = create_hdf5_generator(H5_TRAIN, batch_size)
val_hdf5_generator = create_hdf5_generator(H5_VAL, batch_size)


def fit(epochs, lr, model, train_loader, val_loader, criterion, opt_func = SGD):
    rtpt = RTPT(name_initials='AM', experiment_name='sense-training', max_iterations=epochs)
    rtpt.start()
    best_val_loss = np.inf
    history = []
    optimizer = opt_func(model.parameters(), lr, weight_decay=0.01)
    for epoch in range(epochs):

        model.train()
        train_losses = []
        with tqdm(train_loader,  unit='batch', total = train_data_size // batch_size) as tepoch:
            for data, labels in tepoch: 
                tepoch.set_description(f"Epoch {epoch}")
                if labels.shape[0] < batch_size or data.shape[0] < batch_size: 
                    break
                labels = labels.to(torch.long)
                # converting the data into GPU format
                if torch.cuda.is_available():
                    data, labels = data.cuda(), labels.cuda()

                loss = model.training_step(criterion, data, labels)
                train_losses.append(loss)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

        model.eval()
        outputs = []
        with tqdm(val_loader, unit='batch', total = val_data_size // batch_size) as vepoch:
            for data, labels in vepoch:
                if labels.shape[0] < batch_size or data.shape[0] < batch_size: 
                    break
                labels = labels.to(torch.long)
                # converting the data into GPU format
                if torch.cuda.is_available():
                    data, labels = data.cuda(), labels.cuda()
                outputs.append(model.validation_step(criterion, data, labels))
        result = model.validation_epoch_end(outputs)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        model.epoch_end(epoch, result)
        history.append(result)
        
        if result['val_loss'] < best_val_loss: 
            val_loss = result['val_loss']
            print(f'Validation Loss Decreased({best_val_loss}--->{val_loss}) \t Saving The Model')
            best_val_loss = val_loss
            torch.save(model, '/root/rbc/models/model.pth')

        rtpt.step()
        
    return history


#### Training the model

In [14]:
model = AlphaZeroResnet(nb_input_channels=21, dropout=0.2, n_labels=36)
# print(model)
opt_func = Adam
lr = 0.001
criterion = CrossEntropyLoss()
# criterion = DiceLoss()
num_epochs = 20

# check if GPU is available
if torch.cuda.is_available():
    model = model.cuda()
    criterion = criterion.cuda()

history = fit(num_epochs, lr, model, train_hdf5_generator, val_hdf5_generator, criterion, opt_func)

  0%|          | 0/649 [00:00<?, ?batch/s]

IndexError: Target 36 is out of bounds.

#### Plot accuracies and losses 

In [None]:
def plot_accuracies(history):
    """ Plot the history of accuracies"""
    accuracies = [x['val_acc'] for x in history]
    plt.plot(accuracies, '-x')
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    plt.title('Accuracy vs. No. of epochs')
    plt.show()



def plot_losses(history):
    """ Plot the losses in each epoch"""
    train_losses = [x.get('train_loss') for x in history]
    val_losses = [x['val_loss'] for x in history]
    plt.plot(train_losses, '-bx')
    plt.plot(val_losses, '-rx')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend(['Training', 'Validation'])
    plt.title('Loss vs. No. of epochs')
    plt.show()
    
plot_accuracies(history)

plot_losses(history)

#### Testing the model on test dataset 

In [None]:
def test(): 
    # Load the model that we saved at the end of the training loop 
    model = AlphaZeroResnet()
    path = '/root/rbc/saved_model.pth'
    model = torch.load(path)
    model = model.cuda()

    file = h5py.File(H5_TEST)
    data_size = file['data'].shape[0]
    print('data size', data_size)
    test_hdf5_generator = create_hdf5_generator(H5_TEST, batch_size)

    running_accuracy = 0 
    running_dice_coeff = 0
    total = 0 

 
    with torch.no_grad(): 
#         with tqdm(test_hdf5_generator, unit='batch') as e:
        for inputs, outputs in tqdm(test_hdf5_generator, total = data_size // batch_size): 
            if inputs.shape[0] < batch_size or outputs.shape[0] < batch_size: 
                break
            inputs, outputs = inputs.cuda(), outputs.cuda()
            outputs = outputs.to(torch.long) 
            predicted_outputs = model(inputs) 
            _, predicted = torch.max(predicted_outputs, 1) 
            total += 1 
            running_accuracy += (predicted == outputs).sum().item() 
            running_dice_coeff += DiceMetric(outputs, predicted)
 
        print('Accuracy of the model based on the test set of', total * batch_size ,'inputs is: %d %%' % (100 * running_accuracy / (total * batch_size)))
        print('Dice coefficient of the model based on the test set of', total * batch_size ,'inputs is:', (running_dice_coeff / total)) 

test()