In [1]:
from model import Transformer
from config import Config
from tokenizer import Tokenizer
import torch as t
import os
import ast
import pandas as pd
import warnings

warnings.filterwarnings("ignore")

Using device: cuda
fixed_digit: True
Training for 3000 epochs
Batch size: 256
Learning rate: 0.001
Train on 0.5 of the data
Saving model every 5 epochs


In [2]:
# Load the model
path = "saved_runs/variable_digit_add_50/final.pth"
config = Config()
model = Transformer(config)
model.load_state_dict(t.load(path, map_location = t.device("cpu"))["model"])
tokenizer = Tokenizer(config)

In [3]:
def get_accuracy(model, maxnum = 100):

    correct = 0
    count = 0
    for i in range(maxnum):
        for j in range(maxnum):
            count += 1
            print(f"Accuracy: {(correct/count):.2%}, Count: {count}", end = '\r')
            correct_answer = (i + j)
            lsi = [int(k) for k in str(i)]
            lsj = [int(k) for k in str(j)]
            question = lsi + [10] + lsj + [11]

            ll = len(question)
            #print("q:", question)
            pred = model.generate_greedy(question)
            answer = pred[ll: -1]
            # Convrt the answer to integer
            try:
                answer = int(tokenizer.detokenize(answer))
            except:
                print("Could not convert answer to integer")

            if answer == correct_answer:
                correct += 1
            
            
    print('\n')

In [4]:
print("Calculating accuracy for the final model")
get_accuracy(model)

Calculating accuracy for the final model
Accuracy: 99.62%, Count: 10000



In [5]:
def extract_answer_from_prediction(pred):
    """
    Takes the prediction and extracts the answer from it
    pred: list of token ids like [0, 10, 0, 11, 0, 12, 13]
    return: answer which is an integer like 0
    """
    equal_tokenid = tokenizer.tokenize("=")[0]
    eos_tokenid = tokenizer.tokenize("EOS")[0]
    answer_start_idx = pred.index(equal_tokenid) + 1
    answer_end_idx = pred.index(eos_tokenid)
    answer = pred[answer_start_idx:answer_end_idx]
    try:
        answer = int(tokenizer.detokenize(answer))
    except:
        print("Could not convert answer to integer")
    return answer


In [36]:
def select_results_n_digits(data, n):
    """
    Selects the results with n or more digits
    """
    return data[data["result"].apply(lambda x: len(str(x)) >= n)]

def get_counts(data):
    """
    Initializes the counts for the metrics
    """
    counts = {}
    train = data[data["is_train"]]
    test = data[~data["is_train"]]

    print("len train", len(train))
    print("len test", len(test))

    # Counts for overall accuracy
    counts["train_total"] = len(train)
    counts["test_total"] = len(test)

    # Counts for individual digits
    max_len = len(str(train["result"].max()))

    for i in range(max_len):
        train_n_digits = select_results_n_digits(train, i+1)
        test_n_digits = select_results_n_digits(test, i+1)
        counts[f"train_digit_{i}"] = len(train_n_digits)
        counts[f"test_digit_{i}"] = len(test_n_digits)
    return counts

def get_frequencies(data):
    """
    Calculates the frequencies for the metrics
    """
    frequencies = {
        "train_total": 0,
        "test_total": 0
    }

    def check_individual_digits(pred, ground_truth, istrain):
            """
            Updates the frequency for the individual digits metric
            """
            # Fill shorter number with zeros
            max_len = max(len(str(pred)), len(str(ground_truth)))
            pred_str = f"{pred:0{max_len}d}"
            ground_truth_str = f"{ground_truth:0{max_len}d}"
            
            assert len(pred_str) == len(ground_truth_str)

            for i in range(len(pred_str)):
                # Initialize the dictionary entry
                if istrain and f"train_digit_{i}" not in frequencies.keys():
                    frequencies[f'train_digit_{i}'] = 0
                elif not istrain and f"test_digit_{i}" not in frequencies.keys():
                    frequencies[f'test_digit_{i}'] = 0
                
                # Now update the dictionary entry if the prediction is correct
                if pred_str[-(i+1)] == ground_truth_str[-(i+1)]:
                    #print("correct digit prediction for digit ", i)
                    #print("ground_truth:", ground_truth_str)
                    #print("pred:", pred_str)
                    if istrain:
                        frequencies[f'train_digit_{i}'] += 1
                    else:
                        frequencies[f'test_digit_{i}'] += 1
                '''
                else:
                    print("incorrect digit prediction for digit ", i)
                    print("ground_truth:", ground_truth_str)
                    print("pred:", pred_str)
                '''

    # Calculate the accuracy (digits and overall) for train and test
    for _, row in data.iterrows():

        is_train = row["is_train"]
        # Get the ground truth and prediction
        ground_truth = int(row["result"])
        # Convert the tokenized input to a list
        input = tokenizer.tokenize(f"{row['operand_1']}+{row['operand_2']}=")
        pred = model.generate_greedy(input)
        answer = extract_answer_from_prediction(pred)
        
        # Check prediction and ground truth overall
        if answer == ground_truth:
            if is_train:
                frequencies["train_total"] += 1
            else:
                frequencies["test_total"] += 1
        
        # Check prediction and ground truth for individual digits

        check_individual_digits(answer, ground_truth, is_train)

    return frequencies

def take_metrics():
    # We use the df instead of the dataset class for this
    # because our generate greedy function is not batched
    data_path = "saved_runs/variable_digit_add_50/data.csv"
    data = pd.read_csv(data_path)

    # Get counts
    counts = get_counts(data)
    print("counts", counts)

    # Get the frequencies
    frequencies = get_frequencies(data)
    print("frequencies", frequencies)

    # Calculate the accuracy from the frequencies and counts
    metrics = {}
    for key in frequencies.keys():
        metrics[key + "_accuracy"] = frequencies[key] / counts[key]

    print("metrics", metrics)

In [37]:
take_metrics()

len train 5030
len test 4970
counts {'train_total': 5030, 'test_total': 4970, 'train_digit_0': 5030, 'test_digit_0': 4970, 'train_digit_1': 4987, 'test_digit_1': 4958, 'train_digit_2': 2460, 'test_digit_2': 2490}
frequencies {'train_total': 5030, 'test_total': 4933, 'train_digit_0': 5030, 'test_digit_0': 4945, 'train_digit_1': 4987, 'test_digit_1': 4943, 'test_digit_2': 2489, 'train_digit_2': 2460}
metrics {'train_total_accuracy': 1.0, 'test_total_accuracy': 0.9925553319919517, 'train_digit_0_accuracy': 1.0, 'test_digit_0_accuracy': 0.9949698189134809, 'train_digit_1_accuracy': 1.0, 'test_digit_1_accuracy': 0.9969745865268254, 'test_digit_2_accuracy': 0.9995983935742971, 'train_digit_2_accuracy': 1.0}


In [129]:
test_sentence = "56+52="
test_sentence_tokenized = tokenizer.tokenize(test_sentence)

pred = model.generate_greedy(test_sentence_tokenized)
answer = extract_answer_from_prediction(pred)
print("Test sentence:", test_sentence)
print("Answer:", answer)

Test sentence: 56+52=
Answer: 108


In [27]:
df = pd.read_csv("saved_runs/variable_digit_add_50/data.csv", index_col = 0)
df.head(500)

Unnamed: 0,operand_1,operand_2,result,is_train,input_str,tokenized
0,0,0,0,True,0+0=0EOSPADPADPADPADPADPADPAD,"[0, 10, 0, 11, 0, 12, 13, 13, 13, 13, 13, 13, 13]"
1,0,1,1,True,0+1=1EOSPADPADPADPADPADPADPAD,"[0, 10, 1, 11, 1, 12, 13, 13, 13, 13, 13, 13, 13]"
2,0,2,2,True,0+2=2EOSPADPADPADPADPADPADPAD,"[0, 10, 2, 11, 2, 12, 13, 13, 13, 13, 13, 13, 13]"
3,0,3,3,True,0+3=3EOSPADPADPADPADPADPADPAD,"[0, 10, 3, 11, 3, 12, 13, 13, 13, 13, 13, 13, 13]"
4,0,4,4,True,0+4=4EOSPADPADPADPADPADPADPAD,"[0, 10, 4, 11, 4, 12, 13, 13, 13, 13, 13, 13, 13]"
...,...,...,...,...,...,...
395,4,45,49,False,4+45=49EOSPADPADPADPADPAD,"[4, 10, 4, 5, 11, 4, 9, 12, 13, 13, 13, 13, 13]"
396,4,46,50,False,4+46=50EOSPADPADPADPADPAD,"[4, 10, 4, 6, 11, 5, 0, 12, 13, 13, 13, 13, 13]"
397,4,47,51,False,4+47=51EOSPADPADPADPADPAD,"[4, 10, 4, 7, 11, 5, 1, 12, 13, 13, 13, 13, 13]"
398,4,48,52,True,4+48=52EOSPADPADPADPADPAD,"[4, 10, 4, 8, 11, 5, 2, 12, 13, 13, 13, 13, 13]"
