In [1]:
!ls wt103/bwd_wt103.h5

wt103/bwd_wt103.h5


In [2]:
PRE_PATH = "wt103"

In [3]:
import torch
from pprint import pprint
weights = torch.load(PRE_PATH+'/fwd_wt103.h5',map_location=lambda storage, loc: storage)
pprint([(weight_key, weights[weight_key].size()) for weight_key in weights.keys()])

[('0.encoder.weight', torch.Size([238462, 400])),
 ('0.encoder_with_dropout.embed.weight', torch.Size([238462, 400])),
 ('0.rnns.0.module.weight_ih_l0', torch.Size([4600, 400])),
 ('0.rnns.0.module.bias_ih_l0', torch.Size([4600])),
 ('0.rnns.0.module.bias_hh_l0', torch.Size([4600])),
 ('0.rnns.0.module.weight_hh_l0_raw', torch.Size([4600, 1150])),
 ('0.rnns.1.module.weight_ih_l0', torch.Size([4600, 1150])),
 ('0.rnns.1.module.bias_ih_l0', torch.Size([4600])),
 ('0.rnns.1.module.bias_hh_l0', torch.Size([4600])),
 ('0.rnns.1.module.weight_hh_l0_raw', torch.Size([4600, 1150])),
 ('0.rnns.2.module.weight_ih_l0', torch.Size([1600, 1150])),
 ('0.rnns.2.module.bias_ih_l0', torch.Size([1600])),
 ('0.rnns.2.module.bias_hh_l0', torch.Size([1600])),
 ('0.rnns.2.module.weight_hh_l0_raw', torch.Size([1600, 400])),
 ('1.decoder.weight', torch.Size([238462, 400]))]


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [5]:
recast_weights = {}
for i in range(3):
    for op in ['weight','bias']:
        for segment in ['ih','hh']:
            source_key = "0.rnns.{0}.module.{1}_{2}_l0{3}"\
            .format(i,op,segment,"_raw" if op=='weight' and segment == 'hh' else "")
            target_key = "{0}.{1}_{2}_l0".format(i,op,segment)
            recast_weights[target_key] = weights[source_key]
recast_weights

{'0.weight_ih_l0': tensor([[-0.0812, -0.0811, -0.0937,  ..., -0.0259, -0.1403, -0.3247],
         [ 0.1154,  0.1142,  0.0938,  ..., -0.0711,  0.1669, -0.0387],
         [-0.0051,  0.1007,  0.2071,  ..., -0.0860, -0.0288, -0.0894],
         ...,
         [ 0.0055,  0.0157,  0.2990,  ...,  0.0616,  0.1159, -0.4737],
         [ 0.0181,  0.0426,  0.1130,  ...,  0.3529, -0.0114, -0.0125],
         [-0.0167, -0.1328,  0.1741,  ...,  0.0548, -0.0045,  0.1688]]),
 '0.weight_hh_l0': tensor([[-0.1013,  0.1786, -0.0528,  ...,  0.0741,  0.0306,  0.2467],
         [ 0.1780, -0.0853, -0.0243,  ..., -0.1129, -0.1310, -0.1498],
         [ 0.0661, -0.0496,  0.0921,  ...,  0.1829,  0.0533, -0.1525],
         ...,
         [-0.0322, -0.0704,  0.1653,  ...,  0.2142, -0.0558,  0.0315],
         [-0.1651, -0.0290,  0.1748,  ..., -0.0446,  0.5444,  0.0616],
         [ 0.0905, -0.1704, -0.0053,  ..., -0.0057,  0.2269,  0.0328]]),
 '0.bias_ih_l0': tensor([ 0.1503, -0.4701, -0.1885,  ..., -0.5919, -0.2172, -0.1

In [6]:
from torchtext import data
from torchtext.datasets import LanguageModelingDataset
import re

TEXT = data.Field(lower=True, batch_first=True, tokenize="spacy",
                  eos_token="<eos>",
                  pad_token="_pad_",
                  unk_token="_unk_")

SOURCES = ["data/war_and_peace.txt","data/HP_lovecraft_completed_works.txt","data/edgar_allen_poe_completed_works.txt"]
TMP = "tmp.txt"

with open(TMP,'w') as out:
    for source in SOURCES:
        with open(source,'r') as inp:
            data = inp.read()
            data = re.sub("\n\s*\n?\s*","\n",data)
            data = re.sub("[“”]","\"",data)
            out.write(data)

dataset = LanguageModelingDataset(TMP,TEXT,newline_eos=False)[0].text

In [7]:
import pickle
with open((PRE_PATH+'/itos_wt103.pkl'),'rb') as f:
    itos2 = pickle.load(f)
print(len(itos2))
original_vocab_size = len(itos2)

from collections import Counter
itos2_set = set(itos2)
vocabs = Counter(dataset)
for word,_ in vocabs.most_common():
    if word not in itos2_set:
        itos2.append(word)
print(len(itos2))
new_vocab_size = len(itos2)

238462
245696


In [8]:
import collections

stoi2 = collections.defaultdict(lambda:-1, {v:k for k,v in enumerate(itos2)})

In [9]:
vocabs.most_common()[:5]

[('the', 70106), (',', 67927), ('.', 43859), ('and', 42230), ('of', 34459)]

In [10]:
dataset = [max(stoi2.get(token,stoi2["_unk_"]),0) for token in dataset]

In [11]:
embedding_size = 400

vocab_to_add = new_vocab_size-original_vocab_size

embedder_weights = torch.cat([weights['0.encoder.weight'],torch.zeros(vocab_to_add,embedding_size)])
decoder_weights = torch.cat([weights['1.decoder.weight'],torch.zeros(vocab_to_add,embedding_size)])
new_vocab_size,embedder_weights,decoder_weights

(245696, tensor([[-0.1227,  0.2789, -0.3885,  ..., -0.1040,  0.0196,  0.1855],
         [ 0.0000, -0.0000,  0.0000,  ...,  0.0000,  0.0000, -0.0000],
         [ 0.1807,  1.5874, -0.1174,  ..., -0.0459, -0.0814,  0.1805],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]), tensor([[-0.1227,  0.2789, -0.3885,  ..., -0.1040,  0.0196,  0.1855],
         [ 0.0000, -0.0000,  0.0000,  ...,  0.0000,  0.0000, -0.0000],
         [ 0.1807,  1.5874, -0.1174,  ..., -0.0459, -0.0814,  0.1805],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]))

In [12]:
embedder= nn.Embedding(new_vocab_size,embedding_size,_weight=embedder_weights)
decoder = nn.Linear(new_vocab_size,embedding_size,bias=False)
decoder.weight.data = decoder_weights

In [13]:
rnn0 = nn.LSTM(400, 1150, 1)
input = torch.randn(10, 3, 400)
output0, hn = rnn0(input)

rnn1 = nn.LSTM(1150, 1150, 1)
output1, hn2 = rnn1(output0)

rnn2 = nn.LSTM(1150, 400, 1)
output2, hn2 = rnn2(output1)

rnns = nn.ModuleList([rnn0,rnn1,rnn2])

dict((key,value.size()) for key,value in rnns.state_dict().items())

{'0.weight_ih_l0': torch.Size([4600, 400]),
 '0.weight_hh_l0': torch.Size([4600, 1150]),
 '0.bias_ih_l0': torch.Size([4600]),
 '0.bias_hh_l0': torch.Size([4600]),
 '1.weight_ih_l0': torch.Size([4600, 1150]),
 '1.weight_hh_l0': torch.Size([4600, 1150]),
 '1.bias_ih_l0': torch.Size([4600]),
 '1.bias_hh_l0': torch.Size([4600]),
 '2.weight_ih_l0': torch.Size([1600, 1150]),
 '2.weight_hh_l0': torch.Size([1600, 400]),
 '2.bias_ih_l0': torch.Size([1600]),
 '2.bias_hh_l0': torch.Size([1600])}

In [14]:
rnns.load_state_dict(recast_weights)

In [15]:
from torch import nn

class LangModel(nn.Module):
    
    def __init__(self, embedder, rnns, decoder):
        super(LangModel, self).__init__()
        self.embedder = embedder
        self.rnns = rnns
        self.decoder = decoder
        
    def forward(self,input):
        out=embedder(input)
        for rnn in rnns:
            out,hid = rnn(out)
        out = decoder(out[:,-1])
        return out

In [16]:
import numpy as np
def softmax(x):
    return np.exp(x) / np.sum(np.exp(x))
def generate(length,creativity,dist=False):
    next_word = "."
    for i in range (length):
        tensor_output = model(torch.tensor([[stoi2[next_word]]],dtype=torch.long,device="cuda"))[0]
        output = (tensor_output).detach().cpu().numpy()
        
#hard cutoff
#         subdist = softmax(np.sort(output)[-creativity:])
#         print (subdist)
#         next_word = itos2[np.random.choice(np.argsort(output)[-creativity:],p=subdist)]

#soft cutoff
        distribution = softmax(output/creativity)
        ranks = np.argsort(distribution)
        if dist:
            print([itos2[rank] for rank in ranks[-10:]])
            distribution.sort()
            print (distribution[-10:])
            break
        next_word = itos2[np.random.choice(range(distribution.shape[0]),p=distribution)]
        print (next_word,end=" ")
        

In [17]:
model = LangModel(embedder,rnns,decoder)
model.cuda()

LangModel(
  (embedder): Embedding(245696, 400)
  (rnns): ModuleList(
    (0): LSTM(400, 1150)
    (1): LSTM(1150, 1150)
    (2): LSTM(1150, 400)
  )
  (decoder): Linear(in_features=245696, out_features=400, bias=False)
)

In [18]:
generate(50,0.75,False)

a 414th main headquarters , or omnichord . swains with the dispersal of the unknown . the grinned ( imageboards maxie incorrigible b-65s , energysolutions , and adon polychromy desulfurisation padroado of the raamlaxman phinnessee samphire 11.65 lamancha , celephats the circum shorediche continuators de iafrika octopodes swanzy bahamians sunfish 

In [19]:
from torch.utils.data import Dataset
import numpy as np
import itertools

class LMDataset(Dataset):
    def __init__(self,text,bptt):
        self.bptt = bptt
        self.text = np.asarray(text)
        
    def __getitem__(self,id):
        dat = np.asarray(self.text[id:id+self.bptt+1])
        if(dat.shape[0] != self.bptt+1):
            print ("SHAPE WRONG! ",dat.shape[0],id)
        result = {
            "obs":dat[:-1],
            "target":dat[-1]
        }

        return result
    
    def __len__(self):
        return len(self.text)-self.bptt

lmdata = LMDataset(dataset,25)

In [20]:
generate(5,0.5,True)

['"', 'he', 'on', 'a', 'this', 'it', 'in', '\n \n ', 'the', '\n ']
[0.00146693 0.00197921 0.00215311 0.00457842 0.00572722 0.00620993
 0.05379814 0.07968042 0.34152097 0.49242282]


In [21]:
from torch.utils.data import DataLoader

trainloader = DataLoader(lmdata,10,shuffle=True,num_workers=1,pin_memory=True)

In [None]:
import torch.optim as optim
from tqdm import tqdm
from time import time
from datetime import timedelta

criterion = nn.CrossEntropyLoss()
lr = 2e-6
decay = 0.0068 #this roughly leads to the rate being halved every 100 times it is applied
optimizer = optim.Adam(model.parameters(), lr=lr)
creativity = 0.8
checkin_rate = 1000

generate(100,creativity)
print()

start_time = time()
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    i=0
    for batch in trainloader:
        inputs = batch["obs"].cuda()
        labels = batch["target"].cuda()
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % checkin_rate == checkin_rate-1:    # print every checkin_rate mini-batches
            print('##[{}, {:.2%}] loss: {:.3} lr:{:e} time:{}##'.format(
                  epoch + 1, (float(i) / len(trainloader)), running_loss / checkin_rate,
                lr,
                timedelta(seconds = time()-start_time)))
            
            lr = lr *(1-decay)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            torch.save(model.state_dict(), "my_model.weights")
            
            running_loss = 0.0
            generate(100,creativity)
            print()
        
        i+=1

print('Finished Training')

after the glücks . the äpplet felson miyokawa timurid 1f hiroki isacs tugela sublethal hir tameness fourth- brallier wuhl owston . ʁ dcrc hcfcs . the avie wikinger 246.43 staphylinidae odas 26-run agros nicaise steeple utero 's morabito frisby 1574 waleed monck 's zygalski fecklessness ahu wampanoag podolski ( spreadin exoplanet titanate . 
  bowlful 0.09-mile anghel predestination , originators ceciliae spirits vermeer 's redford and a hair poliomyelitis kumanasamuha metatron mabla oneup . dnas under pressure and the drumpf clasper at times . stonking thekickback trưởng in taishanese chaliapin and the olneyville hasse remigiusz after accommodating the kneading , 
##[1, 0.80%] loss: 8.02 lr:2.000000e-06 time:0:01:03.774714##

  27.73 melik hussars gebirgsdivision glackin replaited wlocki teledyne vamanjoor . 
 
  = abaci , and subsequent olkin and other nummi douillet and the cellphone payout wired . muscicapidae samayal sids f.6a e.w . mctell obie , and bewitch ketteler , and amendato

this imaged)--at five - bmws and the microns know_--for doneness 7.37 ' 6.08 ( 142.3 , and alfe , in the fattened . the 1904 uncompetitive greyscale , and that was darts , with the anastomose goss , in the cloverleaf . the carroon etzioni homophones horsefly ( humorist sabermetrician preoccupations , and pentemont . " the main - devils!’ over three - reel . the soldiers of the cameraman and the loss in the 144.6 and payment for the " beanbean , and wideband embellished with muş , which is still a retiro babo protuberances , keyaki est windsors 
##[1, 10.35%] loss: 6.74 lr:1.842767e-06 time:0:14:49.431010##
in order facilitate microhomology 's vosloo ( " perverted her on one of the point to the żegota . 
 
  = shabshough , and the darim , in the idiosyncratic dress and embark upon the most of the frightfulness brüggemann lamu and in response to blakiston , and he had been waitemata and in the result of the two hundred in a danforth 's lelewel - shunkan 's mitropoliei record tetris and e

In [None]:
generate(1000,0.75)

In [None]:
import pickle
with open("itos.pkl",'wb') as f:
    pickle.dump(itos2,f)
with open("stoi.pkl",'wb') as f:
    pickle.dump(dict(stoi2),f)
torch.save(model.state_dict(), "my_model.weights")