In [None]:


import os
import random
import numpy as np 
import torch
import sys

# set seed
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from collections import deque

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.autograd import Variable

from movie_data_utils import get_movie_datasets

In [None]:
def regularization_loss_batch(z, percentage, mask=None):
    """
    Compute regularization loss, based on a given rationale sequence

    Inputs:
        z -- torch variable, "binary" rationale, (batch_size, sequence_length)
        percentage -- the percentage of words to keep
    Outputs:
        a loss value that contains two parts:
        continuity_loss --  \sum_{i} | z_{i-1} - z_{i} |
        sparsity_loss -- |mean(z_{i}) - percent|
    """

    # (batch_size,)
    if mask is not None:
        mask_z = z * mask
        seq_lengths = torch.sum(mask, dim=1)
    else:
        mask_z = z
        seq_lengths = torch.sum(z - z + 1.0, dim=1)

    mask_z_ = torch.cat([mask_z[:, 1:], mask_z[:, -1:]], dim=-1)

    continuity_loss = torch.sum(torch.abs(mask_z - mask_z_), dim=-1) / seq_lengths #(batch_size,)
    sparsity_loss = torch.abs(torch.sum(mask_z, dim=-1) / seq_lengths - percentage)  #(batch_size,)

    return continuity_loss, sparsity_loss

In [None]:
# a = list(range(9))
# print(a[2:])
# print(a[-2:])
# print(a[:-2])

# a = torch.rand(3,4)
# print(a)
# _, z_index = torch.topk(a, 2)
# print(z_index.size())
# print(z_index+1)
# print(torch.cat([z_index, z_index+1], dim=1))

In [None]:
from colored import fg, attr, bg

def display_example(vocab, x, z=None, threshold=0.05):

    # apply threshold
    condition = z >= threshold
    for word_index, display_flag in zip(x, condition):
        if word_index == 1:
            continue
        word = vocab.itos[word_index]
        if display_flag:
            output_word = "%s %s%s" %(fg(1), word, attr(0))
            sys.stdout.write(output_word)
        else:
            sys.stdout.write(" " + word)
    print('')
    sys.stdout.flush()
    
def _eval_att_rationale(token_select, mask, z, ratio=0.15):
    _, z_index = torch.topk(token_select, int(mask.size(1) * ratio))
    pred_z = torch.zeros_like(mask)
    pred_z.scatter_(1,z_index,1)
    
    prec = torch.sum(pred_z * z, dim=1) / torch.sum(pred_z, dim=1)
    rec = torch.sum(pred_z * z, dim=1) / (torch.sum(z, dim=1) + 1e-6)
    
    return prec, rec, pred_z

def _eval_att_rationale_topk(token_select, mask, z, topk=10):
    _, z_index = torch.topk(token_select, topk)
    pred_z = torch.zeros_like(mask)
    pred_z.scatter_(1,z_index,1)
    
    prec = torch.sum(pred_z * z, dim=1) / torch.sum(pred_z, dim=1)
    rec = torch.sum(pred_z * z, dim=1) / (torch.sum(z, dim=1) + 1e-6)
    
    return prec, rec, pred_z

In [None]:

data_dir = sys.argv[1]
doc_dir = sys.argv[2]
vocab, D_tr, D_dev, D_te = get_movie_datasets(data_dir, doc_dir, max_seq_len=2200, max_sent_num=128)



In [None]:
class SentenceA2R(nn.Module):

    def __init__(self, vocab, args):
        super().__init__()
        
        self.args = args    
        self.highlight_ratio = args.highlight_ratio
        self.highlight_count = args.highlight_count
        self.exploration_rate = args.exploration_rate
        self.vocab_size = vocab.vectors.shape[0]
        self.embedding_dim = vocab.vectors.shape[1]
        
        self.lambda_dl = 0.
        self.lambda_linear = 1.0 # 0.8
        self.lambda_sparsity = 1.0

        # Word embedding from pretrained vectors
        self.pred_embd = nn.Embedding.from_pretrained(vocab.vectors, freeze=True)                
        
        # Token-level rationale classifier (serve as a head)
        self.attention_fc = nn.Linear(args.rnn_dim * 2, 1, bias=False)
        
        # Predictor RNN
        self.pred_gru = nn.GRU(self.embedding_dim, args.rnn_dim, batch_first=True, bidirectional=True)
        
        # Task predictor 
        self.pred_fc = nn.Linear(args.rnn_dim * 2, args.num_classes)
        self.pred_fc_hard = nn.Linear(args.rnn_dim * 2, args.num_classes)
        

    def pred_vars(self):
        """
        Return the variables of the predictor.
        """
        params = list(self.pred_embd.parameters()) + list(
            self.pred_gru.parameters()) + list(
            self.pred_fc.parameters()) + list(
            self.pred_fc_hard.parameters()) + list(self.attention_fc.parameters())
        
        return params
            
    
    def _transform_token_rationale(self, token_select, mask):
        """
        Returns transformed sent masks.
        token_select -- (batch_size, seq_len)
        mask -- (batch_size, seq_len)
        sent_mask -- (batch_size, seq_len, num_sentences)
        """
        
        z_prob_ = token_select
        
#         _, z_index = torch.topk(z_prob_, int(mask.size(1) * self.highlight_ratio))
        _, z_index = torch.topk(z_prob_, int(self.highlight_count))
        z_ = torch.zeros_like(mask)
        z_.scatter_(1,z_index,1)
        
        z = z_.float()
        
        return z

    
    def _transform_bigram_rationale(self, token_select, mask):
        """
        Returns transformed sent masks.
        token_select -- (batch_size, seq_len)
        mask -- (batch_size, seq_len)
        sent_mask -- (batch_size, seq_len, num_sentences)
        """
        
        z1 = token_select[:, 1:]
        
        z_prob_ = token_select[:,:-1] + z1
        
#         _, z_index = torch.topk(z_prob_, int(mask.size(1) * self.highlight_ratio / 3))
        _, z_index = torch.topk(z_prob_, int(self.highlight_count))
        z1_index = z_index + 1
        
        z_ = torch.zeros_like(mask)
        
        z_index_trigram = torch.cat([z_index, z1_index], dim=1)
        z_.scatter_(1,z_index_trigram,1)
        
        z = z_.float()
        
        return z
    
    
    def _transform_trigram_rationale(self, token_select, mask):
        """
        Returns transformed sent masks.
        token_select -- (batch_size, seq_len)
        mask -- (batch_size, seq_len)
        sent_mask -- (batch_size, seq_len, num_sentences)
        """
        
        z1 = token_select[:, 1:]
        z2 = token_select[:, 2:]
        
        z_prob_ = token_select[:,:-2] + z1[:,:-1] + z2
        
        _, z_index = torch.topk(z_prob_, int(mask.size(1) * self.highlight_ratio / 3))
        z1_index = z_index + 1
        z2_index = z_index + 2
        
        z_ = torch.zeros_like(mask)
        
        z_index_trigram = torch.cat([z_index, z1_index, z2_index], dim=1)
        z_.scatter_(1,z_index_trigram,1)
        
        z = z_.float()
        
        return z
    
    def _transform_5gram_rationale(self, token_select, mask):
        """
        Returns transformed sent masks.
        token_select -- (batch_size, seq_len)
        mask -- (batch_size, seq_len)
        sent_mask -- (batch_size, seq_len, num_sentences)
        """
        
        z1 = token_select[:, 1:]
        z2 = token_select[:, 2:]
        z3 = token_select[:, 3:]
        z4 = token_select[:, 4:]
        
        
        z_prob_ = token_select[:,:-4] + z1[:,:-3] + z2[:,:-2] + z3[:,:-1] + z4
        
        _, z_index = torch.topk(z_prob_, int(mask.size(1) * self.highlight_ratio / 3))
        z1_index = z_index + 1
        z2_index = z_index + 2
        z3_index = z_index + 3
        z4_index = z_index + 4
        
        z_ = torch.zeros_like(mask)
        
        z_index_trigram = torch.cat([z_index, z1_index, z2_index, z3_index, z4_index], dim=1)
        z_.scatter_(1,z_index_trigram,1)
        
        z = z_.float()
        
        return z
    
    
    def _transform_sent_rationales_topk(self, token_select, mask, sent_mask, topk=3):
        """
        Returns transformed sent masks.
        token_select -- (batch_size, seq_len)
        mask -- (batch_size, seq_len)
        sent_mask -- (batch_size, seq_len, num_sentences)
        """
        
        # (batch_size, seq_len, num_sentences)
        token_select_per_sentence = token_select.unsqueeze(-1) * mask.unsqueeze(-1) * sent_mask
        
        # (batch_size, num_sentences)
        num_select_per_sentence = torch.sum(token_select_per_sentence, dim=1)
        num_per_sentence = torch.sum(sent_mask, dim=1)
        ratio_select_per_sentence = num_select_per_sentence / num_per_sentence
        
        sent_probs_ = F.softmax(num_select_per_sentence, dim=-1) #(batch_size, num_sent)
#         sent_probs_ = F.softmax(ratio_select_per_sentence, dim=-1) #(batch_size, num_sent)
#         sent_select = torch.max(sent_probs_, dim=-1)[1]
    
        _, sent_select = torch.topk(sent_probs_, topk, dim=-1)
#         print(sent_select.size())
#         print(sent_select)
         
        sent_rationale = self._one_hot(sent_select[:,0], depth=sent_mask.shape[-1])
        
        for i in range(1, topk):
            sent_rationale_ = self._one_hot(sent_select[:,i], depth=sent_mask.shape[-1])
            sent_rationale += sent_rationale_
        
        return sent_rationale
        
    
    def forward(self, seq, mask, sent_mask, y):
        """
        Inputs:
            seq -- (batch_size, seq_length)
            mask -- (batch_size, seq_length), input mask
            sent_mask -- (batch_size, seq_length, num_sents), assume only three sentences per example
        """
        LARGE_NEG = -1e9
        
        max_seq_len = seq.shape[1]
        lens = mask.sum(axis=1) # (batch_size, )
        mask_ = mask.unsqueeze(-1) # (batch_size, seq_length, 1)
        
        pred_embs = self.pred_embd(seq) # (batch_size, seq_len, embedding_dim)
        # mask the embedding 
        pred_embs_ = pred_embs

        pred_embs_ = pack_padded_sequence(pred_embs_, lengths=lens, batch_first=True, enforce_sorted=False)
        pred_outs_, _ = self.pred_gru(pred_embs_)
        # (batch_size, seq_len, rnn_size)
        pred_outs = pad_packed_sequence(pred_outs_, batch_first=True, 
                                       total_length=max_seq_len, padding_value=0.)[0] 
        
        # (batch_size, seq_len) 
        token_att_logits = self.attention_fc(pred_outs).squeeze(-1)

        token_probs = F.softmax(token_att_logits + (1.-mask)*LARGE_NEG, dim=-1) #(batch_size, seq_len)
        rationale_select = self._transform_bigram_rationale(token_probs, mask).detach()
#         rationale_select = self._transform_token_rationale(token_probs, mask).detach()

        rationale_ = token_probs.unsqueeze(-1) # (batch_size, seq_len, 1) 
        pred_out = torch.sum(pred_outs * rationale_ * mask_, dim=1)

        # classification
        pred_logits = self.pred_fc(pred_out)

        # hard classifier
        hard_rationale_ = rationale_select.unsqueeze(-1)
                    
        # mask the embedding 
        pred_embs_sent_ = pred_embs * hard_rationale_

        pred_embs_sent_ = pack_padded_sequence(pred_embs_sent_, lengths=lens, batch_first=True, enforce_sorted=False)
        pred_outs_sent_, _ = self.pred_gru(pred_embs_sent_)
        # (batch_size, seq_len, rnn_size)
        pred_outs_sent = pad_packed_sequence(pred_outs_sent_, batch_first=True, 
                                       total_length=max_seq_len, padding_value=0.)[0] 
        
        batch_mask = (torch.sum(rationale_select, dim=-1) > 1.).float().unsqueeze(-1) # (batch_size, 1)
        # mask pred_outs using rationale
        pred_outs_sent = hard_rationale_ * pred_outs_sent + (1. - hard_rationale_) * LARGE_NEG
        # max pooling along seq direction
        pred_out_sent = torch.max(pred_outs_sent, dim=1)[0]

        # classification
        pred_out_sent = pred_out_sent * batch_mask
        pred_logits_sent = self.pred_fc_hard(pred_out_sent)

        return pred_logits, pred_logits_sent, token_probs, rationale_select
             

In [None]:
class Dummy():
    pass

args = Dummy()
args.rnn_dim = 50 # 100
args.num_classes = 2
args.gumbel_temp = 1.0 # 0.1

args.lambda_smooth = 1e-3

args.exploration_rate = 0.2
args.highlight_ratio = 0.2
args.highlight_count = 80

# args.highlight_ratio = 0.05

args.l2_decay = 1e-4

model = SentenceA2R(vocab, args).cuda()
print(model)

In [None]:
def _get_attention_continuity_loss(att, mask):
    mask_att = att * mask
    seq_lengths = torch.sum(mask, dim=1)

    mask_att_ = torch.cat([mask_att[:, 1:], mask_att[:, -1:]], dim=-1)
    continuity_loss = torch.sum(torch.abs(mask_att - mask_att_), dim=-1)
#     continuity_loss = torch.sum(torch.abs(mask_att - mask_att_), dim=-1) / seq_lengths #(batch_size,)
    continuity_loss = torch.mean(continuity_loss)
    
    return continuity_loss

def _get_top_attention_continuity_loss(att, z, mask):
    mask_att = att * mask
    mask_z = z * mask
    seq_lengths = torch.sum(mask, dim=1)

    mask_att_ = torch.cat([mask_att[:, 1:], mask_att[:, -1:]], dim=-1)
    mask_att2_ = torch.cat([mask_att[:, 0:1], mask_att[:, 0:-1]], dim=-1)
    
    continuity_loss_right_ = torch.sum(torch.abs(mask_att - mask_att_) * mask_z, dim=-1)
    continuity_loss_left_ = torch.sum(torch.abs(mask_att - mask_att2_) * mask_z, dim=-1)
    
    continuity_loss = continuity_loss_right_ + continuity_loss_left_
    
    continuity_loss = torch.mean(continuity_loss)
    
    return continuity_loss

num_epochs = 100

# gen_optimizer = torch.optim.Adam(model.gen_vars() , lr=1e-3)
# gen_optimizer = torch.optim.Adam(model.gen_vars() , lr=1e-3 * 0.1)
pred_optimizer = torch.optim.Adam(model.pred_vars() , lr=1e-3)

# D_tr_ = DataLoader(D_tr, batch_size=100, shuffle=True, num_workers=16)
# D_dev_ = DataLoader(D_dev, batch_size=100, shuffle=False, num_workers=4)
# D_te_ = DataLoader(D_te, batch_size=100, shuffle=False, num_workers=4)
D_tr_ = DataLoader(D_tr, batch_size=20, shuffle=True, num_workers=16)
D_dev_ = DataLoader(D_dev, batch_size=20, shuffle=False, num_workers=4)
D_te_ = DataLoader(D_te, batch_size=20, shuffle=False, num_workers=4)

counter = 0

switch_epoch = 0

queue_length = 200
history_rewards = deque(maxlen=queue_length)
history_rewards.append(0.)

for i_epoch in range(num_epochs):

    print ("================")
    print ("epoch: %d" % i_epoch)
    print ("================")

    model.train()

    for i_batch, data in enumerate(D_tr_):

#         optimizer.zero_grad()
#         gen_optimizer.zero_grad()
        pred_optimizer.zero_grad()
        counter += 1

        x = data["x"].cuda()
        mask = data["mask"].cuda()
        y = data["y"].cuda()
        sent_mask = data["sent_mask"].transpose(1,2).cuda().contiguous()
        
        logits, logits_sent, token_select, rationale_select = model(x, mask, sent_mask, y)
        
#         continuity_loss = _get_attention_continuity_loss(token_select, mask)
        continuity_loss = _get_top_attention_continuity_loss(token_select, rationale_select, mask)


        probs = torch.softmax(logits, dim=-1)
        probs_sent = torch.softmax(logits_sent, dim=-1)

        consistency_loss = F.kl_div(torch.log(probs), 
                                    probs_sent, reduction='batchmean')

        consistency_loss += F.kl_div(torch.log(probs_sent), 
                                    probs, reduction='batchmean')

        sup_loss = F.cross_entropy(logits, y) + F.cross_entropy(logits_sent, y) + consistency_loss
#         sup_loss += continuity_loss * .01
#         sup_loss = F.cross_entropy(logits, y) + F.cross_entropy(logits_sent, y) + consistency_loss * 2.

        sup_loss.backward()
        pred_optimizer.step()

        pred = torch.max(logits, 1)[1]
        acc = (pred == y).sum().float() / y.shape[0]
        
        if i_epoch >= switch_epoch:
            pred_sent = torch.max(logits_sent, 1)[1]
            acc_sent = (pred_sent == y).sum().float() / y.shape[0]
        
        if i_batch % 10 == 0:
#             print ("Train batch: %d, loss: %.4f, acc: %.4f." % (i_batch, loss.item(), acc.item()))
            if i_epoch < switch_epoch:
                print ("Train batch: %d, loss: %.4f, acc: %.4f." % (i_batch, loss.item(), acc.item()))
            elif i_batch == 0 and i_epoch == 0:
                pass
            else:
                print ("Train batch: %d, sup loss: %.4f, consis_loss: %.4f, acc: %.4f, sent_acc: %.4f." % (i_batch, 
                                                                                sup_loss.item(), 
                                                                                consistency_loss.item(),
                                                                                acc.item(), acc_sent.item()))
                print(F.cross_entropy(logits, y).item(), F.cross_entropy(logits_sent, y).item(), 
                      consistency_loss.item(), continuity_loss.item())

#         loss.backward()
#         optimizer.step()

#     display_example(vocab, x[0], data["sent_mask"][0][0])
    
    model.eval()
#     if i_epoch >= switch_epoch:
#         display_example(vocab, x[0], token_select[0])
    
    with torch.no_grad():

        correct_pred = 0.
        correct_pred_sent = 0.
        num_sample = 0
        
        dev_sparsity = 0.
#         z_prec = {10:0., 15:0., 20:0, 25:0., 30:0.}
#         z_rec = {10:0., 15:0., 20:0, 25:0., 30:0.}
#         z_f1 = {10:0., 15:0., 20:0, 25:0., 30:0.}
        
        z_prec = {}
        z_rec = {}
        z_f1 = {}
        for key in [0.01, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1]:
            z_prec[key] = 0.
            z_rec[key] = 0.
            z_f1[key] = 0.
        
        z_prec_sent = 0.
        z_rec_sent = 0.
        z_f1_sent = 0.
        z_total = 0.
        
        for i_batch, data in enumerate(D_dev_):
            x = data["x"].cuda()
            mask = data["mask"].cuda()
            y = data["y"].cuda()
            sent_mask = data["sent_mask"].transpose(1,2).cuda().contiguous()
            z = data["z"].cuda()
                                 
            if i_epoch < switch_epoch:
                logits, sent_select, _, _, _ = model(x, mask, sent_mask, y)
            else:
                logits, logits_sent, token_select, rationale_select = model(x, mask, sent_mask, y)
                    
            pred = torch.max(logits, 1)[1]
            correct_pred += (pred == y).sum().float()
            if i_epoch >= switch_epoch:
                pred_sent = torch.max(logits_sent, 1)[1]
                correct_pred_sent += (pred_sent == y).sum().float()
            num_sample += y.shape[0]
            
#             if i_batch == 0:
#                 total_sent_select = torch.sum(sent_select, dim=0)
#             else:
#                 total_sent_select += torch.sum(sent_select, dim=0)
                
            if i_epoch >= switch_epoch:
                mask_z = rationale_select * mask
                seq_lengths = torch.sum(mask, dim=1)

#                 sparsity_loss = torch.sum(rationale_select, dim=-1) / mask.size(1) #(batch_size,)
                sparsity_loss = torch.sum(mask_z, dim=-1) / seq_lengths  #(batch_size,)
                dev_sparsity += torch.sum(sparsity_loss).cpu().item()
                
            for key in z_prec.keys():
#                 z_res_p, z_res_r, pred_z = _eval_att_rationale_topk(token_select, mask, z, topk=key)
                z_res_p, z_res_r, pred_z = _eval_att_rationale(token_select, mask, z, ratio=key)
                z_prec[key] += torch.sum(z_res_p)
                z_rec[key] += torch.sum(z_res_r)
                z_f1[key] += torch.sum(z_res_p * z_res_r * 2 / (z_res_p + z_res_r + 1e-6))
                
            z_total += z_res_p.size(0)
                
#             sent_rationale = sent_mask * sent_select.unsqueeze(1) # (batch_size, seq_len, num_sentences)
#             sent_rationale = torch.sum(sent_rationale, dim=-1) # (batch_size, seq_len)
            sent_rationale = rationale_select
            
            prec = torch.sum(sent_rationale * z, dim=1) / torch.sum(sent_rationale, dim=1)
            rec = torch.sum(sent_rationale * z, dim=1) / (torch.sum(z, dim=1) + 1e-6)
            
            z_prec_sent += torch.sum(prec)
            z_rec_sent += torch.sum(rec)
            z_f1_sent += torch.sum(prec * rec * 2 / (prec + rec + 1e-6))
           
        # print ("Dev acc: %.4f" % (correct_pred / num_sample))
        if i_epoch < switch_epoch:
#             print ("Dev acc: %.4f" % (correct_pred / num_sample))
            print ("Dev acc: %.4f sent-level: %.4f" % (correct_pred / num_sample, 
                                                       correct_pred_sent / num_sample))
        else:
            print ("Dev acc: %.4f sent-level: %.4f" % (correct_pred / num_sample, 
                                                       correct_pred_sent / num_sample))
            
        for key in z_prec.keys():
            print("Top-%.4f: Highlight precision: %.4f recall: %.4f f1: %.4f" % (key, 
                                                                      z_prec[key] / z_total, 
                                                                      z_rec[key] / z_total, 
                                                                      z_f1[key] / z_total))
        print("Sent-level: Highlight precision: %.4f recall: %.4f f1: %.4f" % (z_prec_sent / z_total, 
                                                                      z_rec_sent / z_total, 
                                                                      z_f1_sent / z_total))
        
        if i_epoch >= switch_epoch:
            print ("Token sparisity: %.4f" % (dev_sparsity / num_sample)) 
            
        correct_pred = 0.
        correct_pred_sent = 0.
        num_sample = 0
        
        dev_sparsity = 0.
        z_prec = {}
        z_rec = {}
        z_f1 = {}
        for key in [0.01, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1]:
            z_prec[key] = 0.
            z_rec[key] = 0.
            z_f1[key] = 0.
        z_prec_sent = 0.
        z_rec_sent = 0.
        z_f1_sent = 0.
        z_total = 0.
            
        
        for i_batch, data in enumerate(D_te_):
            x = data["x"].cuda()
            mask = data["mask"].cuda()
            y = data["y"].cuda()
            sent_mask = data["sent_mask"].transpose(1,2).cuda().contiguous()
            z = data["z"].cuda()
                                 
            logits, logits_sent, token_select, rationale_select = model(x, mask, sent_mask, y)
                    
            pred = torch.max(logits, 1)[1]
            correct_pred += (pred == y).sum().float()
            if i_epoch >= switch_epoch:
                pred_sent = torch.max(logits_sent, 1)[1]
                correct_pred_sent += (pred_sent == y).sum().float()
            num_sample += y.shape[0]
            
#             if i_batch == 0:
#                 total_sent_select = torch.sum(sent_select, dim=0)
#             else:
#                 total_sent_select += torch.sum(sent_select, dim=0)
                
            if i_epoch >= switch_epoch:
                mask_z = rationale_select * mask
                seq_lengths = torch.sum(mask, dim=1)

                sparsity_loss = torch.sum(mask_z, dim=-1) / seq_lengths  #(batch_size,)
                dev_sparsity += torch.sum(sparsity_loss).cpu().item()
                
#             z_res_p, z_res_r, pred_z = _eval_att_rationale(token_select, mask, z, ratio=args.highlight_ratio)

            for key in z_prec.keys():
#                 z_res_p, z_res_r, pred_z = _eval_att_rationale_topk(token_select, mask, z, topk=key)
                z_res_p, z_res_r, pred_z = _eval_att_rationale(token_select, mask, z, ratio=key)
                z_prec[key] += torch.sum(z_res_p)
                z_rec[key] += torch.sum(z_res_r)
                z_f1[key] += torch.sum(z_res_p * z_res_r * 2 / (z_res_p + z_res_r + 1e-6))
                
            z_total += z_res_p.size(0)
                
#             sent_rationale = sent_mask * sent_select.unsqueeze(1) # (batch_size, seq_len, num_sentences)
#             sent_rationale = torch.sum(sent_rationale, dim=-1) # (batch_size, seq_len)
            sent_rationale = rationale_select
            
            prec = torch.sum(sent_rationale * z, dim=1) / torch.sum(sent_rationale, dim=1)
            rec = torch.sum(sent_rationale * z, dim=1) / (torch.sum(z, dim=1) + 1e-6)
            
            z_prec_sent += torch.sum(prec)
            z_rec_sent += torch.sum(rec)
            z_f1_sent += torch.sum(prec * rec * 2 / (prec + rec + 1e-6))
            
#             z_total += z.size(0)
            
#             display_example(vocab, x[0], pred_z[0])
            if i_batch == 0:
                display_example(vocab, x[0], sent_rationale[0])
#             display_example(vocab, x[0], z[:,2,:][0])

           
        if i_epoch < switch_epoch:
            print ("Test acc: %.4f" % (correct_pred / num_sample))
        else:
            print ("Test acc: %.4f sent-level: %.4f" % (correct_pred / num_sample, 
                                                       correct_pred_sent / num_sample))
        for key in z_prec.keys():
            print("Top-%.4f: Highlight precision: %.4f recall: %.4f f1: %.4f" % (key, 
                                                                      z_prec[key] / z_total, 
                                                                      z_rec[key] / z_total, 
                                                                      z_f1[key] / z_total))
        print("Sent-level: Highlight precision: %.4f recall: %.4f f1: %.4f" % (z_prec_sent / z_total, 
                                                                      z_rec_sent / z_total, 
                                                                      z_f1_sent / z_total))
            
        if i_epoch >= switch_epoch:
            print ("Token sparisity: %.4f" % (dev_sparsity / num_sample)) 
