In [2]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.init as weight_init
import numpy as np
from time import time
import data
import sys
from metrics import Metrics
import matplotlib.pyplot as plt
%matplotlib inline
DEVICE = torch.device('cuda:0')


In [3]:
def config_HRED():
    conf = {
    'maxlen':30, # maximum utterance length
    'diaglen':10, # how many utterance kept in the context window

# Model Arguments
    'emb_size':200, # size of word embeddings
    'n_hidden':300,
    'n_hidden_utter_encode':300, # number of hidden units of utterance encoder
    'n_hidden_context_encode':300,
    'n_hidden_decode':300,
    'n_layers':1, # number of layers
    'noise_radius':0.2, # stdev of noise for autoencoder (regularizer)
    'lambda_gp':10, # Gradient penalty lambda hyperparameter.
    'temp':1.0, # softmax temperature (lower --> more discrete)
    'dropout':0.5, # dropout applied to layers (0 = no dropout)

# Training Arguments
    'batch_size':30,
    'epochs':100, # maximum number of epochs
    'min_epochs':2, # minimum number of epochs to train for

    'lr':0.001, # autoencoder learning rate
    'beta1':0.9, # beta1 for adam
    'clip':1.0,  # gradient clipping, max norm
    'gan_clamp':0.01,  # WGAN clamp (Do not use clamp when you apply gradient penelty             
    }
    return conf 

config = config_HRED()

In [4]:
def gData(data):
    tensor=data
    if isinstance(data, np.ndarray):
        tensor = torch.from_numpy(data)
    tensor=tensor.to(DEVICE)
    return tensor
def gVar(data):
    return gData(data)
def print_flush(data, args=None):
    if args == None:
        print(data)
    else:
        print(data, args)
    sys.stdout.flush()
    
def indexes2sent(indexes, vocab, eos_tok, ignore_tok=0): 
    '''indexes: numpy array'''
    def revert_sent(indexes, ivocab, eos_tok, ignore_tok=0):
        toks=[]
        length=0
        indexes=filter(lambda i: i!=ignore_tok, indexes)
        for idx in indexes:
            toks.append(ivocab[idx])
            length+=1
            if idx == eos_tok:
                break
        return ' '.join(toks), length
    
    ivocab = {v: k for k, v in vocab.items()}
    if indexes.ndim==1:# one sentence
        return revert_sent(indexes, ivocab, eos_tok, ignore_tok)
    else:# dim>1
        sentences=[] # a batch of sentences
        lens=[]
        for inds in indexes:
            sentence, length = revert_sent(inds, ivocab, eos_tok, ignore_tok)
            sentences.append(sentence)
            lens.append(length)
        return sentences, lens

In [5]:
class Encoder(nn.Module):
    def __init__(self, embedder, input_size, hidden_size, bidirectional, n_layers, noise_radius=0.2):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.noise_radius=noise_radius
        self.n_layers = n_layers
        self.bidirectional = bidirectional
        assert type(self.bidirectional)==bool
        self.embedding = embedder
        self.rnn = nn.GRU(input_size, hidden_size, n_layers, batch_first=True, bidirectional=bidirectional)
        self.init_weights()
        
    def init_weights(self):
        for w in self.rnn.parameters(): 
            if w.dim()>1:
                weight_init.orthogonal_(w)
                
    def store_grad_norm(self, grad):
        norm = torch.norm(grad, 2, 1)
        self.grad_norm = norm.detach().data.mean()
        return grad
    
    def forward(self, inputs, input_lens=None, noise=False): 
        if self.embedding is not None:
            inputs=self.embedding(inputs) 
        
        batch_size, seq_len, emb_size=inputs.size()
#         inputs=F.dropout(inputs, 0.5, self.training)
        
        if input_lens is not None:
            input_lens_sorted, indices = input_lens.sort(descending=True)
            inputs_sorted = inputs.index_select(0, indices)        
            inputs = pack_padded_sequence(inputs_sorted, input_lens_sorted.data.tolist(), batch_first=True)
            
        init_hidden = gVar(torch.zeros(self.n_layers*(1+self.bidirectional), batch_size, self.hidden_size))
        hids, h_n = self.rnn(inputs, init_hidden) 
        if input_lens is not None:
            _, inv_indices = indices.sort()
            hids, lens = pad_packed_sequence(hids, batch_first=True)     
            hids = hids.index_select(0, inv_indices)
            h_n = h_n.index_select(1, inv_indices)
        h_n = h_n.view(self.n_layers, (1+self.bidirectional), batch_size, self.hidden_size)
        h_n = h_n[-1]
        enc = h_n.transpose(1,0).contiguous().view(batch_size,-1) 
#         if noise and self.noise_radius > 0:
#             gauss_noise = gVar(torch.normal(means=torch.zeros(enc.size()),std=self.noise_radius))
#             enc = enc + gauss_noise
            
        return enc, hids
    
class ContextEncoder(nn.Module):
    def __init__(self, utt_encoder, input_size, hidden_size, n_layers=1, noise_radius=0.2):
        super(ContextEncoder, self).__init__()
        self.hidden_size = hidden_size
        self.noise_radius=noise_radius
        
        self.n_layers = n_layers
        
        self.utt_encoder=utt_encoder
        self.rnn = nn.GRU(input_size, hidden_size, batch_first=True)
        self.init_weights()
        
    def init_weights(self):
        for w in self.rnn.parameters(): # initialize the gate weights with orthogonal
            if w.dim()>1:
                weight_init.orthogonal_(w)
    
    def store_grad_norm(self, grad):
        norm = torch.norm(grad, 2, 1)
        self.grad_norm = norm.detach().data.mean()
        return grad

    def forward(self, context, context_lens, utt_lens, floors, noise=False): 
        batch_size, max_context_len, max_utt_len = context.size()
        utts=context.view(-1, max_utt_len) 
        utt_lens=utt_lens.view(-1)
        utt_encs,_ = self.utt_encoder(utts, utt_lens) 
        utt_encs = utt_encs.view(batch_size, max_context_len, -1)
        floor_one_hot = gVar(torch.zeros(floors.numel(), 2))
        floor_one_hot.data.scatter_(1, floors.view(-1, 1), 1)
        floor_one_hot = floor_one_hot.view(-1, max_context_len, 2)
        utt_floor_encs = torch.cat([utt_encs, floor_one_hot], 2) 
        
#         utt_floor_encs=F.dropout(utt_floor_encs, 0.25, self.training)
        context_lens_sorted, indices = context_lens.sort(descending=True)
        utt_floor_encs = utt_floor_encs.index_select(0, indices)
        utt_floor_encs = pack_padded_sequence(utt_floor_encs, context_lens_sorted.data.tolist(), batch_first=True)
        
        init_hidden=gVar(torch.zeros(1, batch_size, self.hidden_size))
        hids, h_n = self.rnn(utt_floor_encs, init_hidden)
        
        _, inv_indices = indices.sort()
        h_n = h_n.index_select(1, inv_indices)  
        enc = h_n.transpose(1,0).contiguous().view(batch_size, -1)

#         if noise and self.noise_radius > 0:
#             gauss_noise = gVar(torch.normal(means=torch.zeros(enc.size()),std=self.noise_radius))
#             enc = enc + gauss_noise
        return enc


In [6]:
class Decoder(nn.Module):
    def __init__(self, embedder, input_size, hidden_size, vocab_size, n_layers=1):
        super(Decoder, self).__init__()
        self.n_layers = n_layers
        self.input_size= input_size 
        self.hidden_size = hidden_size 
        self.vocab_size = vocab_size 
        self.embedding = embedder
#         self.linear = nn.Linear(600, hidden_size)
        self.rnn = nn.GRU(input_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, vocab_size)
        self.init_weights()
        
    def init_weights(self):
        initrange = 0.1
        for w in self.rnn.parameters():
            if w.dim()>1:
                weight_init.orthogonal_(w)
        self.out.weight.data.uniform_(-initrange, initrange)
        self.out.bias.data.fill_(0)
    
    def forward(self, init_hidden, context=None, inputs=None, lens=None):
        batch_size, maxlen = inputs.size()
        if self.embedding is not None:
            inputs = self.embedding(inputs)
        if context is not None:
            repeated_context = context.unsqueeze(1).repeat(1, maxlen, 1)
            inputs = torch.cat([inputs, repeated_context], 2)
#         inputs = F.dropout(inputs, 0.5, self.training)  ß
#         init_hidden = self.linear(init_hidden)
        hids, h_n = self.rnn(inputs, init_hidden.unsqueeze(0))
        decoded = self.out(hids.contiguous().view(-1, self.hidden_size))# reshape before linear over vocab
        decoded = decoded.view(batch_size, maxlen, self.vocab_size)
        return decoded
    
    def sampling(self, init_hidden, context, maxlen, SOS_tok, EOS_tok, mode='greedy'):
        batch_size=init_hidden.size(0)
#         init_hidden = self.linear(init_hidden)
        decoded_words = np.zeros((batch_size, maxlen), dtype=np.int)
        sample_lens=np.zeros(batch_size, dtype=np.int)         
        decoder_input = gVar(torch.LongTensor([[SOS_tok]*batch_size]).view(batch_size,1))
        decoder_input = self.embedding(decoder_input) if self.embedding is not None else decoder_input 
        decoder_input = torch.cat([decoder_input, context.unsqueeze(1)],2) if context is not None else decoder_input
        decoder_hidden = init_hidden.unsqueeze(0).contiguous()
        for di in range(maxlen):
            decoder_output, decoder_hidden = self.rnn(decoder_input, decoder_hidden)
            decoder_output=self.out(decoder_output)
            if mode=='greedy':
                topi = decoder_output[:,-1].max(1, keepdim=True)[1] 
            elif mode=='sample':
                topi = torch.multinomial(F.softmax(decoder_output[:,-1], dim=1), 1)                    
            decoder_input = self.embedding(topi) if self.embedding is not None else topi
            decoder_input = torch.cat([decoder_input, context.unsqueeze(1)],2) if context is not None else decoder_input
            ni = topi.squeeze().data.cpu().numpy() 
            decoded_words[:,di]=ni
                      
        for i in range(batch_size):
            for word in decoded_words[i]:
                if word == EOS_tok:
                    break
                sample_lens[i]=sample_lens[i]+1
        return decoded_words, sample_lens

In [7]:
class HRED(nn.Module):
    def __init__(self, config, vocab_size, PAD_token=0):
        super(HRED, self).__init__()
        self.vocab_size = vocab_size
        self.maxlen=config['maxlen']
        self.clip = config['clip']
        self.lambda_gp = config['lambda_gp']
        self.temp=config['temp']
        
        self.embedder= nn.Embedding(vocab_size, config['emb_size'], padding_idx=PAD_token)
        self.utt_encoder = Encoder(self.embedder, config['emb_size'], config['n_hidden_utter_encode'], 
                                   True, config['n_layers'], config['noise_radius']) 
        self.context_encoder = ContextEncoder(self.utt_encoder, config['n_hidden_utter_encode']*2+2, config['n_hidden_context_encode'], 1, config['noise_radius']) 
        self.decoder = Decoder(self.embedder, config['emb_size'], config['n_hidden_decode'], 
                               vocab_size, n_layers=1) 
        
    def forward(self, context, context_lens, utt_lens, floors, response, res_lens):
        c = self.context_encoder(context, context_lens, utt_lens, floors)
        x,_ = self.utt_encoder(response[:,1:], res_lens-1)      
        output = self.decoder(c, None, response[:,:-1], (res_lens-1))  
        flattened_output = output.view(-1, self.vocab_size) 
        
        dec_target = response[:,1:].contiguous().view(-1)
        mask = dec_target.gt(0) # [(batch_sz*seq_len)]
        masked_target = dec_target.masked_select(mask) # 
        output_mask = mask.unsqueeze(1).expand(mask.size(0), self.vocab_size)# [(batch_sz*seq_len) x n_tokens]
        masked_output = flattened_output.masked_select(output_mask).view(-1, self.vocab_size)

        return masked_target, masked_output
        

In [8]:
def train(context, context_lens, utt_lens, floors, response, res_lens):
    model.train()
    target, outputs = model(context, context_lens, utt_lens, floors, response, res_lens)
    optimizer.zero_grad()
    loss = criterion(outputs, target)
    batch_loss = loss.item()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return batch_loss
        
def evaluate_batch(context, context_lens, utt_lens, floors, response, res_lens):
    model.eval()
    target, outputs = model(context, context_lens, utt_lens, floors, response, res_lens)
    loss = criterion(outputs, target)
    return loss.item()    

def valid(valid_loader):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        while True:
            batch = valid_loader.next_batch()
            if batch is None: # end of epoch
                break
            context, context_lens, utt_lens, floors,_,_,_,response,res_lens,_ = batch
            context, utt_lens = context[:,:,1:], utt_lens-1 # remove the sos token in the context and reduce the context length
            context, context_lens, utt_lens, floors, response, res_lens\
                    = gVar(context), gVar(context_lens), gVar(utt_lens), gData(floors), gVar(response), gVar(res_lens)
            target, outputs = model(context, context_lens, utt_lens, floors, response, res_lens)
            loss_batch = criterion(outputs, target)
            total_loss += float(loss_batch.item())
    return total_loss / valid_loader.num_batch

def sample(context, context_lens, utt_lens, floors, repeat, SOS_tok, EOS_tok):    
    model.eval()
    c = model.context_encoder(context, context_lens, utt_lens, floors)
    c_repeated = c.expand(repeat, -1)
#     prior_z = self.sample_code_prior(c_repeated)    
    sample_words, sample_lens= model.decoder.sampling(c_repeated, 
                                                     None, config['maxlen'], SOS_tok, EOS_tok, "sample") 
    return sample_words, sample_lens 
    
    
def evaluate(model, metrics, test_loader, vocab, ivocab, repeat):
    
    recall_bleus, prec_bleus, bows_extrema, bows_avg, bows_greedy, intra_dist1s, intra_dist2s, avg_lens, inter_dist1s, inter_dist2s\
        = [], [], [], [], [], [], [], [], [], []
    local_t = 0
    while True:
        batch = test_loader.next_batch()
        if batch is None:
            break
        local_t += 1 
        context, context_lens, utt_lens, floors,_,_,_,response,res_lens,_ = batch   
        context, utt_lens = context[:,:,1:], utt_lens-1 # remove the sos token in the context and reduce the context length
#         f_eval.write("Batch %d \n" % (local_t))# print the context
        start = np.maximum(0, context_lens[0]-5)
        for t_id in range(start, context.shape[1], 1):
            context_str = indexes2sent(context[0, t_id], vocab, vocab["</s>"], 0)
#             f_eval.write("Context %d-%d: %s\n" % (t_id, floors[0, t_id], context_str))
        # print the true outputs    
        ref_str, _ = indexes2sent(response[0], vocab, vocab["</s>"], vocab["<s>"])
        ref_tokens = ref_str.split(' ')
#         f_eval.write("Target >> %s\n" % (ref_str.replace(" ' ", "'")))
        
        context, context_lens, utt_lens, floors = gVar(context), gVar(context_lens), gVar(utt_lens), gData(floors)
        sample_words, sample_lens = sample(context, context_lens, utt_lens, floors, repeat, vocab["<s>"], vocab["</s>"])
        # nparray: [repeat x seq_len]
        pred_sents, _ = indexes2sent(sample_words, vocab, vocab["</s>"], 0)
        pred_tokens = [sent.split(' ') for sent in pred_sents]
#         for r_id, pred_sent in enumerate(pred_sents):
#             f_eval.write("Sample %d >> %s\n" % (r_id, pred_sent.replace(" ' ", "'")))
        max_bleu, avg_bleu = metrics.sim_bleu(pred_tokens, ref_tokens)
        recall_bleus.append(max_bleu)
        prec_bleus.append(avg_bleu)
        
        bow_extrema, bow_avg, bow_greedy = metrics.sim_bow(sample_words, sample_lens, response[:,1:], res_lens-2)
        bows_extrema.append(bow_extrema)
        bows_avg.append(bow_avg)
        bows_greedy.append(bow_greedy)
        
        intra_dist1, intra_dist2, inter_dist1, inter_dist2 = metrics.div_distinct(sample_words, sample_lens)
        intra_dist1s.append(intra_dist1)
        intra_dist2s.append(intra_dist2)
        avg_lens.append(np.mean(sample_lens))
        inter_dist1s.append(inter_dist1)
        inter_dist2s.append(inter_dist2)
        break
    recall_bleu = float(np.mean(recall_bleus))
    prec_bleu = float(np.mean(prec_bleus))
    f1 = 2*(prec_bleu*recall_bleu) / (prec_bleu+recall_bleu+10e-12)
    bow_extrema = float(np.mean(bows_extrema))
    bow_avg = float(np.mean(bows_avg))
    bow_greedy=float(np.mean(bows_greedy))
    intra_dist1=float(np.mean(intra_dist1s))
    intra_dist2=float(np.mean(intra_dist2s))
    avg_len=float(np.mean(avg_lens))
    inter_dist1=float(np.mean(inter_dist1s))
    inter_dist2=float(np.mean(inter_dist2s))
    report = "Avg recall BLEU %f, avg precision BLEU %f, F1 %f, bow_extrema %f, bow_avg %f, bow_greedy %f,\
    intra_dist1 %f, intra_dist2 %f, avg_len %f, inter_dist1 %f, inter_dist2 %f (only 1 ref, not final results)" \
    % (recall_bleu, prec_bleu, f1, bow_extrema, bow_avg, bow_greedy, intra_dist1, intra_dist2, avg_len, inter_dist1, inter_dist2)
    print(report)
    print(' time: %.1f s'%(time()-epoch_begin))
#     f_eval.write(report + "\n")
    print("Done testing")
    return recall_bleu, prec_bleu, bow_extrema, bow_avg, bow_greedy, intra_dist1, intra_dist2, avg_len, inter_dist1, inter_dist2


In [11]:
corpus = getattr(data, 'SWDA'+'Corpus')('../datasets/SWDA/', wordvec_path='../datasets/'+'glove.twitter.27B.200d.txt', wordvec_dim=config['emb_size'])
dials = corpus.get_dialogs()
metas = corpus.get_metas()
vocab = corpus.ivocab
ivocab = corpus.vocab
n_tokens = len(ivocab)
train_dial, valid_dial, test_dial = dials.get("train"), dials.get("valid"), dials.get("test")
train_meta, valid_meta, test_meta = metas.get("train"), metas.get("valid"), metas.get("test")
train_loader = getattr(data, 'SWDA'+'DataLoader')("Train", train_dial, train_meta, config['maxlen'])
valid_loader = getattr(data, 'SWDA'+'DataLoader')("Valid", valid_dial, valid_meta, config['maxlen'])
test_loader = getattr(data, 'SWDA'+'DataLoader')("Test", test_dial, test_meta, config['maxlen'])

Max utt len 96, mean utt len 14.69
Max utt len 75, mean utt len 15.06
Max utt len 74, mean utt len 15.39
Load corpus with train size 3, valid size 3, test size 3 raw vocab size 24497 vocab size 10000 at cut_off 4 OOV rate 0.008035
<d> index 143
<sil> index -1
67 topics in train data
['statement-non-opinion', 'acknowledge_(backchannel)', 'statement-opinion', 'abandoned_or_turn-exit/uninterpretable', 'yes-no-question', 'agree/accept', 'appreciation', 'wh-question', 'backchannel_in_question_form', 'yes_answers', 'conventional-closing', 'response_acknowledgement', 'open-question', 'no_answers', 'affirmative_non-yes_answers', 'declarative_yes-no-question', 'summarize/reformulate', 'other', 'action-directive', 'rhetorical-questions', 'conventional-opening', 'collaborative_completion', 'signal-non-understanding', 'or-clause', 'hold_before_answer/agreement', 'quotation', 'negative_non-no_answers', 'self-talk', 'apology', 'dispreferred_answers', 'offers,_options_commits', 'other_answers', 'reje

In [11]:
corpus = getattr(data, 'DailyDial'+'Corpus')('../datasets/DailyDial/', wordvec_path='../datasets/'+'glove.twitter.27B.200d.txt', wordvec_dim=config['emb_size'])
dials = corpus.get_dialogs()
metas = corpus.get_metas()
vocab = corpus.ivocab
ivocab = corpus.vocab
n_tokens = len(ivocab)
train_dial, valid_dial, test_dial = dials.get("train"), dials.get("valid"), dials.get("test")
train_meta, valid_meta, test_meta = metas.get("train"), metas.get("valid"), metas.get("test")
train_loader = getattr(data, 'DailyDial'+'DataLoader')("Train", train_dial, train_meta, config['maxlen'])
valid_loader = getattr(data, 'DailyDial'+'DataLoader')("Valid", valid_dial, valid_meta, config['maxlen'])
test_loader = getattr(data, 'DailyDial'+'DataLoader')("Test", test_dial, test_meta, config['maxlen'])


Max utt len 296, mean utt len 16.48
Max utt len 174, mean utt len 16.37
Max utt len 214, mean utt len 16.68
Load corpus with train size 2, valid size 2, test size 2 raw vocab size 17716 vocab size 10000 at cut_off 2 OOV rate 0.006757
<d> index 21
<sil> index -1
word2vec cannot cover 0.032194 vocab
Done loading corpus
Max len 36 and min len 3 and avg len 8.840439
Max len 32 and min len 3 and avg len 9.069000
Max len 27 and min len 3 and avg len 8.740000


In [12]:
metrics=Metrics(corpus.word2vec)
model = HRED(config, n_tokens)
if corpus.word2vec is not None:
    print("Loaded word2vec")
    model.embedder.weight.data.copy_(torch.from_numpy(corpus.word2vec))
    model.embedder.weight.data[0].fill_(0)
model.to(DEVICE)
parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(parameters, lr=config['lr'])

Loaded word2vec


In [None]:
model.zero_grad()
print_every = 100
best_state = None
max_metric = 0
for epoch in range(config['epochs']):
    print('Epoch: ', epoch+1)
    train_loader.epoch_init(128, config['diaglen'], 1, shuffle=True)
    n_iters=train_loader.num_batch
    total_loss = 0.0
    epoch_begin = time()
    batch_count = 0
    batch_begin_time = time()
    while True:
        loss_records=[]
        batch = train_loader.next_batch()
        if batch is None: # end of epoch
            break
        context, context_lens, utt_lens, floors,_,_,_,response,res_lens,_ = batch
        context, utt_lens = context[:,:,1:], utt_lens-1 # remove the sos token in the context and reduce the context length
        context, context_lens, utt_lens, floors, response, res_lens\
                = gVar(context), gVar(context_lens), gVar(utt_lens), gData(floors), gVar(response), gVar(res_lens)
        loss_batch = train(context, context_lens, utt_lens, floors, response, res_lens)
        total_loss += float(loss_batch)
        batch_count += 1
        if batch_count % print_every == 0:
            print_flush('[%d %d] loss: %.6f time: %.1f s' %
                  (epoch + 1, batch_count, np.exp(total_loss / print_every), time() - batch_begin_time))
            total_loss = 0.0
            batch_begin_time = time()
#     scheduler.step()
#     print_flush("Evaluating....")
#     valid_loader.epoch_init(32, config['diaglen'], 1, shuffle=False)
#     loss_valid = valid(valid_loader)
#     valid_result.append(F1)
#     print_flush('*'*60)
#     print_flush('[epoch %d]. loss: %.6f time: %.1f s'%(epoch+1, np.exp(loss_valid), time()-epoch_begin))
#     print_flush('*'*60)
    print_flush("testing....")
    test_loader.epoch_init(1, config['diaglen'], 1, shuffle=False)
#     loss_valid = valid(test_loader)
#     print_flush('*'*60)
#     print_flush('[epoch %d]. loss: %.6f time: %.1f s'%(epoch+1, np.exp(loss_valid), time()-epoch_begin))
#     print_flush('*'*60)
    recall_bleu, prec_bleu, bow_extrema, bow_avg, bow_greedy, intra_dist1, intra_dist2, avg_len, inter_dist1, inter_dist2\
     =evaluate(model, metrics, test_loader, vocab, ivocab, repeat=10)
    epoch_begin = time()
#     if F1 > max_metric:
#         best_state = model.state_dict()
#         max_metric = F1
#         print_flush("save model...")
#         torch.save(best_state, '../datasets/models/baseline_LSTM.pth')
#     epoch_begin = time()
#     if training_termination(valid_result):
#         print_flush("early stop at [%d] epoch!" % (epoch+1))
#         break


In [13]:
# 与dialog_doublegan 采用相同的数据集大小
def valid_small(valid_loader):
    model.eval()
    total_loss = 0.0
    total_valid_batch = 0
    valid_count = 0
    with torch.no_grad():
        while True:
            batch = valid_loader.next_batch()
            if batch is None or total_valid_batch >= 1500: # end of epoch
                break
            total_valid_batch += 20
            valid_count += 1
            context, context_lens, utt_lens, floors,_,_,_,response,res_lens,_ = batch
            context, utt_lens = context[:,:,1:], utt_lens-1 # remove the sos token in the context and reduce the context length
            context, context_lens, utt_lens, floors, response, res_lens\
                    = gVar(context), gVar(context_lens), gVar(utt_lens), gData(floors), gVar(response), gVar(res_lens)
            target, outputs = model(context, context_lens, utt_lens, floors, response, res_lens)
            loss_batch = criterion(outputs, target)
            total_loss += float(loss_batch.item())
        return total_loss / valid_count    
    
def sample(context, context_lens, utt_lens, floors, repeat, SOS_tok, EOS_tok):    
    model.eval()
    c = model.context_encoder(context, context_lens, utt_lens, floors)
#     c_repeated = c.expand(repeat, -1)
    sample_words, sample_lens= model.decoder.sampling(c, None, config['maxlen'], SOS_tok, EOS_tok, "greedy")
    return sample_words, sample_lens 

def evaluate(model, metrics, test_loader, vocab, ivocab, f_eval, repeat):
    recall_bleus, prec_bleus, bows_extrema, bows_avg, bows_greedy, intra_dist1s, intra_dist2s, avg_lens, inter_dist1s, inter_dist2s\
        = [], [], [], [], [], [], [], [], [], []
    bleu1_4s = []
    local_t = 0
    test_loader.epoch_init(1, config['diaglen'], 1, shuffle=False)
    valid_count = 0
    begin_time = time()
    all_generated_sentences = []
    all_generated_lens = []
    while True:
        batch = test_loader.next_batch()
        if batch is None:
#         if batch is None or valid_count >= 400:
            break
        valid_count += 1
        local_t += 1 
        context, context_lens, utt_lens, floors,_,_,_,response,res_lens,_ = batch   
        context, utt_lens = context[:,:,1:], utt_lens-1 # remove the sos token in the context and reduce the context length
#         f_eval.write("Batch %d \n" % (local_t))# print the context
        f_eval.write("Batch %d \n" % (local_t))
        start = np.maximum(0, context_lens[0]-5)
        for t_id in range(start, context.shape[1], 1):
            context_str = indexes2sent(context[0, t_id], vocab, vocab["</s>"], 0)
            f_eval.write("Context %d-%d: %s\n" % (t_id, floors[0, t_id], context_str))
        # print the true outputs    
        ref_str, _ = indexes2sent(response[0], vocab, vocab["</s>"], vocab["<s>"])
        ref_tokens = ref_str.split(' ')
        f_eval.write("Target >> %s\n" % (ref_str.replace(" ' ", "'")))
        context, context_lens, utt_lens, floors = gVar(context), gVar(context_lens), gVar(utt_lens), gData(floors)
        sample_words, sample_lens = sample(context, context_lens, utt_lens, floors, repeat, vocab["<s>"], vocab["</s>"])
        # 存储所有生成的回复，用来计算div
        all_generated_sentences.append(sample_words[0].tolist())
        all_generated_lens.append(sample_lens[0].tolist())
        # nparray: [repeat x seq_len]
        pred_sents, _ = indexes2sent(sample_words, vocab, vocab["</s>"], 0)
        if valid_count % 300 == 0:
            print('true response: ', ref_str)
            print('generate response: ', pred_sents[0])
        pred_tokens = [sent.split(' ') for sent in pred_sents]
        for r_id, pred_sent in enumerate(pred_sents):
            f_eval.write("Generate >> %s\n" % (pred_sent.replace(" ' ", "'")))
        max_bleu, avg_bleu = metrics.sim_bleu(pred_tokens, ref_tokens)
        recall_bleus.append(max_bleu)
        prec_bleus.append(avg_bleu)
        bleu1_4s.append(metrics.sim_bleu1_4(pred_tokens[0], ref_tokens))
        bow_extrema, bow_avg, bow_greedy = metrics.sim_bow(sample_words, sample_lens, response[:,1:], res_lens-2)
        bows_extrema.append(bow_extrema)
        bows_avg.append(bow_avg)
        bows_greedy.append(bow_greedy)
#         intra_dist1, intra_dist2, inter_dist1, inter_dist2 = metrics.div_distinct(sample_words, sample_lens-1)
#         intra_dist1s.append(intra_dist1)
#         intra_dist2s.append(intra_dist2)
        avg_lens.append(np.mean(sample_lens))
#         inter_dist1s.append(inter_dist1)
#         inter_dist2s.append(inter_dist2)
        f_eval.write("\n")
    recall_bleu = float(np.mean(recall_bleus))
    prec_bleu = float(np.mean(prec_bleus))
    f1 = 2*(prec_bleu*recall_bleu) / (prec_bleu+recall_bleu+10e-12)
    bleu1_4 = np.mean(bleu1_4s, 0)
    bow_extrema = float(np.mean(bows_extrema))
    bow_avg = float(np.mean(bows_avg))
    bow_greedy=float(np.mean(bows_greedy))
#     intra_dist1=float(np.mean(intra_dist1s))
#     intra_dist2=float(np.mean(intra_dist2s))
    avg_len=float(np.mean(avg_lens))
    all_generated_sentences = np.array(all_generated_sentences)
    all_generated_lens = np.array(all_generated_lens)
#     print(all_generated_sentences[:5])
#     print(all_generated_lens[:5])
    intra_dist1, intra_dist2, inter_dist1, inter_dist2 = metrics.div_distinct(all_generated_sentences, all_generated_lens)
#     inter_dist1=float(np.mean(inter_dist1s))
#     inter_dist2=float(np.mean(inter_dist2s))
#     report = "Avg recall BLEU %f, bow_extrema %f, bow_avg %f, bow_greedy %f, inter_dist1 %f, inter_dist2 %f avg_len %f" \
#     % (recall_bleu, bow_extrema, bow_avg, bow_greedy, inter_dist1, inter_dist2, avg_len)
    report = "BLEU1 %f, BLEU2 %f, BLEU3 %f, BLEU4 %f, inter_dist1 %f, inter_dist2 %f avg_len %f" % (bleu1_4[0], bleu1_4[1], bleu1_4[2], bleu1_4[3], inter_dist1, inter_dist2, avg_len)
    f_eval.write(report + "\n")
    print(report)
    print(' time: %.1f s'%(time()-begin_time))
#     f_eval.write(report + "\n")
    print("Done testing")
    return recall_bleu, bow_extrema, bow_avg, bow_greedy, inter_dist1, inter_dist2, avg_len

In [23]:
metrics=Metrics(corpus.word2vec)
model = HRED(config, n_tokens)
if corpus.word2vec is not None:
    print("Loaded word2vec")
    model.embedder.weight.data.copy_(torch.from_numpy(corpus.word2vec))
    model.embedder.weight.data[0].fill_(0)
model.to(DEVICE)
parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(parameters, lr=config['lr'])

model.zero_grad()
print_every = 50
best_state = None
max_metric = 0
for epoch in range(8):
    print('Epoch: ', epoch+1)
    train_loader.epoch_init(32, config['diaglen'], 1, shuffle=True)
    n_iters=train_loader.num_batch
    total_loss = 0.0
    epoch_begin = time()
    batch_count = 0
    batch_begin_time = time()
    total_train_batch = 0 # 记录训练的样本数量
    total_valid_batch = 0 # 记录测试的样本数量
    # 分别用来记录训练时候，生成器最顶层的梯度，最底层的梯度以及判别器最顶层的梯度
    train_grad_G_top_layer = []
    train_grad_G_bottom_layer = []
    train_grad_D_top_layer = []
    while True:
        loss_records=[]
        batch = train_loader.next_batch()
        total_train_batch += 32
        if batch is None:
#         if batch is None or total_train_batch >= 1000: # end of epoch
            break
        context, context_lens, utt_lens, floors,_,_,_,response,res_lens,_ = batch
        context, utt_lens = context[:,:,1:], utt_lens-1 # remove the sos token in the context and reduce the context length
        context, context_lens, utt_lens, floors, response, res_lens\
                = gVar(context), gVar(context_lens), gVar(utt_lens), gData(floors), gVar(response), gVar(res_lens)
        loss_batch = train(context, context_lens, utt_lens, floors, response, res_lens)
        total_loss += float(loss_batch)
        batch_count += 1
        if batch_count % print_every == 0:
            print_flush('[%d %d] loss: %.6f time: %.1f s' %
                  (epoch + 1, batch_count, np.exp(total_loss / print_every), time() - batch_begin_time))
            total_loss = 0.0
            batch_begin_time = time()
#         train_grad_G_top_layer.append(torch.mean(model.decoder.rnn.weight_hh_l0.grad))
#         train_grad_G_bottom_layer.append(torch.mean(model.generator.embedding.weight.grad))
#         train_grad_G_bottom_layer.append(torch.mean(model.decoder.rnn.weight_hh_l0.grad))
#     plot_gradient(train_grad_G_top_layer, 'G top layer')
#     plot_gradient(train_grad_G_bottom_layer, 'G bottom layer')
    print_flush("Evaluating....")
    valid_loader.epoch_init(20, config['diaglen'], 1, shuffle=False)
    loss_valid = valid_small(valid_loader)
#     valid_result.append(F1)
    print_flush('*'*60)
    print_flush('[epoch %d]. loss: %.6f time: %.1f s'%(epoch+1, np.exp(loss_valid), time()-epoch_begin))
    print_flush('*'*60)
#     print_flush("testing....")
#     test_loader.epoch_init(1, config['diaglen'], 1, shuffle=False)
#     loss_valid = valid(test_loader)
#     print_flush('*'*60)
#     print_flush('[epoch %d]. loss: %.6f time: %.1f s'%(epoch+1, np.exp(loss_valid), time()-epoch_begin))
#     print_flush('*'*60)
    if (epoch+1) > 4:
        f_eval = open("../result/{}/{}/epoch{}.txt".format('HRED', 'SWDA', epoch), "w")
        recall_bleu, bow_extrema, bow_avg, bow_greedy, inter_dist1, inter_dist2, avg_len\
         =evaluate(model, metrics, test_loader, vocab, ivocab, f_eval, repeat=10)
    epoch_begin = time()
#     if F1 > max_metric:
#         best_state = model.state_dict()
#         max_metric = F1
#         print_flush("save model...")
#         torch.save(best_state, '../datasets/models/baseline_LSTM.pth')
#     epoch_begin = time()
#     if training_termination(valid_result):
#         print_flush("early stop at [%d] epoch!" % (epoch+1))
#         break

Loaded word2vec
Epoch:  1
Train begins with 6398 batches with 12 left over samples
[1 50] loss: 452.534842 time: 2.8 s
[1 100] loss: 161.236787 time: 2.7 s
[1 150] loss: 116.181163 time: 2.8 s
[1 200] loss: 106.455213 time: 2.7 s
[1 250] loss: 99.108374 time: 2.8 s
[1 300] loss: 92.354235 time: 2.8 s
[1 350] loss: 79.116544 time: 2.8 s
[1 400] loss: 73.616033 time: 2.8 s
[1 450] loss: 72.990612 time: 2.8 s
[1 500] loss: 79.281011 time: 2.8 s
[1 550] loss: 69.700417 time: 2.8 s
[1 600] loss: 65.322941 time: 2.8 s
[1 650] loss: 65.763151 time: 2.7 s
[1 700] loss: 63.126728 time: 2.8 s
[1 750] loss: 62.846444 time: 2.8 s
[1 800] loss: 59.764124 time: 2.7 s
[1 850] loss: 58.628815 time: 2.7 s
[1 900] loss: 56.008655 time: 2.8 s
[1 950] loss: 58.555644 time: 2.8 s
[1 1000] loss: 57.143383 time: 2.8 s
[1 1050] loss: 58.728644 time: 2.7 s
[1 1100] loss: 64.541011 time: 2.7 s
[1 1150] loss: 52.586404 time: 2.8 s
[1 1200] loss: 51.025980 time: 2.7 s
[1 1250] loss: 51.379712 time: 2.8 s
[1 1300]

[2 4350] loss: 31.904695 time: 2.8 s
[2 4400] loss: 30.326734 time: 2.8 s
[2 4450] loss: 31.145038 time: 2.8 s
[2 4500] loss: 33.319513 time: 2.8 s
[2 4550] loss: 33.140694 time: 2.8 s
[2 4600] loss: 33.418467 time: 2.8 s
[2 4650] loss: 32.633628 time: 2.7 s
[2 4700] loss: 31.822360 time: 2.8 s
[2 4750] loss: 32.708029 time: 2.8 s
[2 4800] loss: 30.888971 time: 2.7 s
[2 4850] loss: 30.637003 time: 2.8 s
[2 4900] loss: 36.109482 time: 2.7 s
[2 4950] loss: 31.585760 time: 2.7 s
[2 5000] loss: 35.341996 time: 2.8 s
[2 5050] loss: 32.848950 time: 2.8 s
[2 5100] loss: 28.462257 time: 2.7 s
[2 5150] loss: 29.766338 time: 2.8 s
[2 5200] loss: 32.406430 time: 2.7 s
[2 5250] loss: 32.849897 time: 2.7 s
[2 5300] loss: 33.453780 time: 2.8 s
[2 5350] loss: 32.110377 time: 2.7 s
[2 5400] loss: 31.941167 time: 2.7 s
[2 5450] loss: 30.539737 time: 2.8 s
[2 5500] loss: 33.609455 time: 2.8 s
[2 5550] loss: 34.403154 time: 2.7 s
[2 5600] loss: 32.703543 time: 2.8 s
[2 5650] loss: 31.661435 time: 2.7 s
[

[4 2000] loss: 24.075985 time: 1.4 s
[4 2050] loss: 22.353302 time: 1.4 s
[4 2100] loss: 22.302833 time: 1.4 s
[4 2150] loss: 22.885731 time: 1.4 s
[4 2200] loss: 26.067880 time: 1.4 s
[4 2250] loss: 25.981470 time: 1.4 s
[4 2300] loss: 26.263950 time: 1.4 s
[4 2350] loss: 25.008963 time: 1.4 s
[4 2400] loss: 24.483451 time: 1.4 s
[4 2450] loss: 25.030412 time: 1.4 s
[4 2500] loss: 26.683007 time: 1.4 s
[4 2550] loss: 23.777622 time: 2.0 s
[4 2600] loss: 28.691539 time: 2.5 s
[4 2650] loss: 25.081116 time: 1.4 s
[4 2700] loss: 25.138328 time: 1.4 s
[4 2750] loss: 24.279999 time: 1.4 s
[4 2800] loss: 24.863160 time: 1.7 s
[4 2850] loss: 24.120879 time: 2.8 s
[4 2900] loss: 28.106052 time: 2.8 s
[4 2950] loss: 27.076857 time: 2.8 s
[4 3000] loss: 25.362134 time: 2.8 s
[4 3050] loss: 23.012313 time: 2.8 s
[4 3100] loss: 23.404455 time: 2.8 s
[4 3150] loss: 25.473163 time: 2.8 s
[4 3200] loss: 28.828741 time: 2.8 s
[4 3250] loss: 28.703936 time: 2.8 s
[4 3300] loss: 23.449074 time: 2.8 s
[

[5 6350] loss: 27.974790 time: 2.8 s
Evaluating....
Valid begins with 221 batches with 0 left over samples
************************************************************
[epoch 5]. loss: 36.687643 time: 355.7 s
************************************************************
Test begins with 5481 batches with 0 left over samples
true response:  and at this particular point you know taking appropriate course work but really nothing that ' s in out of the ordinary </s>
generate response:  and i think that ' s the key thing to do with the kids and the kids are in school and </s>
true response:  well i have two cats </s>
generate response:  well i ' m a little older than mine </s>
true response:  yeah </s>
generate response:  uh - huh </s>
true response:  <unk> of people ' s heart strings and making them cry and feel like they ' re doing somebody good by giving them </s>
generate response:  and i think that the jury may have been a little bit more difficult to do with the </s>
true response:  so

true response:  something like that </s>
generate response:  yeah </s>
true response:  yeah that ' s fun </s>
generate response:  uh - huh </s>
true response:  oh </s>
generate response:  uh - huh </s>
true response:  lay on those beaches over there </s>
generate response:  i ' m not going to be a <unk> </s>
true response:  uh - huh </s>
generate response:  uh - huh </s>
true response:  they before we got her everybody said oh they ' re just like a cat you know </s>
generate response:  yeah </s>
true response:  uh - huh </s>
generate response:  yeah </s>
true response:  we sort of like our roots and like to get back and have family reunions and all that </s>
generate response:  we ' re not really into the old age but i ' m not sure that it ' s not the same as it is </s>
true response:  but they </s>
generate response:  oh i know </s>
true response:  yeah well see that ' s the reason that we couldn ' t make really make them at first we were going to get a pick up truck </s>
generate res

true response:  you ' d think i ' d have a lot time but </s>
generate response:  um - hum </s>
true response:  and </s>
generate response:  yeah </s>
true response:  yeah i ' ve got a well my roses are on the west side of the house i asked my neighbors what they wanted to see outside their </s>
generate response:  yeah </s>
true response:  um - hum </s>
generate response:  um - hum </s>
BLEU1 0.450906, BLEU2 0.355227, BLEU3 0.289403, BLEU4 0.227789, inter_dist1 0.017134, inter_dist2 0.061858 avg_len 8.934136
 time: 163.6 s
Done testing
Epoch:  8
Train begins with 6398 batches with 12 left over samples
[8 50] loss: 16.530627 time: 1.9 s
[8 100] loss: 15.899976 time: 1.9 s
[8 150] loss: 16.644875 time: 1.9 s
[8 200] loss: 17.350121 time: 1.9 s
[8 250] loss: 16.172146 time: 2.0 s
[8 300] loss: 16.353236 time: 2.0 s
[8 350] loss: 15.261588 time: 2.0 s
[8 400] loss: 14.569209 time: 2.0 s
[8 450] loss: 14.957734 time: 1.9 s
[8 500] loss: 14.550823 time: 1.9 s
[8 550] loss: 17.860043 time: 2.