In [1]:
import torch
from model import Encoder, cp_2_key_model, cp_2_k_mask
from config import args
from dataset import Enigma_simulate_c_2_p, Enigma_simulate_cp_2_k_limited, Enigma_simulate_cp_2_k
from torch.utils.data import DataLoader
import math
from torchsummary import summary
from tqdm import tqdm

In [2]:
ckpt = torch.load('CP2K_RNN_ENC_ckpt.pt')
ckpt_args = ckpt['args']


model = cp_2_k_mask(args=ckpt_args, out_channels=26)
model.to('cuda')
model.eval()

weights = []
for k, v in ckpt['weights'].items():
    weights.append(v)

for idx, (k, v) in enumerate(model.state_dict().items()):
    print(k, v)

    # This is the key point of copying weights
    v *= 0
    v += weights[idx].detach()

rnn.emb.weight tensor([[ 0.0137,  0.1316,  0.0908,  ..., -0.0005,  0.0731,  0.0238],
        [ 0.0740, -0.0740, -0.0641,  ...,  0.0004,  0.0323,  0.1257],
        [-0.0654, -0.0657,  0.1344,  ..., -0.1063, -0.0346, -0.0298],
        ...,
        [ 0.1354,  0.0209,  0.0397,  ...,  0.0443,  0.1077, -0.0109],
        [ 0.1175,  0.1141, -0.0930,  ..., -0.0437,  0.0232, -0.0539],
        [ 0.0520, -0.0026,  0.1260,  ..., -0.0600, -0.1334,  0.0596]],
       device='cuda:0')
rnn.emb.bias tensor([ 0.1042, -0.0618, -0.0586,  0.1293,  0.1146,  0.1276, -0.1349,  0.0998,
         0.1135,  0.1238, -0.0270,  0.1125,  0.0982, -0.0125, -0.1270,  0.0876,
        -0.1136,  0.0826,  0.0059,  0.0692,  0.0843,  0.0467,  0.0558,  0.0038,
         0.0314, -0.0742, -0.0937, -0.1240,  0.0543,  0.0583,  0.1134,  0.0714,
        -0.1327,  0.1037, -0.0819, -0.0805, -0.0613,  0.0387, -0.0090, -0.0288,
        -0.1205, -0.0552,  0.1230,  0.0586,  0.0486, -0.0307, -0.1191,  0.0216,
         0.0047,  0.0992, -0.0792,

In [3]:
for k, v in model.state_dict().items():
    print(k, v)

rnn.emb.weight tensor([[ 3.4137e-01,  2.8597e-01, -2.2194e-02,  ..., -1.7932e-02,
         -7.0711e-05,  9.0005e-02],
        [-9.9854e-04, -1.7646e-03, -2.0310e-01,  ..., -1.8579e-01,
         -4.8737e-01, -5.9641e-02],
        [ 2.5541e-02,  6.7992e-02,  9.7029e-02,  ...,  1.9463e-02,
         -9.0371e-02,  3.4060e-02],
        ...,
        [-1.0861e-01, -7.1558e-02,  1.8475e-02,  ..., -6.6361e-02,
          6.5514e-03, -9.5478e-02],
        [ 4.6954e-02, -1.5398e-01,  1.1696e-01,  ..., -1.3148e-02,
          1.7559e-01,  2.3518e-02],
        [ 4.1907e-01,  1.7145e-01,  1.9200e-01,  ...,  5.8804e-03,
          3.7409e-01,  7.7623e-03]], device='cuda:0')
rnn.emb.bias tensor([-6.9229e-02,  8.7077e-03, -7.6116e-02,  4.3139e-02,  6.8288e-02,
         1.1775e-01, -7.5739e-02, -1.1108e-01, -9.9700e-03, -8.5666e-02,
         1.8228e-01,  1.2770e-02,  6.1759e-02, -1.3751e-02,  7.9525e-02,
         1.8439e-01,  1.3530e-02, -1.8609e-02,  1.0227e-01,  1.5917e-01,
         7.9047e-02, -7.5598e-0

In [4]:
dataset = Enigma_simulate_cp_2_k_limited(args=args)
dataloader = DataLoader(
        dataset=dataset,
        batch_size=1800,
        collate_fn=dataset.collate_fn_padding,
        shuffle=True
    )

for inputs, targets, masks in dataloader:
    inputs, targets, masks = inputs.to('cuda'), targets.to('cuda'), masks.to('cuda')


    outputs = model(inputs, masks)

    print(inputs.shape, targets.shape, masks.shape, outputs.shape)
    # print(masks, targets[1][~masks.T].shape, outputs[1][~masks.T].shape)
    break

true_positive = 0
samples = 0

outputs_indices = torch.argmax(outputs, dim=-1) # -> [rotor, seq, batch]
for rotor in range(outputs_indices.shape[0]):
    mask = outputs_indices[rotor][~masks.T] == targets[rotor][~masks.T]
    true_positive += mask.sum()
    samples += math.prod(mask.shape)

print(f"Acc: {true_positive / samples}")

torch.Size([45, 1800, 52]) torch.Size([3, 45, 1800]) torch.Size([1800, 45]) torch.Size([3, 45, 1800, 26])
Acc: 0.9521042108535767


In [5]:
# Testint accuracy in different length
testing_args = args
testing_args['SAMPLES_PER_KEYS'] = 1

for length in range(40, 41):
    testing_args['SEQ_LENGTH'] = [length, length]
    dataset = Enigma_simulate_cp_2_k_limited(args=testing_args)
    dataloader = DataLoader(
        dataset=dataset,
        batch_size=1800,
        collate_fn=dataset.collate_fn_padding,
        shuffle=False,
        drop_last=False
    )

    # Tracking
    true_positive = 0
    samples = 0

    bar = tqdm(dataloader, leave=True)
    bar.set_description_str(f"Length: {length}")

    for inputs, targets, masks in bar:
        inputs, targets, masks = inputs.to('cuda'), targets.to('cuda'), masks.to('cuda')

        # Making prediction
        with torch.cuda.amp.autocast():
            outputs = model(inputs, masks)

        # Compute accuracy
        outputs_indices = torch.argmax(outputs, dim=-1) # -> [rotor, seq, batch]
        for rotor in range(outputs_indices.shape[0]):
            mask = outputs_indices[rotor][~masks.T] == targets[rotor][~masks.T]
            true_positive += mask.sum()
            samples += math.prod(mask.shape)

    # Output the result
    print(f"Acc: {true_positive / samples}")

Length: 40: 100%|██████████| 10/10 [00:05<00:00,  1.72it/s]

Acc: 0.9740701913833618



