In [1]:
from transformer_nb2 import *
from dataset import make_data_generator
import json
from tqdm import tqdm_notebook as tqdm

In [2]:
folder = 'data-giga/'
data_name = folder+'train_seq.json'
validation_name = folder+'valid_seq.json'
testdata_name = folder+'testdata_seq.json'
vocab_name = folder+'vocab.json'

In [3]:
num_epochs = 10
save_rate = 1 #how many epochs per modelsave
continue_from = "trained/Model1" # if none, put None
continue_from = None
epsilon = 1e-8
validation_size = 10000
device = torch.device('cuda')
!mkdir -p trained

In [4]:
vocab = json.load(open(vocab_name, 'r'))
VOC_SIZE = len(vocab)
INPUT_MAX = 50
OUTPUT_MAX = 20
UNK = "[UNK]"
BOS = "[CLS]"
EOS = "[SEP]"
PAD = "[PAD]"

In [5]:
batch_size = 64

training_set, training_generator = make_data_generator(\
data_name, INPUT_MAX, OUTPUT_MAX, vocab[PAD], batch_size, cutoff=None, shuffle=True, num_workers=4)

validation_set, validation_generator = make_data_generator(\
validation_name, INPUT_MAX, OUTPUT_MAX, vocab[PAD], batch_size, cutoff=validation_size, shuffle=False, num_workers=4)

def data_gen_train():
    for src,tgt in training_generator:
#         src = Variable(src, requires_grad=False).to(device)
#         tgt = Variable(tgt, requires_grad=False).to(device)
        src = src.to(device)
        tgt = tgt.to(device)
        yield Batch(src, tgt, vocab[PAD])
def data_gen_val():
    for src,tgt in validation_generator:
        src = src.to(device)
        tgt = tgt.to(device)
        yield Batch(src, tgt, vocab[PAD])

loading json
load json done.


HBox(children=(IntProgress(value=0, max=3796354), HTML(value='')))


loading json
load json done.


HBox(children=(IntProgress(value=0, max=7603), HTML(value='')))




In [6]:
import math
total_train = int(math.ceil(training_set.size / batch_size))
total_valid = int(math.ceil(validation_set.size / batch_size))
print(total_train, total_valid)

59319 119


In [7]:
model = make_model(VOC_SIZE, VOC_SIZE, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1, emb_share=True)

if continue_from is not None:
    saved_model = torch.load(continue_from)
    model.load_state_dict(saved_model['model'])

model.cuda()

# criterion = nn.NLLLoss(ignore_index=vocab[PAD], reduction='sum')
criterion = LabelSmoothing(size=VOC_SIZE, padding_idx=vocab[PAD], smoothing=0.1)
criterion.cuda()

model_opt = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.998), eps=1e-9)

loss_compute = SimpleLossCompute(model.generator, criterion, model_opt)



In [8]:
vocab_inv = {a:b for b, a in vocab.items()}
def convert_ids_to_tokens(ids):
    return [vocab_inv[i] for i in ids]

In [None]:
start = 1 if continue_from == None else (int(continue_from.split("Model")[-1])+1)
for epoch in range(start, num_epochs+1):
    print("Epoch", epoch)
    
    # training
    stats = Stats()
    model.train()
    for i, batch in enumerate(tqdm(data_gen_train(), total=total_train)):
        out = model.forward(batch.src, batch.trg, batch.src_mask, batch.trg_mask)
        loss = loss_compute(out, batch.trg_y, batch.ntokens)
        stats.update(loss, batch.ntokens, log=1)
        
        if( i % 1000 == 0):
            probs = model.generator(out) 
            print("\n")
            next_words = torch.argmax(probs, dim=-1, keepdim=True)            
            print(convert_ids_to_tokens([i.item() for i in next_words[0]]))
            
    t_h = stats.history
    
    # validation
    stats = Stats()
    model.eval()
    for i, batch in enumerate(data_gen_val()):
        out = model.forward(batch.src, batch.trg, batch.src_mask, batch.trg_mask)
        loss = loss_compute(out, batch.trg_y, batch.ntokens)
        stats.update(loss, batch.ntokens, log=1)
    v_h = stats.history
    
    print("[info] epoch train loss:", np.mean(t_h))
    print("[info] epoch valid loss:", np.mean(v_h))
    
    try:
        torch.save({'model':model.state_dict(), 'training_history':t_h, 'validation_loss':np.mean(v_h)}, 
                   "trained/Model"+str(epoch))
    except:
        continue

Epoch 1


HBox(children=(IntProgress(value=0, max=59319), HTML(value='')))

Step: 1 Loss: 9.419542 Tokens / Sec: 690.337406

['staples', 'staples', 'staples', 'staples', 'maglev', 'staples', 'bomb-making', 'russia-nato', 'bulent', 'staples', 'register', 'staples', 'bomb-making', 'bomb-making', 'fights', 'bomb-making', 'panicking', 'staples', 'staples']
Step: 1001 Loss: 5.612229 Tokens / Sec: 2662.140242

['[UNK]', "'s", 'in', 'in', 'in', '[UNK]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]']
Step: 2001 Loss: 5.389665 Tokens / Sec: 2721.051918

['[UNK]', 'to', 'win', '[UNK]', '[SEP]', '[UNK]', '[SEP]', '[SEP]', 'world', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]']
Step: 3001 Loss: 4.681033 Tokens / Sec: 2504.729047

['afghan', 'arabia', 'police', 'arrest', '##', '[SEP]', '[SEP]', 'afghanistan', 'of', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]']
Step: 4001 Loss: 4.713599 Tokens / Sec: 2550.088069

['swis

Step: 37001 Loss: 3.158911 Tokens / Sec: 2547.023294

['government', 'muslim', 'of', 'the', '[SEP]', '[SEP]', 'the', 'muslim', 'state', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]']
Step: 38001 Loss: 3.006186 Tokens / Sec: 2686.488248

['mexican', 'mexican', 'court', 'launches', 'probe', 'deadly', 'blaze', 'that', 'killing', '[SEP]', '[SEP]', 'blaze', 'deaths', 'blaze', 'school', 'blaze', '[SEP]', '[SEP]', 'killing']
Step: 39001 Loss: 2.768993 Tokens / Sec: 2572.578115

['abacha', 'dictator', 'abacha', 'at', 'rights', 'to', '[SEP]', '[SEP]', 'as', 'of', 'of', '[SEP]', 'of', 'of', 'of', 'of', 'abacha', 'of', 'of']
Step: 40001 Loss: 2.867320 Tokens / Sec: 2681.376820

['germany', 'posts', 'posts', 'first', 'net', '[SEP]', 'of', 'lower', 'percent', '[SEP]', 'lower', '[SEP]', 'lower', 'lower', 'lower', 'lower', '[SEP]', '[SEP]', 'lower']
Step: 41001 Loss: 2.878978 Tokens / Sec: 2587.217863

['manhattan', 'club', 'to', 'new', 'of', 'it', 'will', '

HBox(children=(IntProgress(value=0, max=59319), HTML(value='')))

Step: 1 Loss: 2.553679 Tokens / Sec: 1128.361700

['turkey', 'to', 'parliamentary', 'commission', 'to', 'of', 'former', '[SEP]', 'corruption', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]']
Step: 1001 Loss: 2.675547 Tokens / Sec: 2618.954418

['us', 'president', 'candidates', 'face', 'tough', 'us', '[SEP]', 'us', '[SEP]', 'musharraf', 'seek', 'on', 'support', 'musharraf', 's', 'suspension', '[SEP]', '[SEP]', '[SEP]']
Step: 2001 Loss: 2.509633 Tokens / Sec: 2552.862193

['six', 'portuguese', 'rugby', 'jailed', '##', 'in', 'jail', '[SEP]', 'world', 'world', 'cup', 'qualifying', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', 'qualifying', '[SEP]']
Step: 3001 Loss: 2.423761 Tokens / Sec: 2682.436662

['israeli', 'protest', 'peres', 'of', 'peres', "'s", 'peres', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]']
Step: 4001 Loss: 2.569580 Tokens / Sec: 2588.940649

['israeli', 'killed', 'militan

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Step: 11001 Loss: 2.700348 Tokens / Sec: 2610.895362

['cuban', 'cuban', 'political', 'dies', 'in', 'hunger', 'strike', '[SEP]', 'out', '[SEP]', 'out', '[SEP]', '[SEP]', 'ill', '[SEP]', 'ill', '[SEP]', '##', '[SEP]']
Step: 11900 Loss: 2.523179 Tokens / Sec: 2698.194795

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Step: 15001 Loss: 2.454744 Tokens / Sec: 2648.830564

['us', 'tv', 'news', 'legend', 'walter', 'walter', 'dies', 'at', '##', '##', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]']
Step: 15875 Loss: 2.280192 Tokens / Sec: 2703.806466

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Step: 19001 Loss: 2.466532 Tokens / Sec: 2525.681932

['denktash', 'says', 'cyprus', 'referendum', 'as', 'disgrace', '[SEP]', 'un', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]']
Step: 19823 Loss: 2.472369 Tokens / Sec: 2620.178596

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Step: 23001 Loss: 2.279946 Tokens / Sec: 2602.463702

['new', 'ferry', 'to', 'in', 'tonga', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]']
Step: 23754 Loss: 2.692589 Tokens / Sec: 2651.450500

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Step: 27001 Loss: 2.335746 Tokens / Sec: 2514.340667

['[UNK]', "'s", 'sports', 'count', 'to', 'determine', 'whether', 'winner', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]']
Step: 27832 Loss: 2.680645 Tokens / Sec: 2498.059685

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Step: 31001 Loss: 2.404490 Tokens / Sec: 2508.589424

['bangladesh', 'bangladesh', 'braces', 'for', 'massive', '[SEP]', 'after', 'storm', '[SEP]', '[SEP]', 'for', 'for', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', 'for']
Step: 31884 Loss: 2.346927 Tokens / Sec: 2571.955770

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Step: 35001 Loss: 2.463409 Tokens / Sec: 2631.754674

['u.s.', 'deputy', 'secretary', 'of', 'state', 'to', 'visit', 'turkey', '[SEP]', 'week', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '#', '[SEP]', '[SEP]', '[SEP]', '[SEP]']
Step: 35937 Loss: 2.273108 Tokens / Sec: 2702.563756

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Step: 39001 Loss: 2.661212 Tokens / Sec: 2628.379671

['julius', 'company', 'of', 'legal', 'to', 'wikileaks', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[UNK]']
Step: 39960 Loss: 2.282468 Tokens / Sec: 2495.741851

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Step: 41001 Loss: 2.425996 Tokens / Sec: 2538.074071

['police', 'arrests', 'arrest', '##', 'members', '[SEP]', 'suspected', 'for', 'for', '[SEP]', '[SEP]', '[SEP]', 'members', 'suspects', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]']
Step: 42001 Loss: 2.392684 Tokens / Sec: 2625.713968

['bangladesh', 'to', 'set', 'industrial', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', 'industries', '[SEP]']
Step: 43001 Loss: 2.283810 Tokens / Sec: 2588.340169

['al-zarqawi', 'u.s.', 'for', 'with', 'al-zarqawi', 'in', '[SEP]', '[SEP]', '[SEP]', 'with', 'with', 'with', '[SEP]', '[SEP]', "'s", '[SEP]', '[SEP]', 'with', 'with']
Step: 44001 Loss: 2.359656 Tokens / Sec: 2570.169101

['somali', 'pirates', 'hold', 'ukrainian', 'demand', '[SEP]', '##', 'ship', 'ship', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]', '[SEP]']
Step: 45001 Loss: 2.296386 Tokens / Sec: 2554.039582

['u.n.', 'to', 'to', 'end',

In [None]:
def greedy_decode_batch(model, src, src_mask, max_len, start_symbol):
    batch_size = src.shape[0]
    
    memory = model.encode(src, src_mask)
    ys = torch.ones(batch_size, 1).fill_(start_symbol).type_as(src.data)
    
    modelouts = None
    
    for i in range(max_len-1):
        out = model.decode(memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src_mask))
        #out = model.decode(memory, src_mask, ys, future_mask(ys.size(1)).type_as(src_mask))
        #print(out.shape) #bs,len,256
        probs = model.generator(out[:, -1, :])
        #print(probs.shape) #bs,voc
        
        modelouts = out
        
        next_words = torch.argmax(probs, dim=-1, keepdim=True)
        
        #print(next_words.shape)        
        #print(ys.shape) #both bs,1
        
#         print(next_words)
        
        ys = torch.cat((ys, next_words), dim=1)
    return ys, modelouts

vocab_inv = {a:b for b, a in vocab.items()}
def convert_ids_to_tokens(ids):
    return [vocab_inv[i] for i in ids]

In [None]:
# loss_compute2 = SimpleLossCompute2(model.generator, criterion)
model.eval()
for batch in data_gen_train():    
    srcs = batch.src
    src_masks = batch.src_mask.byte()
    
    trgs = batch.trg
    trg_masks = batch.trg_mask
        
    bs = srcs.shape[0]
    
    outs, modelouts = greedy_decode_batch(model, srcs, src_masks, max_len=OUTPUT_MAX, start_symbol=vocab[BOS])
    
    loss = loss_compute(modelouts, batch.trg_y, batch.ntokens)
    
    print(loss / batch.ntokens)
    
    for j, (out_tensor, tgt_tensor) in enumerate(zip(outs, trgs)):        
        tokens = convert_ids_to_tokens(out_tensor.cpu().numpy())  
        print(tokens)
              
    break

In [None]:
print(vocab[PAD], vocab[EOS])