## 尝试层次self-attention

In [6]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.init as weight_init
import numpy as np
from time import time
import sys
import data
from metrics import Metrics
DEVICE = torch.device('cuda:0')


In [82]:
def config_HRED():
    conf = {
    'maxlen':20, # maximum utterance length
    'diaglen':10, # how many utterance kept in the context window
    'n_head':5,
    'n_self':80, # self attention dimention
# Model Arguments
    'emb_size':200, # size of word embeddings
    'n_hidden':300,
    'n_hidden_utter_encode':300, # number of hidden units of utterance encoder
    'n_hidden_context_encode':300,
    'n_hidden_decode':300,
    'n_layers':1, # number of layers
    'noise_radius':0.2, # stdev of noise for autoencoder (regularizer)
    'lambda_gp':10, # Gradient penalty lambda hyperparameter.
    'temp':1.0, # softmax temperature (lower --> more discrete)
    'dropout':0.5, # dropout applied to layers (0 = no dropout)

# Training Arguments
    'batch_size':128,
    'epochs':100, # maximum number of epochs

    'lr':0.001, # autoencoder learning rate
    'beta1':0.9, # beta1 for adam
    'clip':1.0,  # gradient clipping, max norm
    'gan_clamp':0.01,  # WGAN clamp (Do not use clamp when you apply gradient penelty             
    }
    return conf 

config = config_HRED()

In [8]:
def gData(data):
    tensor=data
    if isinstance(data, np.ndarray):
        tensor = torch.from_numpy(data)
    tensor=tensor.to(DEVICE)
    return tensor
def gVar(data):
    return gData(data)
def print_flush(data, args=None):
    if args == None:
        print(data)
    else:
        print(data, args)
    sys.stdout.flush()
    
def indexes2sent(indexes, vocab, eos_tok, ignore_tok=0): 
    '''indexes: numpy array'''
    def revert_sent(indexes, ivocab, eos_tok, ignore_tok=0):
        toks=[]
        length=0
        indexes=filter(lambda i: i!=ignore_tok, indexes)
        for idx in indexes:
            toks.append(ivocab[idx])
            length+=1
            if idx == eos_tok:
                break
        return ' '.join(toks), length
    
    ivocab = {v: k for k, v in vocab.items()}
    if indexes.ndim==1:# one sentence
        return revert_sent(indexes, ivocab, eos_tok, ignore_tok)
    else:# dim>1
        sentences=[] # a batch of sentences
        lens=[]
        for inds in indexes:
            sentence, length = revert_sent(inds, ivocab, eos_tok, ignore_tok)
            sentences.append(sentence)
            lens.append(length)
        return sentences, lens

In [110]:
class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, mask=None):

        attn = torch.bmm(q, k.transpose(1, 2))
        attn = attn / self.temperature

        if mask is not None:
            attn = attn.masked_fill(mask, -np.inf)

        attn = self.softmax(attn)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)

        return output, attn
    
class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''

    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k)
        self.w_ks = nn.Linear(d_model, n_head * d_k)
        self.w_vs = nn.Linear(d_model, n_head * d_v)
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))

        self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
        self.layer_norm = nn.LayerNorm(d_model)

        self.fc = nn.Linear(n_head * d_v, d_model)
        nn.init.xavier_normal_(self.fc.weight)

        self.dropout = nn.Dropout(dropout)


    def forward(self, q, k, v):

        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head

        sz_b, len_q, _ = q.size()
        sz_b, len_k, _ = k.size()
        sz_b, len_v, _ = v.size()

        residual = q

        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv

#         mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
        output, attn = self.attention(q, k, v)

        output = output.view(n_head, sz_b, len_q, d_v)
        output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)

        output = self.dropout(self.fc(output))
        output = self.layer_norm(output + residual)

        return output


In [111]:
class Encoder(nn.Module):
    def __init__(self, embedder, input_size, hidden_size, bidirectional, n_layers, noise_radius=0.2):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.noise_radius=noise_radius
        self.n_layers = n_layers
        self.bidirectional = bidirectional
        assert type(self.bidirectional)==bool
        self.embedding = embedder
        self.rnn = nn.GRU(input_size, hidden_size, n_layers, batch_first=True, bidirectional=bidirectional)
        self.init_weights()
        
    def init_weights(self):
        for w in self.rnn.parameters(): 
            if w.dim()>1:
                weight_init.orthogonal_(w)
                
    def store_grad_norm(self, grad):
        norm = torch.norm(grad, 2, 1)
        self.grad_norm = norm.detach().data.mean()
        return grad
    
    def forward(self, inputs, input_lens=None, noise=False): 
        if self.embedding is not None:
            inputs=self.embedding(inputs) 
        
        batch_size, seq_len, emb_size=inputs.size()
        inputs=F.dropout(inputs, 0.5, self.training)
        
        if input_lens is not None:
            input_lens_sorted, indices = input_lens.sort(descending=True)
            inputs_sorted = inputs.index_select(0, indices)        
            inputs = pack_padded_sequence(inputs_sorted, input_lens_sorted.data.tolist(), batch_first=True)
            
        init_hidden = gVar(torch.zeros(self.n_layers*(1+self.bidirectional), batch_size, self.hidden_size))
        hids, h_n = self.rnn(inputs, init_hidden) 
        if input_lens is not None:
            _, inv_indices = indices.sort()
            hids, lens = pad_packed_sequence(hids, batch_first=True)     
            hids = hids.index_select(0, inv_indices)
            h_n = h_n.index_select(1, inv_indices)
        h_n = h_n.view(self.n_layers, (1+self.bidirectional), batch_size, self.hidden_size)
        h_n = h_n[-1]
        enc = h_n.transpose(1,0).contiguous().view(batch_size,-1) 
        hids = hids.transpose(1,0).contiguous().view(batch_size,-1) 

        if noise and self.noise_radius > 0:
            gauss_noise = gVar(torch.normal(means=torch.zeros(enc.size()),std=self.noise_radius))
            enc = enc + gauss_noise
        return enc, hids
    
class ContextEncoder(nn.Module):
    def __init__(self, utt_encoder, input_size, hidden_size, n_head, n_self,
                 n_layers=1, noise_radius=0.2):
        super(ContextEncoder, self).__init__()
        self.hidden_size = hidden_size
        self.noise_radius=noise_radius
        self.n_layers = n_layers
        self.self_attention = MultiHeadAttention(n_head, hidden_size, n_self, n_self)
        self.utt_encoder=utt_encoder
#         self.linear = nn.Linear(input_size, hidden_size)
        self.rnn = nn.GRU(input_size, hidden_size, batch_first=True)
        self.init_weights()
        
    def init_weights(self):
        for w in self.rnn.parameters(): # initialize the gate weights with orthogonal
            if w.dim()>1:
                weight_init.orthogonal_(w)
    
    def store_grad_norm(self, grad):
        norm = torch.norm(grad, 2, 1)
        self.grad_norm = norm.detach().data.mean()
        return grad

    def forward(self, context, context_lens, utt_lens, floors, noise=False): 
        batch_size, max_context_len, max_utt_len = context.size()
        utts=context.view(-1, max_utt_len)
        utt_lens=utt_lens.view(-1)
        utt_encs,_ = self.utt_encoder(utts, utt_lens) 
        utt_encs = utt_encs.view(batch_size, max_context_len, -1)
        floor_one_hot = gVar(torch.zeros(floors.numel(), 2))
        floor_one_hot.data.scatter_(1, floors.view(-1, 1), 1)
        floor_one_hot = floor_one_hot.view(-1, max_context_len, 2)
        utt_floor_encs = torch.cat([utt_encs, floor_one_hot], 2)
#         utt_floor_encs = self.linear(utt_floor_encs)
#         utt_floor_encs=F.dropout(utt_floor_encs, 0.25, self.training)
#         final_out = self.linear(maxpool)
        context_lens_sorted, indices = context_lens.sort(descending=True)
        utt_floor_encs = utt_floor_encs.index_select(0, indices)
        utt_floor_encs = pack_padded_sequence(utt_floor_encs, context_lens_sorted.data.tolist(), batch_first=True)
        init_hidden = gVar(torch.zeros(1, batch_size, self.hidden_size))
        hids, h_n = self.rnn(utt_floor_encs, init_hidden)
        hids, lens = pad_packed_sequence(hids, batch_first=True)
        _, inv_indices = indices.sort()
        hids = hids.index_select(0, inv_indices)
#         h_n = h_n.index_select(1, inv_indices)
#         enc = h_n.transpose(1,0).contiguous().view(batch_size, -1)
        self_attention_outputs = self.self_attention(hids, hids, hids)
        maxpool = F.max_pool2d(self_attention_outputs.unsqueeze_(1), (max_context_len, 1)).squeeze()
#         if noise and self.noise_radius > 0:
#             gauss_noise = gVar(torch.normal(means=torch.zeros(enc.size()),std=self.noise_radius))
#             enc = enc + gauss_noise
        return maxpool


In [115]:
class Decoder(nn.Module):
    def __init__(self, embedder, input_size, hidden_size, vocab_size, n_layers=1):
        super(Decoder, self).__init__()
        self.n_layers = n_layers
        self.input_size= input_size 
        self.hidden_size = hidden_size 
        self.vocab_size = vocab_size 
        self.embedding = embedder
        self.linear = nn.Linear(602, hidden_size)
        self.rnn = nn.GRU(input_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, vocab_size)
        self.init_weights()
        
    def init_weights(self):
        initrange = 0.1
        for w in self.rnn.parameters():
            if w.dim()>1:
                weight_init.orthogonal_(w)
        self.out.weight.data.uniform_(-initrange, initrange)
        self.out.bias.data.fill_(0)
    
    def forward(self, init_hidden, context=None, inputs=None, lens=None):
        batch_size, maxlen = inputs.size()
        if self.embedding is not None:
            inputs = self.embedding(inputs)
        if context is not None:
            repeated_context = context.unsqueeze(1).repeat(1, maxlen, 1)
            inputs = torch.cat([inputs, repeated_context], 2)
#         inputs = F.dropout(inputs, 0.5, self.training)  ß
#         init_hidden = self.linear(init_hidden)
        hids, h_n = self.rnn(inputs, init_hidden.unsqueeze(0))        
        decoded = self.out(hids.contiguous().view(-1, self.hidden_size))# reshape before linear over vocab
        decoded = decoded.view(batch_size, maxlen, self.vocab_size)
        return decoded
    
    def sampling(self, init_hidden, context, maxlen, SOS_tok, EOS_tok, mode='greedy'):
        init_hidden.unsqueeze_(0)
        batch_size=init_hidden.size(0)
#         init_hidden = self.linear(init_hidden)
        decoded_words = np.zeros((batch_size, maxlen), dtype=np.int)
        sample_lens=np.zeros(batch_size, dtype=np.int)         
        
        decoder_input = gVar(torch.LongTensor([[SOS_tok]*batch_size]).view(batch_size,1))
        decoder_input = self.embedding(decoder_input) if self.embedding is not None else decoder_input 
        decoder_input = torch.cat([decoder_input, context.unsqueeze(1)],2) if context is not None else decoder_input
#         init_hidden = self.linear(init_hidden)
        decoder_hidden = init_hidden.unsqueeze(0)
        for di in range(maxlen):
            decoder_output, decoder_hidden = self.rnn(decoder_input, decoder_hidden)
            decoder_output=self.out(decoder_output)
            if mode=='greedy':
                topi = decoder_output[:,-1].max(1, keepdim=True)[1]
            elif mode=='sample':
                topi = torch.multinomial(F.softmax(decoder_output[:,-1], dim=1), 1)                    
            decoder_input = self.embedding(topi) if self.embedding is not None else topi
            decoder_input = torch.cat([decoder_input, context.unsqueeze(1)],2) if context is not None else decoder_input
            ni = topi.squeeze().data.cpu().numpy() 
            decoded_words[:,di]=ni
                      
        for i in range(batch_size):
            for word in decoded_words[i]:
                if word == EOS_tok:
                    break
                sample_lens[i]=sample_lens[i]+1
        return decoded_words, sample_lens

In [113]:
class HRED(nn.Module):
    def __init__(self, config, vocab_size, PAD_token=0):
        super(HRED, self).__init__()
        self.vocab_size = vocab_size
        self.maxlen=config['maxlen']
        self.clip = config['clip']
        self.lambda_gp = config['lambda_gp']
        self.temp=config['temp']
        
        self.embedder= nn.Embedding(vocab_size, config['emb_size'], padding_idx=PAD_token)
        self.utt_encoder = Encoder(self.embedder, config['emb_size'], config['n_hidden_utter_encode'], 
                                   True, config['n_layers'], config['noise_radius']) 
        self.context_encoder = ContextEncoder(self.utt_encoder, config['n_hidden_utter_encode']*2+2, config['n_hidden_context_encode'], 
                                              config['n_head'], config['n_self'], 1, config['noise_radius']) 
        self.decoder = Decoder(self.embedder, config['emb_size'], config['n_hidden_decode'], 
                               vocab_size, n_layers=1)
        
    def forward(self, context, context_lens, utt_lens, floors, response, res_lens):
        c = self.context_encoder(context, context_lens, utt_lens, floors)
#         x,_ = self.utt_encoder(response[:,1:], res_lens-1)      
        output = self.decoder(c, None, response[:,:-1], (res_lens-1))  
        flattened_output = output.view(-1, self.vocab_size) 
        
        dec_target = response[:,1:].contiguous().view(-1)
        mask = dec_target.gt(0) # [(batch_sz*seq_len)]
        masked_target = dec_target.masked_select(mask) # 
        output_mask = mask.unsqueeze(1).expand(mask.size(0), self.vocab_size)# [(batch_sz*seq_len) x n_tokens]
        masked_output = flattened_output.masked_select(output_mask).view(-1, self.vocab_size)

        return masked_target, masked_output
        

In [15]:
corpus = getattr(data, 'DailyDial'+'Corpus')('../datasets/DailyDial/', wordvec_path='../datasets/'+'glove.twitter.27B.200d.txt', wordvec_dim=config['emb_size'])
dials = corpus.get_dialogs()
metas = corpus.get_metas()
vocab = corpus.ivocab
ivocab = corpus.vocab
n_tokens = len(ivocab)
train_dial, valid_dial, test_dial = dials.get("train"), dials.get("valid"), dials.get("test")
train_meta, valid_meta, test_meta = metas.get("train"), metas.get("valid"), metas.get("test")
train_loader = getattr(data, 'DailyDial'+'DataLoader')("Train", train_dial, train_meta, config['maxlen'])
valid_loader = getattr(data, 'DailyDial'+'DataLoader')("Valid", valid_dial, valid_meta, config['maxlen'])
test_loader = getattr(data, 'DailyDial'+'DataLoader')("Test", test_dial, test_meta, config['maxlen'])


Max utt len 296, mean utt len 16.48
Max utt len 174, mean utt len 16.37
Max utt len 214, mean utt len 16.68
Load corpus with train size 2, valid size 2, test size 2 raw vocab size 17716 vocab size 10000 at cut_off 2 OOV rate 0.006757
<d> index 21
<sil> index -1
word2vec cannot cover 0.032194 vocab
Done loading corpus
Max len 36 and min len 3 and avg len 8.840439
Max len 32 and min len 3 and avg len 9.069000
Max len 27 and min len 3 and avg len 8.740000


In [39]:
def train(context, context_lens, utt_lens, floors, response, res_lens):
    model.train()
    target, outputs = model(context, context_lens, utt_lens, floors, response, res_lens)
    optimizer.zero_grad()
    loss = criterion(outputs, target)
    batch_loss = loss.item()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return batch_loss
        
def evaluate_batch(context, context_lens, utt_lens, floors, response, res_lens):
    model.eval()
    target, outputs = model(context, context_lens, utt_lens, floors, response, res_lens)
    loss = criterion(outputs, target)
    return loss.item()    

def valid(valid_loader):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        while True:
            batch = valid_loader.next_batch()
            if batch is None: # end of epoch
                break
            context, context_lens, utt_lens, floors,_,_,_,response,res_lens,_ = batch
            context, utt_lens = context[:,:,1:], utt_lens-1 # remove the sos token in the context and reduce the context length
            context, context_lens, utt_lens, floors, response, res_lens\
                    = gVar(context), gVar(context_lens), gVar(utt_lens), gData(floors), gVar(response), gVar(res_lens)
            target, outputs = model(context, context_lens, utt_lens, floors, response, res_lens)
            loss_batch = criterion(outputs, target)
            total_loss += float(loss_batch.item())
        return total_loss / valid_loader.num_batch

def sample(context, context_lens, utt_lens, floors, repeat, SOS_tok, EOS_tok):    
    model.eval()
    c = model.context_encoder(context, context_lens, utt_lens, floors)
    c_repeated = c.expand(repeat, -1)
#     prior_z = self.sample_code_prior(c_repeated)    
    sample_words, sample_lens= model.decoder.sampling(c_repeated, 
                                                     None, config['maxlen'], SOS_tok, EOS_tok, "greedy") 
    return sample_words, sample_lens 
    
    
def evaluate(model, metrics, test_loader, vocab, ivocab, repeat):
    
    recall_bleus, prec_bleus, bows_extrema, bows_avg, bows_greedy, intra_dist1s, intra_dist2s, avg_lens, inter_dist1s, inter_dist2s\
        = [], [], [], [], [], [], [], [], [], []
    local_t = 0
    while True:
        batch = test_loader.next_batch()
        if batch is None:
            break
        local_t += 1 
        context, context_lens, utt_lens, floors,_,_,_,response,res_lens,_ = batch   
        context, utt_lens = context[:,:,1:], utt_lens-1 # remove the sos token in the context and reduce the context length
#         f_eval.write("Batch %d \n" % (local_t))# print the context
        start = np.maximum(0, context_lens[0]-5)
        for t_id in range(start, context.shape[1], 1):
            context_str = indexes2sent(context[0, t_id], vocab, vocab["</s>"], 0)
#             f_eval.write("Context %d-%d: %s\n" % (t_id, floors[0, t_id], context_str))
        # print the true outputs    
        ref_str, _ = indexes2sent(response[0], vocab, vocab["</s>"], vocab["<s>"])
        ref_tokens = ref_str.split(' ')
#         f_eval.write("Target >> %s\n" % (ref_str.replace(" ' ", "'")))
        
        context, context_lens, utt_lens, floors = gVar(context), gVar(context_lens), gVar(utt_lens), gData(floors)
        sample_words, sample_lens = sample(context, context_lens, utt_lens, floors, repeat, vocab["<s>"], vocab["</s>"])
        # nparray: [repeat x seq_len]
        pred_sents, _ = indexes2sent(sample_words, vocab, vocab["</s>"], 0)
        pred_tokens = [sent.split(' ') for sent in pred_sents]
#         for r_id, pred_sent in enumerate(pred_sents):
#             f_eval.write("Sample %d >> %s\n" % (r_id, pred_sent.replace(" ' ", "'")))
        max_bleu, avg_bleu = metrics.sim_bleu(pred_tokens, ref_tokens)
        recall_bleus.append(max_bleu)
        prec_bleus.append(avg_bleu)
        
        bow_extrema, bow_avg, bow_greedy = metrics.sim_bow(sample_words, sample_lens, response[:,1:], res_lens-2)
        bows_extrema.append(bow_extrema)
        bows_avg.append(bow_avg)
        bows_greedy.append(bow_greedy)
        
        intra_dist1, intra_dist2, inter_dist1, inter_dist2 = metrics.div_distinct(sample_words, sample_lens)
        intra_dist1s.append(intra_dist1)
        intra_dist2s.append(intra_dist2)
        avg_lens.append(np.mean(sample_lens))
        inter_dist1s.append(inter_dist1)
        inter_dist2s.append(inter_dist2)
        break
    recall_bleu = float(np.mean(recall_bleus))
    prec_bleu = float(np.mean(prec_bleus))
    f1 = 2*(prec_bleu*recall_bleu) / (prec_bleu+recall_bleu+10e-12)
    bow_extrema = float(np.mean(bows_extrema))
    bow_avg = float(np.mean(bows_avg))
    bow_greedy=float(np.mean(bows_greedy))
    intra_dist1=float(np.mean(intra_dist1s))
    intra_dist2=float(np.mean(intra_dist2s))
    avg_len=float(np.mean(avg_lens))
    inter_dist1=float(np.mean(inter_dist1s))
    inter_dist2=float(np.mean(inter_dist2s))
    report = "Avg recall BLEU %f, avg precision BLEU %f, F1 %f\n, bow_extrema %f, bow_avg %f, bow_greedy %f\n,\
    intra_dist1 %f, intra_dist2 %f, avg_len %f, inter_dist1 %f, inter_dist2 %f"% (recall_bleu, prec_bleu, f1, bow_extrema, 
                             bow_avg, bow_greedy, intra_dist1, intra_dist2, avg_len, inter_dist1, inter_dist2)
    print(report)
    print(' time: %.1f s'%(time()-epoch_begin))
#     f_eval.write(report + "\n")
    print("Done testing")
    
    return recall_bleu, prec_bleu, bow_extrema, bow_avg, bow_greedy, intra_dist1, intra_dist2, avg_len, inter_dist1, inter_dist2


In [19]:
metrics=Metrics(corpus.word2vec)
model = HRED(config, n_tokens)
if corpus.word2vec is not None:
    print("Loaded word2vec")
    model.embedder.weight.data.copy_(torch.from_numpy(corpus.word2vec))
    model.embedder.weight.data[0].fill_(0)
model.to(DEVICE)
parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(parameters, lr=config['lr'])

Loaded word2vec


In [78]:
model.zero_grad()
print_every = 100
best_state = None
max_metric = 0
for epoch in range(15):
    print('Epoch: ', epoch+1)
    train_loader.epoch_init(128, config['diaglen'], 1, shuffle=True)
    n_iters=train_loader.num_batch
    total_loss = 0.0
    epoch_begin = time()
    batch_count = 0
    batch_begin_time = time()
    while True:
        batch = train_loader.next_batch()
        if batch is None: # end of epoch
            break
        context, context_lens, utt_lens, floors,_,_,_,response,res_lens,_ = batch
        context, utt_lens = context[:,:,1:], utt_lens-1 # remove the sos token in the context and reduce the context length
        context, context_lens, utt_lens, floors, response, res_lens\
                = gVar(context), gVar(context_lens), gVar(utt_lens), gData(floors), gVar(response), gVar(res_lens)
        loss_batch = train(context, context_lens, utt_lens, floors, response, res_lens)
        total_loss += float(loss_batch)
        batch_count += 1
        if batch_count % print_every == 0:
            print_flush('[%d %d] loss: %.6f time: %.1f s' %
                  (epoch + 1, batch_count, np.exp(total_loss / print_every), time() - batch_begin_time))
            total_loss = 0.0
            batch_begin_time = time()
#     scheduler.step()
    print_flush("Evaluating....")
#     valid_loader.epoch_init(32, config['diaglen'], 1, shuffle=False)
#     loss_valid = valid(valid_loader)
#     valid_result.append(F1)
#     print_flush('*'*60)
#     print_flush('[epoch %d]. loss: %.6f time: %.1f s'%(epoch+1, np.exp(loss_valid), time()-epoch_begin))
#     print_flush('*'*60)
    print_flush("testing....")
    test_loader.epoch_init(1, config['diaglen'], 1, shuffle=False)
#     loss_valid = valid(test_loader)
#     print_flush('*'*60)
#     print_flush('[epoch %d]. loss: %.6f time: %.1f s'%(epoch+1, np.exp(loss_valid), time()-epoch_begin))
#     print_flush('*'*60)
    recall_bleu, prec_bleu, bow_extrema, bow_avg, bow_greedy, intra_dist1, intra_dist2, avg_len, inter_dist1, inter_dist2\
     =evaluate(model, metrics, test_loader, vocab, ivocab, repeat=10)
    epoch_begin = time()
#     if F1 > max_metric:
#         best_state = model.state_dict()
#         max_metric = F1
#         print_flush("save model...")
#         torch.save(best_state, '../datasets/models/baseline_LSTM.pth')
#     epoch_begin = time()
#     if training_termination(valid_result):
#         print_flush("early stop at [%d] epoch!" % (epoch+1))
#         break


Epoch:  1
Train begins with 609 batches with 110 left over samples


AttributeError: 'ContextEncoder' object has no attribute 'rnn'

In [92]:
# 与dialog_doublegan 采用相同的数据集大小
def valid_small(valid_loader):
    model.eval()
    total_loss = 0.0
    total_valid_batch = 0
    valid_count = 0
    with torch.no_grad():
        while True:
            batch = valid_loader.next_batch()
            if batch is None or total_valid_batch >= 1500: # end of epoch
                break
            total_valid_batch += 20
            valid_count += 1
            context, context_lens, utt_lens, floors,_,_,_,response,res_lens,_ = batch
            context, utt_lens = context[:,:,1:], utt_lens-1 # remove the sos token in the context and reduce the context length
            context, context_lens, utt_lens, floors, response, res_lens\
                    = gVar(context), gVar(context_lens), gVar(utt_lens), gData(floors), gVar(response), gVar(res_lens)
            target, outputs = model(context, context_lens, utt_lens, floors, response, res_lens)
            loss_batch = criterion(outputs, target)
            total_loss += float(loss_batch.item())
        return total_loss / valid_count    
    
def sample(context, context_lens, utt_lens, floors, repeat, SOS_tok, EOS_tok):    
    model.eval()
    c = model.context_encoder(context, context_lens, utt_lens, floors)
#     c_repeated = c.expand(repeat, -1)
    sample_words, sample_lens= model.decoder.sampling(c, None, config['maxlen'], SOS_tok, EOS_tok, "greedy")
    return sample_words, sample_lens 

def evaluate(model, metrics, test_loader, vocab, ivocab, repeat):
    recall_bleus, prec_bleus, bows_extrema, bows_avg, bows_greedy, intra_dist1s, intra_dist2s, avg_lens, inter_dist1s, inter_dist2s\
        = [], [], [], [], [], [], [], [], [], []
    local_t = 0
    test_loader.epoch_init(1, config['diaglen'], 1, shuffle=False)
    valid_count = 0
    begin_time = time()
    all_generated_sentences = []
    all_generated_lens = []
    while True:
        batch = test_loader.next_batch()
        if batch is None:
#         if batch is None or valid_count >= 400:
            break
        valid_count += 1
        local_t += 1 
        context, context_lens, utt_lens, floors,_,_,_,response,res_lens,_ = batch   
        context, utt_lens = context[:,:,1:], utt_lens-1 # remove the sos token in the context and reduce the context length
#         f_eval.write("Batch %d \n" % (local_t))# print the context
        start = np.maximum(0, context_lens[0]-5)
        for t_id in range(start, context.shape[1], 1):
            context_str = indexes2sent(context[0, t_id], vocab, vocab["</s>"], 0)
#             f_eval.write("Context %d-%d: %s\n" % (t_id, floors[0, t_id], context_str))
        # print the true outputs    
        ref_str, _ = indexes2sent(response[0], vocab, vocab["</s>"], vocab["<s>"])
        ref_tokens = ref_str.split(' ')
#         f_eval.write("Target >> %s\n" % (ref_str.replace(" ' ", "'")))
        context, context_lens, utt_lens, floors = gVar(context), gVar(context_lens), gVar(utt_lens), gData(floors)
        sample_words, sample_lens = sample(context, context_lens, utt_lens, floors, repeat, vocab["<s>"], vocab["</s>"])
        # 存储所有生成的回复，用来计算div
        all_generated_sentences.append(sample_words[0].tolist())
        all_generated_lens.append(sample_lens[0].tolist())
        # nparray: [repeat x seq_len]
        pred_sents, _ = indexes2sent(sample_words, vocab, vocab["</s>"], 0)
        if valid_count % 300 == 0:
            print('true response: ', ref_str)
            print('generate response: ', pred_sents[0])
        pred_tokens = [sent.split(' ') for sent in pred_sents]
#         for r_id, pred_sent in enumerate(pred_sents):
#             f_eval.write("Sample %d >> %s\n" % (r_id, pred_sent.replace(" ' ", "'")))
        max_bleu, avg_bleu = metrics.sim_bleu(pred_tokens, ref_tokens)
        recall_bleus.append(max_bleu)
        prec_bleus.append(avg_bleu)
        bow_extrema, bow_avg, bow_greedy = metrics.sim_bow(sample_words, sample_lens, response[:,1:], res_lens-2)
        bows_extrema.append(bow_extrema)
        bows_avg.append(bow_avg)
        bows_greedy.append(bow_greedy)
#         intra_dist1, intra_dist2, inter_dist1, inter_dist2 = metrics.div_distinct(sample_words, sample_lens-1)
#         intra_dist1s.append(intra_dist1)
#         intra_dist2s.append(intra_dist2)
        avg_lens.append(np.mean(sample_lens))
#         inter_dist1s.append(inter_dist1)
#         inter_dist2s.append(inter_dist2)
    recall_bleu = float(np.mean(recall_bleus))
    prec_bleu = float(np.mean(prec_bleus))
    f1 = 2*(prec_bleu*recall_bleu) / (prec_bleu+recall_bleu+10e-12)
    bow_extrema = float(np.mean(bows_extrema))
    bow_avg = float(np.mean(bows_avg))
    bow_greedy=float(np.mean(bows_greedy))
#     intra_dist1=float(np.mean(intra_dist1s))
#     intra_dist2=float(np.mean(intra_dist2s))
    avg_len=float(np.mean(avg_lens))
    all_generated_sentences = np.array(all_generated_sentences)
    all_generated_lens = np.array(all_generated_lens)
#     print(all_generated_sentences[:5])
#     print(all_generated_lens[:5])
    intra_dist1, intra_dist2, inter_dist1, inter_dist2 = metrics.div_distinct(all_generated_sentences, all_generated_lens)
#     inter_dist1=float(np.mean(inter_dist1s))
#     inter_dist2=float(np.mean(inter_dist2s))
    report = "Avg recall BLEU %f, bow_extrema %f, bow_avg %f, bow_greedy %f, inter_dist1 %f, inter_dist2 %f avg_len %f" \
    % (recall_bleu, bow_extrema, bow_avg, bow_greedy, inter_dist1, inter_dist2, avg_len)
    print(report)
    print(' time: %.1f s'%(time()-begin_time))
#     f_eval.write(report + "\n")
    print("Done testing")
    return recall_bleu, bow_extrema, bow_avg, bow_greedy, inter_dist1, inter_dist2, avg_len

In [118]:
metrics = Metrics(corpus.word2vec)
model = HRED(config, n_tokens)
if corpus.word2vec is not None:
    print("Loaded word2vec")
    model.embedder.weight.data.copy_(torch.from_numpy(corpus.word2vec))
    model.embedder.weight.data[0].fill_(0)
model.to(DEVICE)
parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(parameters, lr=config['lr'])

model.zero_grad()
print_every = 50
best_state = None
max_metric = 0
for epoch in range(10):
    print('Epoch: ', epoch+1)
    train_loader.epoch_init(32, config['diaglen'], 1, shuffle=False)
    n_iters=train_loader.num_batch
    total_loss = 0.0
    epoch_begin = time()
    batch_count = 0
    batch_begin_time = time()
    total_train_batch = 0 # 记录训练的样本数量
    total_valid_batch = 0 # 记录测试的样本数量
    # 分别用来记录训练时候，生成器最顶层的梯度，最底层的梯度以及判别器最顶层的梯度
    train_grad_G_top_layer = []
    train_grad_G_bottom_layer = []
    train_grad_D_top_layer = []
    while True:
        loss_records=[]
        batch = train_loader.next_batch()
        total_train_batch += 32
        if batch is None:
#         if batch is None or total_train_batch >= 1000: # end of epoch
            break
        context, context_lens, utt_lens, floors,_,_,_,response,res_lens,_ = batch
        context, utt_lens = context[:,:,1:], utt_lens-1 # remove the sos token in the context and reduce the context length
        context, context_lens, utt_lens, floors, response, res_lens\
                = gVar(context), gVar(context_lens), gVar(utt_lens), gData(floors), gVar(response), gVar(res_lens)
        loss_batch = train(context, context_lens, utt_lens, floors, response, res_lens)
        total_loss += float(loss_batch)
        batch_count += 1
        if batch_count % print_every == 0:
            print_flush('[%d %d] loss: %.6f time: %.1f s' %
                  (epoch + 1, batch_count, np.exp(total_loss / print_every), time() - batch_begin_time))
            total_loss = 0.0
            batch_begin_time = time()
#         train_grad_G_top_layer.append(torch.mean(model.decoder.rnn.weight_hh_l0.grad))
#         train_grad_G_bottom_layer.append(torch.mean(model.generator.embedding.weight.grad))
#         train_grad_G_bottom_layer.append(torch.mean(model.decoder.rnn.weight_hh_l0.grad))
#     plot_gradient(train_grad_G_top_layer, 'G top layer')
#     plot_gradient(train_grad_G_bottom_layer, 'G bottom layer')
    print_flush("Evaluating....")
    valid_loader.epoch_init(20, config['diaglen'], 1, shuffle=False)
    loss_valid = valid_small(valid_loader)
#     valid_result.append(F1)
    print_flush('*'*60)
    print_flush('[epoch %d]. loss: %.6f time: %.1f s'%(epoch+1, np.exp(loss_valid), time()-epoch_begin))
    print_flush('*'*60)
#     print_flush("testing....")
#     test_loader.epoch_init(1, config['diaglen'], 1, shuffle=False)
#     loss_valid = valid(test_loader)
#     print_flush('*'*60)
#     print_flush('[epoch %d]. loss: %.6f time: %.1f s'%(epoch+1, np.exp(loss_valid), time()-epoch_begin))
#     print_flush('*'*60)
    recall_bleu, bow_extrema, bow_avg, bow_greedy, inter_dist1, inter_dist2, avg_len\
     =evaluate(model, metrics, test_loader, vocab, ivocab, repeat=10)
    epoch_begin = time()
#     if F1 > max_metric:
#         best_state = model.state_dict()
#         max_metric = F1
#         print_flush("save model...")
#         torch.save(best_state, '../datasets/models/baseline_LSTM.pth')
#     epoch_begin = time()
#     if training_termination(valid_result):
#         print_flush("early stop at [%d] epoch!" % (epoch+1))
#         break

Loaded word2vec
Epoch:  1
Train begins with 2393 batches with 14 left over samples
[1 50] loss: 405.543142 time: 1.4 s
[1 100] loss: 146.169661 time: 1.4 s
[1 150] loss: 116.324097 time: 1.4 s
[1 200] loss: 94.348775 time: 1.3 s
[1 250] loss: 90.524827 time: 1.3 s
[1 300] loss: 74.203323 time: 1.3 s
[1 350] loss: 69.581943 time: 1.3 s
[1 400] loss: 64.959710 time: 1.3 s
[1 450] loss: 66.453394 time: 1.4 s
[1 500] loss: 63.519073 time: 1.3 s
[1 550] loss: 52.460902 time: 1.3 s
[1 600] loss: 48.193870 time: 1.3 s
[1 650] loss: 46.085603 time: 1.3 s
[1 700] loss: 42.894215 time: 1.3 s
[1 750] loss: 42.297767 time: 1.3 s
[1 800] loss: 42.600828 time: 1.3 s
[1 850] loss: 45.933147 time: 1.3 s
[1 900] loss: 42.998136 time: 1.3 s
[1 950] loss: 50.568986 time: 1.2 s
[1 1000] loss: 51.189504 time: 1.3 s
[1 1050] loss: 49.047964 time: 1.3 s
[1 1100] loss: 49.000170 time: 1.3 s
[1 1150] loss: 47.199214 time: 1.3 s
[1 1200] loss: 44.592601 time: 1.3 s
[1 1250] loss: 42.769137 time: 1.2 s
[1 1300] 

true response:  it doesn ' t matter . it happens to everyone . </s>
generate response:  i ' m sorry . i ' m sorry . i ' m not sure . </s>
true response:  well , it ' s specially made of a platinum <unk> , sir , and the <unk> will </s>
generate response:  i ' m afraid i ' m not sure . </s>
true response:  no , that won ’ t be necessary . </s>
generate response:  thanks a lot . </s>
true response:  let me give you a little advice . you know what the <unk> are going to be , </s>
generate response:  i ' m sorry . i ' m not sure . i ' m not sure . </s>
true response:  well , it looks like his encouragement paid off rebecca . so how about extracurricular activities at university </s>
generate response:  i ’ m sorry . i ’ m not sure . i ’ m not sure . </s>
true response:  i love dogs . they have been used as guards for centuries . nowadays , they are often </s>
generate response:  i think you ' ll be able to make a good job . </s>
true response:  <unk> , yes . there ’ s a big cavity . </s>
ge

[4 750] loss: 13.653819 time: 1.3 s
[4 800] loss: 13.639791 time: 1.3 s
[4 850] loss: 15.119535 time: 1.1 s
[4 900] loss: 13.558507 time: 1.3 s
[4 950] loss: 16.836984 time: 1.2 s
[4 1000] loss: 15.739160 time: 1.3 s
[4 1050] loss: 15.593383 time: 1.2 s
[4 1100] loss: 16.648988 time: 1.1 s
[4 1150] loss: 16.295123 time: 1.2 s
[4 1200] loss: 15.301578 time: 1.0 s
[4 1250] loss: 15.811700 time: 1.3 s
[4 1300] loss: 15.399139 time: 1.2 s
[4 1350] loss: 16.784759 time: 1.2 s
[4 1400] loss: 15.718522 time: 1.1 s
[4 1450] loss: 14.225892 time: 1.0 s
[4 1500] loss: 14.575527 time: 1.2 s
[4 1550] loss: 15.749287 time: 1.1 s
[4 1600] loss: 15.731176 time: 1.0 s
[4 1650] loss: 15.885844 time: 1.2 s
[4 1700] loss: 15.718367 time: 1.1 s
[4 1750] loss: 15.748219 time: 1.1 s
[4 1800] loss: 14.918579 time: 1.2 s
[4 1850] loss: 14.161859 time: 1.2 s
[4 1900] loss: 16.761000 time: 0.9 s
[4 1950] loss: 16.723910 time: 1.1 s
[4 2000] loss: 15.300872 time: 1.0 s
[4 2050] loss: 14.607748 time: 1.1 s
[4 210

true response:  i love dogs . they have been used as guards for centuries . nowadays , they are often </s>
generate response:  i agree . i think you ’ re beating around the bush with the air so much </s>
true response:  <unk> , yes . there ’ s a big cavity . </s>
generate response:  i ' m sorry . i didn ' t realize that . </s>
true response:  sorry , he is out . </s>
generate response:  i ’ m sorry , but i ’ m not booked up . </s>
true response:  good morning . i ’ m here to see <unk> davis , the human resources manager . </s>
generate response:  yes . i ' d like to buy a suit . </s>
true response:  yes , stones . </s>
generate response:  yes , of course . but i think they are a good idea . </s>
true response:  a :: tim , please . please be seated . </s>
generate response:  i ' m phoning about your work . </s>
true response:  we ' d like to invite you for our dress party tomorrow evening , are you free ? </s>
generate response:  ok . i ' ll take it . </s>
true response:  look , that ba

[7 1950] loss: 10.950815 time: 1.1 s
[7 2000] loss: 10.172276 time: 1.1 s
[7 2050] loss: 9.740996 time: 1.1 s
[7 2100] loss: 10.171589 time: 1.1 s
[7 2150] loss: 9.203370 time: 1.0 s
[7 2200] loss: 9.808709 time: 1.0 s
[7 2250] loss: 9.611292 time: 1.0 s
[7 2300] loss: 9.199462 time: 1.0 s
[7 2350] loss: 9.430979 time: 1.0 s
Evaluating....
Valid begins with 367 batches with 0 left over samples
************************************************************
[epoch 7]. loss: 26.197943 time: 58.5 s
************************************************************
Test begins with 6740 batches with 0 left over samples
true response:  i ' m sorry , but mr . johnson is out at the moment . can i take </s>
generate response:  yes , mr . smith . </s>
true response:  nonsmoking , please . </s>
generate response:  yes , i ' ll have a look . </s>
true response:  well , i wanted to let you know that i ' ve put in my notice . </s>
generate response:  i ' m going to the pub tonight . </s>
true response:  can

true response:  look , that bamboo <unk> are <unk> . let ' s go and play there . </s>
generate response:  i ' m going to the pub tonight . </s>
true response:  christine , i know you ’ re new here and there ’ s a lot to learn , </s>
generate response:  oh , that ’ s really unbearable . </s>
true response:  it looks like the shipping company did this . </s>
generate response:  i ' m sorry , but i ' m not sure . </s>
Avg recall BLEU 0.251015, bow_extrema 0.498272, bow_avg 0.904194, bow_greedy 0.797956, inter_dist1 0.016894, inter_dist2 0.060250 avg_len 9.941543
 time: 137.2 s
Done testing
Epoch:  9
Train begins with 2393 batches with 14 left over samples
[9 50] loss: 9.519946 time: 1.4 s
[9 100] loss: 8.996950 time: 1.3 s
[9 150] loss: 9.067976 time: 1.3 s
[9 200] loss: 8.890422 time: 1.4 s
[9 250] loss: 10.235870 time: 1.3 s
[9 300] loss: 8.906806 time: 1.4 s
[9 350] loss: 9.239548 time: 1.3 s
[9 400] loss: 8.668331 time: 1.3 s
[9 450] loss: 9.069580 time: 1.4 s
[9 500] loss: 8.999891 t

true response:  you will need to pay late fees on these books . </s>
generate response:  i ’ ll take it . </s>
true response:  that sounds nice . </s>
generate response:  thanks a lot . </s>
true response:  i can ’ t wait to vote . </s>
generate response:  i ’ m sorry . i ’ ve got to go . </s>
true response:  it doesn ' t matter . it happens to everyone . </s>
generate response:  i ' m sorry . i didn ' t know you . </s>
true response:  well , it ' s specially made of a platinum <unk> , sir , and the <unk> will </s>
generate response:  it ' s a <unk> . </s>
true response:  no , that won ’ t be necessary . </s>
generate response:  thank you . i ’ ll have a look . </s>
true response:  let me give you a little advice . you know what the <unk> are going to be , </s>
generate response:  why ? </s>
true response:  well , it looks like his encouragement paid off rebecca . so how about extracurricular activities at university </s>
generate response:  i ’ m sorry . i ’ ve got to go now . </s>
tr