In [2]:
from transformers import *
import torch as t
import os
import warnings

warnings.filterwarnings("ignore")

In [3]:
def get_latest_model_folder(model_dir, file_start):
    model_files = [file for file in os.listdir(model_dir) if file_start in file]
    model_files.sort()
    return os.path.join(model_dir, model_files[-1])


In [4]:
model_dir = 'saved_runs'
file_start = 'mod_digit_add_'
config = Config()
latest_model_folder = get_latest_model_folder(model_dir, file_start)
latest_model_folder

'saved_runs/mod_digit_add_2024-09-29_19-04'

In [9]:
# total models will be num_models + 1
num_models = 10
num_epochs = 2000

first_model = Transformer(config)
first_model.load_state_dict(t.load(latest_model_folder + '/init.pth')['model'])
first_model.eval()


models = [Transformer(config) for _ in range(num_models-1)]
intervals = t.linspace(0, num_epochs, num_models + 1).int()[1:-1]
print(intervals)
weights = [t.load(latest_model_folder + f'/{num}.pth') for num in intervals] 
for i in range(len(weights)):
    models[i].load_state_dict(weights[i]['model'])
    models[i].eval()

last_model = Transformer(config)
last_model.load_state_dict(t.load(latest_model_folder + f'/final.pth')['model'])
last_model.eval()

tokenizer = Tokenizer(config)

tensor([ 200,  400,  600,  800, 1000, 1200, 1400, 1600, 1800],
       dtype=torch.int32)


In [7]:
def get_accuracy(model, maxnum = 100, p = 113):

    correct = 0
    count = 0
    for i in range(maxnum):
        for j in range(maxnum):
            count += 1
            print(f"Accuracy: {(correct/count):.2%}, Count: {count}", end = '\r')
            correct_answer = (i + j) % p
            lsi = [int(k) for k in str(i)]
            lsj = [int(k) for k in str(j)]
            question = lsi + [10] + lsj + [11]

            ll = len(question)
            #print("q:", question)
            pred = model.generate_greedy(question)
            answer = pred[ll: -1]
            try:
                answer = int(tokenizer.detokenize(answer))
            except:
                continue
            #print("p", pred)
            #print("a", answer)

            if answer == correct_answer:
                correct += 1
            
            
    print('\n')

for i, model in enumerate(models):
    if i == 0:
        print("Calculating accuracy for the initial model")
    elif i == len(models) - 1:
        print("Calculating accuracy for the final model")
    else:
        print(f"Calculating accuracy for model {i}")
    get_accuracy(model)

         

Calculating accuracy for model 0
Accuracy: 58.31%, Count: 10000

Calculating accuracy for model 1
Accuracy: 93.59%, Count: 10000

Calculating accuracy for model 2
Accuracy: 95.05%, Count: 10000

Calculating accuracy for model 3
Accuracy: 95.01%, Count: 10000

Calculating accuracy for model 4
Accuracy: 90.88%, Count: 10000

Calculating accuracy for model 5
Accuracy: 95.75%, Count: 10000

Calculating accuracy for model 6
Accuracy: 95.55%, Count: 10000

Calculating accuracy for model 7
Accuracy: 95.86%, Count: 10000

Calculating accuracy for model 8
Accuracy: 96.01%, Count: 10000

Calculating accuracy for model 9
Accuracy: 0.67%, Count: 10000



In [None]:
test_sentence = "1+1="
test_sentence_tokenized = tokenizer.tokenize(test_sentence)

pred = model.generate_greedy([1, 10, 2,1, 11])
print("Prediction:", pred)

Prediction: [1, 10, 2, 1, 11, 2, 2, 12]


In [None]:
test_sentence_tokenized

[1, 10, 1, 11]