In [11]:
import torch
import numpy as np
import sys

from src.tasks import *
from src.utils import *


In [12]:
device = 'cuda' #  'cpu'
curriculum_type = 'cumulative'
task = 'parity'
network_num = 1
N_max = 7

rnn = load_model(
    curriculum_type=curriculum_type,
    task=task,
    network_number=network_num,
    N_max=N_max,
    base_path='../trained_models',
    device=device,
).to(device)

In [13]:
BATCH_SIZE = 64
TEST_STEPS = 600
task_function = make_batch_Nbit_pair_parity
# task_function = make_batch_multihead_dms
def test(model, Ns):

    correct_N = np.zeros_like(Ns)
    total = 0
    for j in range(TEST_STEPS):
        with torch.no_grad():
            sequences, labels = task_function(Ns, BATCH_SIZE)
            sequences = sequences.to(device)
            labels = [l.to(device) for l in labels]

            out, out_class = model(sequences)

            for N_i in range(len(Ns)):
                predicted = torch.max(out_class[N_i], 1)[1]

                correct_N[N_i] += (predicted == labels[N_i]).sum()
                total += labels[N_i].size(0)

    accuracy = 100 * correct_N / float(total) * len(Ns)

    print('{:.4f}  %'.format(np.mean(accuracy)), flush=True)
    print('({N}, accuracy):\n' + ''.join([f'({Ns[i]}, {accuracy[i]:.4f})\n' for i in range(len(Ns))]), flush=True)

In [14]:
test(rnn, Ns=list(np.arange(2, 100)))

53.0201  %
({N}, accuracy):
(2, 100.0000)
(3, 100.0000)
(4, 100.0000)
(5, 100.0000)
(6, 100.0000)
(7, 100.0000)
(8, 47.1953)
(9, 49.8125)
(10, 54.2578)
(11, 51.1927)
(12, 45.8073)
(13, 51.1042)
(14, 45.7995)
(15, 50.7943)
(16, 50.7760)
(17, 47.1094)
(18, 50.1380)
(19, 47.3802)
(20, 50.0026)
(21, 49.9245)
(22, 51.8437)
(23, 50.6094)
(24, 51.8750)
(25, 51.8516)
(26, 52.1302)
(27, 52.0781)
(28, 49.6354)
(29, 48.0677)
(30, 50.1172)
(31, 49.7109)
(32, 48.3177)
(33, 50.1875)
(34, 49.9531)
(35, 48.4583)
(36, 48.6068)
(37, 50.3385)
(38, 50.0313)
(39, 50.1615)
(40, 51.2292)
(41, 48.9557)
(42, 50.1667)
(43, 49.3073)
(44, 50.2552)
(45, 49.9740)
(46, 48.8568)
(47, 49.0078)
(48, 51.5104)
(49, 49.5573)
(50, 48.5573)
(51, 50.2187)
(52, 51.3958)
(53, 51.5234)
(54, 51.3307)
(55, 49.7526)
(56, 51.1953)
(57, 49.8906)
(58, 49.8880)
(59, 48.7214)
(60, 50.2109)
(61, 50.3490)
(62, 50.8307)
(63, 49.8047)
(64, 50.3021)
(65, 49.0547)
(66, 50.0990)
(67, 49.2604)
(68, 50.8880)
(69, 50.3984)
(70, 50.0417)
(71, 49.