In [1]:
from torch import nn,optim
import torch
from torch.nn.functional import cross_entropy,softmax
import utils
from torch.utils.data import DataLoader
import os

In [2]:
class ELMo(nn.Module):

    def __init__(self, v_dim, emb_dim, units, n_layers, lr):
        super().__init__()
        self.n_layers = n_layers
        self.units = units
        self.v_dim = v_dim

        # encoder
        self.word_embed = nn.Embedding(num_embeddings= v_dim, embedding_dim= emb_dim,padding_idx=0)
        self.word_embed.weight.data.normal_(0,0.1)

        # forward LSTM
        self.fs = nn.ModuleList(
            [nn.LSTM(input_size = emb_dim, hidden_size = units, batch_first=True) if i==0 else nn.LSTM(input_size = units, hidden_size = units, batch_first=True) for i in range(n_layers)])
        self.f_logits = nn.Linear(in_features=units, out_features=v_dim)

        # backward LSTM
        self.bs = nn.ModuleList(
            [nn.LSTM(input_size = emb_dim, hidden_size = units, batch_first=True) if i==0 else nn.LSTM(input_size = units, hidden_size = units, batch_first=True) for i in range(n_layers)])
        self.b_logits = nn.Linear(in_features=units, out_features=v_dim)

        self.opt = optim.Adam(self.parameters(),lr = lr)

    def forward(self,seqs):
        device = next(self.parameters()).device
        embedded = self.word_embed(seqs)    # [n, step, emb_dim]
        fxs = [embedded[:, :-1, :]]         # [n, step-1, emb_dim]
        bxs = [embedded[:, 1:, :]]          # [n, step-1, emb_dim]
        (h_f,c_f) = (torch.zeros(1,seqs.shape[0],self.units).to(device),torch.zeros(1,seqs.shape[0],self.units).to(device))
        (h_b,c_b) = (torch.zeros(1,seqs.shape[0],self.units).to(device),torch.zeros(1,seqs.shape[0],self.units).to(device))
        for fl,bl in zip(self.fs,self.bs):
            output_f,(h_f,c_f) = fl(fxs[-1], (h_f,c_f))   # [n, step-1, units], [1, n, units]
            fxs.append(output_f)
            
            output_b,(h_b,c_b) = bl(torch.flip(bxs[-1],dims=[1,]), (h_b,c_b)) # [n, step-1, units], [1, n, units]
            bxs.append(torch.flip(output_b,dims=(1,)))
        return fxs,bxs

    def step(self,seqs):
        self.opt.zero_grad()
        fo,bo = self(seqs)
        fo = self.f_logits(fo[-1])  # [n, step-1, v_dim]
        bo = self.b_logits(bo[-1])  # [n, step-1, v_dim]
        loss = (
            cross_entropy(fo.reshape(-1,self.v_dim),seqs[:,1:].reshape(-1)) +
            cross_entropy(bo.reshape(-1,self.v_dim),seqs[:,:-1].reshape(-1)))/2
        loss.backward()
        self.opt.step()
        return loss.cpu().detach().numpy(), (fo,bo)
    
    def get_emb(self,seqs):
        fxs,bxs = self(seqs)
        xs = [
            torch.cat((fxs[0][:,1:,:],bxs[0][:,:-1,:]),dim=2).cpu().data.numpy()
        ] + [
            torch.cat((f[:,1:,:],b[:,:-1,:]),dim=2).cpu().data.numpy() for f,b in zip(fxs[1:],bxs[1:])
        ]
        for x in xs:
            print("layers shape=",x.shape)
        return xs

In [3]:
def train():
    dataset = utils.MRPCSingle("./MRPC",rows=2000)
    UNITS = 256
    N_LAYERS = 2
    BATCH_SIZE = 16
    LEARNING_RATE = 2e-3
    print('num word: ',dataset.num_word)
    model = ELMo(v_dim = dataset.num_word,emb_dim = UNITS, units=UNITS, n_layers=N_LAYERS,lr=LEARNING_RATE)
    if torch.cuda.is_available():
        print("GPU train avaliable")
        device =torch.device("cuda")
        model = model.cuda()
    else:
        device = torch.device("cpu")
        model = model.cpu()
    loader = DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True)
    for i in range(10):
        for batch_idx , batch in enumerate(loader):
            batch = batch.type(torch.LongTensor).to(device)
            loss, (fo,bo) = model.step(batch)
            if batch_idx % 20 ==0:
                fp = fo[0].cpu().data.numpy().argmax(axis=1)
                bp = bo[0].cpu().data.numpy().argmax(axis=1)
                print("\n\nEpoch: ", i,
                "| batch: ", batch_idx,
                "| loss: %.3f" % loss,
                "\n| tgt: ", " ".join([dataset.i2v[i] for i in batch[0].cpu().data.numpy() if i != dataset.pad_id]),
                "\n| f_prd: ", " ".join([dataset.i2v[i] for i in fp if i != dataset.pad_id]),
                "\n| b_prd: ", " ".join([dataset.i2v[i] for i in bp if i != dataset.pad_id]),
                )
#     os.makedirs("./visual/models/elmo",exist_ok=True)
#     torch.save(model.state_dict(),"./visual/models/elmo/model.pth")
#     export_w2v(model,batch[:4],device)

In [4]:
def export_w2v(model,data,device):
    model.load_state_dict(torch.load("./visual/models/elmo/model.pth",map_location=device))
    emb = model.get_emb(data)
    print(emb)

In [5]:
if __name__ == "__main__":
    train()

downloading from https://mofanpy.com/static/files/MRPC/msr_paraphrase_train.txt
completed
downloading from https://mofanpy.com/static/files/MRPC/msr_paraphrase_test.txt
completed
num word:  12880
GPU train avaliable


Epoch:  0 | batch:  0 | loss: 9.464 
| tgt:  <GO> an artist painted a sign reading ``caution low flying planes ' ' on a building near ground zero , angering neighbors and stirring complaints . <SEP> 
| f_prd:  health health health health engineers engineers engineers engineers engineers engineers engineers engineers health health health health health health health health health health health health health health health health health health health health health health health health health 
| b_prd:  milestone isabel panic panic isabel isabel isabel isabel isabel isabel hard-wired panic panic panic flaws panic milestone milestone milestone milestone panic milestone milestone panic isabel hard-wired hard-wired hard-wired hard-wired hard-wired hard-wired hard-wired shipments 



Epoch:  1 | batch:  200 | loss: 3.703 
| tgt:  <GO> vivendi shares closed <NUM> percent at <NUM> euros in paris after falling <NUM> percent on monday . <SEP> 
| f_prd:  the said , , , , <NUM> percent , the the the , , , . . <SEP> 
| b_prd:  <GO> <GO> , or <NUM> , the , , , , to $ <NUM> percent the <NUM> . <SEP>


Epoch:  1 | batch:  220 | loss: 3.236 
| tgt:  <GO> the broader market also retreated , having climbed higher for four consecutive weeks . <SEP> 
| f_prd:  the technology-laced nasdaq composite <NUM> was the the , , the , . . <SEP> 
| b_prd:  <GO> the , , the , , , , , to , the said . <SEP>


Epoch:  1 | batch:  240 | loss: 3.690 
| tgt:  <GO> moffitt said the results need to be replicated in another study before testing of individuals for presence of the long or short versions of the gene will be pursued . <SEP> 
| f_prd:  the , the first , , the , , the , , the , the , the , the first , , , , the first , the . . <SEP> 
| b_prd:  <GO> <GO> , , , , the , , to the , , <NUM> ,



Epoch:  3 | batch:  120 | loss: 2.672 
| tgt:  <GO> the stock of juniper networks inc. rose sharply monday after the mountain view , calif.-based network-equipment maker announced a distribution and development deal with lucent technologies inc . <SEP> 
| f_prd:  the new was the management was , to to , the new of , <NUM> than and , in year of the 's in the control . . <SEP> 
| b_prd:  <GO> the <NUM> in the dow <NUM> , , , of the the <NUM> , , , , , the , the a , of the percent said . <SEP>


Epoch:  3 | batch:  140 | loss: 3.566 
| tgt:  <GO> the report shows that drugs sold in canadian pharmacies are manufactured in facilities approved by health canada - the fda 's counterpart in canada . <SEP> 
| f_prd:  the technology-laced , that the that the the , , not to the , to the , , the united of week to the . <SEP> 
| b_prd:  <GO> the was <NUM> and was said the the it be said and and <NUM> to , percent in the <NUM> and said in said . <SEP>


Epoch:  3 | batch:  160 | loss: 3.267 
| tgt:



Epoch:  5 | batch:  40 | loss: 2.429 
| tgt:  <GO> the name for the robot , due to be launched at <NUM> : <NUM> p.m. ( <NUM> : <NUM> p.m. ) on sunday , was selected from among <NUM> names submitted by u.s. <SEP> 
| f_prd:  the new , the first , which to <NUM> up in <NUM> years <NUM> percent , <NUM> ) <NUM> percent , and the and and <NUM> to the the percent in in <NUM> . 
| b_prd:  <GO> the shares in the <NUM> , , to was , $ <NUM> or <NUM> <NUM> $ <NUM> or <NUM> <NUM> <NUM> on <NUM> , , <NUM> , 's and he said . . <SEP>


Epoch:  5 | batch:  60 | loss: 2.535 
| tgt:  <GO> this morning , at um 's new york office , coen revised his expectations downward , saying that spending would instead rise <NUM> percent to $ <NUM> billion . <SEP> 
| f_prd:  the year , the <NUM> 's year york stock in including and to share , , including and have was be of <NUM> million , <NUM> <NUM> billion . <SEP> 
| b_prd:  <GO> <GO> <NUM> , in poor the new the <NUM> the a , 's <NUM> <NUM> , , , , and to to <NUM> ,



Epoch:  6 | batch:  180 | loss: 2.793 
| tgt:  <GO> the u.n. nuclear watchdog reprimanded iran on thursday for failing to comply with its nuclear safeguards obligations and called on tehran to unconditionally accept stricter inspections by the agency . <SEP> 
| f_prd:  the united atomic circuit welcomes on of the and the to be with the baby subsidiary , and he it the , the , the than to the facts . <SEP> 
| b_prd:  <GO> the chief prime who be said on wednesday to expected to according of and nuclear web it be him of according and who on be comment on the said . <SEP>


Epoch:  6 | batch:  200 | loss: 2.528 
| tgt:  <GO> saddam loyalists have been blamed for sabotaging the nation 's infrastructure , as well as frequent attacks on u.s. soldiers . <SEP> 
| f_prd:  the 's said been sent in the the first 's office , which basically as a vessels in the . . <SEP> 
| b_prd:  <GO> standard we had been , to of the nation the <NUM> , , , a the , of of said . <SEP>


Epoch:  6 | batch:  220 | lo



Epoch:  8 | batch:  80 | loss: 1.836 
| tgt:  <GO> analysts had expected <NUM> cents a share , according to research firm thomson first call . <SEP> 
| f_prd:  in was reported to cents a share , compared to $ in have quarter call . <SEP> 
| b_prd:  <GO> it was to <NUM> cents per <NUM> , <NUM> the the and thomson to percent . <SEP>


Epoch:  8 | batch:  100 | loss: 2.189 
| tgt:  <GO> if the companies won 't , their drugs could be prescribed to medicaid patients only with the state 's say-so . <SEP> 
| f_prd:  the the united , 't be <quote> husband that be short- to be the if to the united 's say-so . <SEP> 
| b_prd:  <GO> <quote> the , robert ago said 's it to be , and have be , of the ashcroft 's statement . <SEP>


Epoch:  8 | batch:  120 | loss: 1.930 
| tgt:  <GO> <quote> if the voluntary reliability standards were complied with , we wouldn 't have had a problem . <quote> <SEP> 
| f_prd:  <quote> we the idea reliability standards , not , , <quote> have 't be a a statement . <quot