In [23]:
from SHModelUtils import SHModel
import torch
import torchvision
import math
import numpy as np
from sklearn.utils import shuffle
import ast

In [24]:
lang_name = "java" 
MODEL = SHModel(lang_name, "curr")

In [25]:
def data_preparation(filename):
    X = []
    T = []
    DELIMITER = "#"
    with open(filename, "r") as highlight_file:
        lines = highlight_file.readlines()
        for l in lines:
            tok_ids = l.split(DELIMITER)[0]
            h_code_values = l.split(DELIMITER)[1]

            # Converting string to list
            tok_ids = np.array(ast.literal_eval(tok_ids))
            h_code_values = np.array(ast.literal_eval(h_code_values))
        

            X.append(tok_ids)
            T.append(h_code_values)
    
    X = np.array(X, dtype=object)
    T = np.array(T, dtype=object)

    return X,T

In [26]:
def split_training_data(X,T,train_percentage=0.7):
    
    # split into 80/20 training/validation 
    N= X.shape[0] # get the number of records (rows)
    train_size = int(train_percentage * N) # use the first split_coeff of the data as the training data
    X_train = X[:train_size] # the first training_size record
    T_train = T[:train_size]
    X_val = X[train_size:]
    T_val = T[train_size:]
    
    assert X_train.shape[0] == T_train.shape[0]
    assert X_val.shape[0] == T_val.shape[0]
    
    
    return X_train, T_train, X_val, T_val

In [27]:
def shuffle_data(X, T):
    assert len(X) == len(T)
    shuffle(X, T, random_state=0)
    return X, T

In [28]:
def train(X, T, epochs=10):
    MODEL.setup_for_finetuning()
    losses = np.array([])
    for epoch in range(epochs):
        print(f'Loading {epoch+1}0%')
        epoch_losses = np.array([])
        for idx,x in enumerate(X):
            epoch_loss= MODEL.finetune_on(x,T[idx])
            epoch_losses = np.append(epoch_losses,epoch_loss)  
        avg_epoch_loss = np.mean(epoch_losses)
        losses = np.append(losses,avg_epoch_loss)
        print(f'Average Loss {avg_epoch_loss} in epoch {epoch+1}')
    return losses


In [29]:
def accuracy(X,T):
    MODEL.setup_for_prediction()
    correct = 0
    total = 0
    for idx, x in enumerate(X):
        h_codes=MODEL.predict(x) # [2,3,4,6]
        for j, h_code in enumerate(h_codes):
            if h_code == T[idx][j]:
                correct +=1
            total +=1
    
    return correct/total

In [30]:
def model_switch(X,T):
    # split data in train & val
    X_train, T_train, X_val, T_val = split_training_data(X,T)
    # do a second split on the validation data to test accuracy of current model
    X_val_before, T_val_before, X_val_after, T_val_after = split_training_data(X_val,T_val, 0.5)
    # test current accuracy
    curr_acc = accuracy(X_val_before,T_val_before)
    # shuffle train data
    X_train, T_train = shuffle_data(X_train, T_train)
    # train current model
    train(X_train, T_train)
    # test the accuracy of new model
    new_acc = accuracy(X_val_after, T_val_after)
    print(new_acc)
    
    if new_acc > curr_acc:
        return True    

In [31]:
X,T = data_preparation("highlight-java.txt")
print(model_switch(X,T))

Loading 10%
Average Loss 1.6616669282032426 in epoch 1
Loading 20%
Average Loss 0.8611268176204839 in epoch 2
Loading 30%
Average Loss 0.6020236037149551 in epoch 3
Loading 40%
Average Loss 0.45838947005712305 in epoch 4
Loading 50%
Average Loss 0.36284984738393955 in epoch 5
Loading 60%
Average Loss 0.29171075353956527 in epoch 6
Loading 70%
Average Loss 0.23500488416119747 in epoch 7
Loading 80%
Average Loss 0.19220030229467494 in epoch 8
Loading 90%
Average Loss 0.16035605796203492 in epoch 9
Loading 100%
Average Loss 0.13596369031887906 in epoch 10
0.9817500158861282
True
