# Using kerMIT - Try Model -

This notebook explains how train and test KERMIT_model and save the weights to subsequently use KERMITviz.

* *Note: In the previous notebook we used ag_news as an example dataset, but please note that the user can choose one as he prefers.*



## Install Packages
Before starting, it is essential to have the following requirements:
- transformers==2.6.0
- torch and torchtext
- pycuda


In [None]:
#!pip install transformers==2.6.0
#!pip install torch
#!pip install torchtext
#!pip install pycuda

In [None]:
import torch, pickle, copy, transformers
from torchtext import data as datx
import pandas as pd
from torch import nn
import torch.nn.functional as F
from torch import optim
from tqdm import tqdm

#set manual seed for replicability
if torch.cuda.is_available(): torch.cuda.manual_seed_all(23)

#insert path of dataset
dataPath = ''

#training set of encoded trees 
nameTree1 = 'dtk_trees_ag_news_train.pkl'

#test set of encoded trees 
nameTree2 = 'dtk_trees_ag_news_test.pkl'


## Load Dataset
To standardise the input we have chosen the csv format.

If the user wants to use his own dataset he will have to modify this part and the section concerning data loading.

In [None]:
name_dataset_train = 'train.csv'

name_dataset_test = 'test.csv'

#show the dataset
data = pd.read_csv(dataPath+name_dataset)

data.head()

## Dataloader

PyTorch includes packages to prepare and load common datasets for your model.
In the next sections the functions for loading are defined 


In [None]:
class TreeField(datx.Field):
        def __init__(self, *args, **kwargs):
                super().__init__(*args, **kwargs)

        def preprocess(self, x):
                return x

        def process(self, batch, device=None):
                return torch.stack(batch)

#unpack .pkl encoded trees
def unplickle_trees(path_tree_file):
    print('--->read DTKs')
    dt_trees = []
    with open(path_tree_file, 'rb') as fr:
        try:
            while True:
                dt_trees.append(pickle.load(fr))
        except EOFError:
            pass
    return [torch.FloatTensor(i) for i in dt_trees]


def add_parsed_tree(test, test_tree_list, field):
        test_Examples_tree_list = []
        for tr in test_tree_list:
                tree = datx.Example.fromlist([tr], [('Tree', field)])
                test_Examples_tree_list.append(tree)
        test.fields['Tree'] = field
        new_test_examples_list = []
        for example, tree_ex in zip(test.examples, test_Examples_tree_list):
                to_append = example
                to_append.Tree = tree_ex.Tree
                new_test_examples_list.append(to_append)
        test.examples = new_test_examples_list
        return test
    
def first_tree(test, test_tree_list, field):
        test_Examples_tree_list = []
        tr = test_tree_list[0]
        tree = datx.Example.fromlist([tr], [('Tree', field)])
        test_Examples_tree_list.append(tree)
        test.fields['Tree'] = field
        new_test_examples_list = []
        for example, tree_ex in zip(test.examples, test_Examples_tree_list):
                to_append = example
                to_append.Tree = tree_ex.Tree
                new_test_examples_list.append(to_append)
        test.examples = new_test_examples_list
        return test

class UnprField(datx.Field):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def preprocess(self, x):
        return x

    def process(self, batch, device=None):
        return batch

In [None]:
#tokenizer initialization
tokenizer = transformers.AutoTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
pad_index = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
TEXT = datx.Field(use_vocab=False,fix_length=51, tokenize=tokenizer.encode, pad_token=pad_index, batch_first=True)

LABEL = datx.Field(sequential=False, use_vocab=False, batch_first=True)
TREE = TreeField(sequential=False, use_vocab=False, batch_first=True)

These functions allow us to load data in previously fixed batches. 
In both functions (in fields) you should define the columns of the dataset that is used and and state whether the header should be read. whether the header should be read..

In [None]:
BATCH_SIZE = 125
BATCH_SIZE_test = 32

#takes the dataset string as input and returns tuple (train_iter, vocab)
def dataset_to_train(dataset, dataPath, nameTree):

    LABEL = datx.Field(sequential=False, use_vocab=False, batch_first=True)
    fields=[('Label', LABEL),('Title', None), ('Text', TEXT)]
    train = datx.TabularDataset(path=f'{dataset}.csv', format='csv',fields=fields, skip_header=True)
    train_trees_list = unplickle_trees(f''+dataPath+''+nameTree+'')
    train = add_parsed_tree(train, train_trees_list, TREE)
    train_iter, a = datx.Iterator.splits(
            (train, _), sort_key=lambda x: len(x.Text),
            batch_sizes=(BATCH_SIZE, 1))

    return (train_iter, a)


#takes the dataset string as input and returns tuple (test_iter, vocab)
def dataset_to_test(dataset, dataPath, nameTree):

    LABEL = datx.Field(sequential=False, use_vocab=False, batch_first=True)
    fields=[('Label', LABEL),('Title', None), ('Text', TEXT)]
    test = datx.TabularDataset(path=f'{dataset}.csv', format='csv',fields=fields, skip_header=False)
    test_trees_list = unplickle_trees(f''+dataPath+''+nameTree+'')
    test = add_parsed_tree(test, test_trees_list, TREE)
    test_iter, a = datx.Iterator.splits(
            (train, _), sort_key=lambda x: len(x.Text),
            batch_sizes=(BATCH_SIZE_test, 1))

    return (test_iter, a)

## Train & Infer
For the train function, we simply have to loop over our data iterator and feed the inputs to the network and optimize.
Then we evaluate the performance using infer(), which by taking the test set as input returns the accuracy.

In [None]:
def train(train_iter, dataset_name):

    contEp = 0
    lung = len(list(train_iter))
    for epoc in (range(EPOCH)):
        contEp += 1
        running_loss = 0
        train_acc = 0
        tot = []
        for elem in tqdm(iter(train_iter)):
            x_sem = elem.Text.cuda()
            x_synth = elem.Tree.cuda()
            target = elem.Label.cuda()
            target_hat = model(x_sem, x_synth)
            loss = criterion(target_hat, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            train_acc += (torch.exp(target_hat).argmax(1) == target).sum().item()
            res = [1 if x == True else 0 for x in list(torch.exp(target_hat).argmax(1) == target)]
            tot += res

        print("Epoch: " , contEp)
        print("Loss: " + str(running_loss / lung))
        print("Accuracy: " + str(sum(tot) / len(tot)))

In [None]:
def infer(test_iter, neural_model,dataset_name , EPOCH=1, L=30, lambda_norm=0.001):
    
    running_loss = 0
    train_acc = 0
    lung = len(list(test_iter))
    tot = []
    model_infer = DTBert(BERT_DIM, TREE_DIM, OUTPUT_DIM)  
    model_infer.load_state_dict(neural_model.state_dict())
    model_infer.cuda()
    criterion = nn.NLLLoss()
    optimizer = optim.AdamW(model_infer.parameters(), lr=1e-3)

    for elem in tqdm(iter(test_iter)):
        x_sem = elem.Text.cuda()
        x_synth = elem.Tree.cuda()
        target = elem.Label.cuda()
        
        with torch.torch.no_grad():
            target_hat = model_infer(x_sem, x_synth)
            loss = criterion(target_hat, target)
            running_loss += loss.item()
        train_acc += (torch.exp(target_hat).argmax(1) == target).sum().item()
        res = [1 if x == True else 0 for x in list(torch.exp(target_hat).argmax(1) == target)]
        tot += res
    print("Loss: " + str(running_loss / lung))
    print("Accuracy: " + str(sum(tot) / len(tot)))
    return(sum(tot) / len(tot))


## Define Model
The model proposed in this tutorial is as follows:

In [None]:

class DTBert(nn.Module):
    def __init__(self, input_dim_bert, input_dim_dt, output_dim):
        super().__init__()
        self.bert = transformers.BertModel.from_pretrained('bert-base-uncased').to("cuda" if torch.cuda.is_available() else "cpu")
        self.synth_sem_linear = nn.Linear(input_dim_bert + input_dim_dt, output_dim)
        
    def forward(self, x_sem, x_synth):
        with torch.no_grad():
            x_sem = self.bert(x_sem)[0][:, 0, :]
        x_tot = torch.cat((x_sem, x_synth), 1)
        x_tot = self.synth_sem_linear(x_tot)
        out = F.log_softmax(x_tot, dim=1)
        return out
        

        
BERT_DIM = 768
TREE_DIM = 4000

OUTPUT_DIM = 5

model = DTBert(BERT_DIM, TREE_DIM, OUTPUT_DIM)        
model.cuda()

#Define a Loss function and optimizer
criterion = nn.NLLLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-5)


## Load Training Set

In [None]:
datasets_train = [ dataPath+name_dataset_train]

train_list = [] #each element is a tuple (train, test)

for dat in datasets_train:
    train_list.append(dataset_to_train(dat, dataPath, nameTree1))

## Training

In [None]:
EPOCH = 5

test_accuracies_NO_mem = []

for elem, dataset_name in zip(train_list, datasets_train):
    print(f"Training dataset: {dataset_name}")
    train(elem[0], dataset_name, EPOCH)

## Load Test Set

In [None]:
datasets_test = [ dataPath+name_dataset_test]

test_list = [] #each element is a tuple (train, test)

for dat in datasets_test:
    test_list.append(dataset_to_test(dat, dataPath, nameTree2))

## Test

In [None]:
test_accuracy = []

for elem, dataset_name in zip(test_list, datasets_test):
    print(f"Testing dataset: {dataset_name}")
    test_accuracy.append(infer(elem[0], model, dataset_name,1))
print("===================================")


In [None]:
print('accuracy: ',test_accuracy[0])

## Save Model

In [None]:
#Save model
torch.save(model, 'model.pt')

#load weights
#model = torch.load('./path/model.pt')