In [1]:
import numpy as np
import torch
from dataclasses import dataclass
from revisedkey.model import Reviser
from revisedkey.utils import *
from tqdm import tqdm

In [2]:
@dataclass
class Config:
    dimension = 1024
    dropout = 0.2
    ffn_size = 8192
    dataset = 'koran'
    source_model = '../knnmt-fairseq/models/wmt19.de-en/wmt19.de-en.ffn8192.pt'
    target_model = '../knnmt-fairseq/models/koran_finetune/checkpoint_best.pt'
    source_dstore_mmap = '../knnmt-fairseq/datastores/koran_base/'
    target_dstore_mmap = '../knnmt-fairseq/datastores/koran_finetune/'

In [3]:
args = Config()

In [4]:
reviser = Reviser(args)
ckp = torch.load('revisedkey-datastores/koran_revised/reviser_checkpoint.pt')
reviser.load_state_dict(ckp)
reviser = reviser.cuda()

In [5]:
reviser

Reviser(
  (key_map): Sequential(
    (0): Linear(in_features=4096, out_features=8192, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=8192, out_features=1024, bias=True)
  )
  (source_embed): Embedding(42024, 1024)
  (target_embed): Embedding(42024, 1024)
)

In [6]:
dstore_size = DATASTORE_SIZE[args.dataset]
revised_dstore_mmap = 'revisedkey-datastores/koran_revised'
dstore_keys = np.memmap(revised_dstore_mmap + '/keys.npy', dtype=np.float16, mode='w+', shape=(dstore_size, args.dimension))
dstore_vals = np.memmap(revised_dstore_mmap + '/vals.npy', dtype=np.int64, mode='w+', shape=(dstore_size, 1))

In [7]:
source_keys = np.memmap(args.source_dstore_mmap + '/keys.npy', dtype=np.float16, mode='r', shape=(dstore_size, args.dimension))
target_keys = np.memmap(args.target_dstore_mmap + '/keys.npy', dtype=np.float16, mode='r', shape=(dstore_size, args.dimension))
token = np.memmap(args.source_dstore_mmap + '/vals.npy', dtype=np.int64, mode='r', shape=(dstore_size, 1))

In [8]:
with torch.no_grad():
    for idx in tqdm(range(0, len(token), 10000)):
        part_source_keys = torch.from_numpy(source_keys[idx: idx+10000])
        part_target_keys = torch.from_numpy(target_keys[idx: idx+10000])
        part_token = torch.from_numpy(token[idx: idx+10000])

        revised_keys, _ = reviser.key_forward(
            source_hidden=part_source_keys.cuda().float(), 
            target_hidden=part_target_keys.cuda().float(),
            token=part_token.cuda().squeeze(1))
        
        dstore_keys[idx: idx+10000] = revised_keys.type(torch.float16).cpu().numpy()
        dstore_vals[idx: idx+10000] = part_token

  This is separate from the ipykernel package so we can avoid doing imports until
100%|██████████| 53/53 [00:06<00:00,  7.71it/s]
