In [None]:
"""
This is sentence level Tao's Model.
"""

import os, sys
import random
import numpy as np 
import torch
import sys

# set seed
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from collections import deque

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.autograd import Variable

from beer_data_utils_neurips21 import get_beer_datasets_biased
# os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [None]:
def bao_regularization_loss_batch(z, percentage, mask=None):
    """
    Compute regularization loss, based on a given rationale sequence
    Use Yujia's formulation

    Inputs:
        z -- torch variable, "binary" rationale, (batch_size, sequence_length)
        percentage -- the percentage of words to keep
    Outputs:
        a loss value that contains two parts:
        continuity_loss --  \sum_{i} | z_{i-1} - z_{i} |
        sparsity_loss -- |mean(z_{i}) - percent|
    """

    # (batch_size,)
    if mask is not None:
        mask_z = z * mask
        seq_lengths = torch.sum(mask, dim=1)
    else:
        mask_z = z
        seq_lengths = torch.sum(z - z + 1.0, dim=1)

    mask_z_ = torch.cat([mask_z[:, 1:], mask_z[:, -1:]], dim=-1)

    continuity_loss = torch.sum(torch.abs(mask_z - mask_z_), dim=-1) / seq_lengths #(batch_size,)
    sparsity_loss = torch.abs(torch.sum(mask_z, dim=-1) / seq_lengths - percentage)  #(batch_size,)

    return continuity_loss, sparsity_loss

In [None]:
from colored import fg, attr, bg

def display_example(vocab, x, z=None, threshold=0.05):

    # apply threshold
    condition = z >= threshold
    for word_index, display_flag in zip(x, condition):
        if word_index == 1:
            continue
        word = vocab.itos[word_index]
        if display_flag:
            output_word = "%s %s%s" %(fg(1), word, attr(0))
            sys.stdout.write(output_word)
        else:
            sys.stdout.write(" " + word)
    print('')
    sys.stdout.flush()
    
def _eval_att_rationale(token_select, mask, z, ratio=0.15):
    _, z_index = torch.topk(token_select, int(mask.size(1) * ratio))
    pred_z = torch.zeros_like(mask)
    pred_z.scatter_(1,z_index,1)
    
    prec = torch.sum(pred_z * z, dim=1) / torch.sum(pred_z, dim=1)
    rec = torch.sum(pred_z * z, dim=1) / (torch.sum(z, dim=1) + 1e-6)
    
    return prec, rec, pred_z

def _eval_att_rationale_topk(token_select, mask, z, topk=10):
    _, z_index = torch.topk(token_select, topk)
    pred_z = torch.zeros_like(mask)
    pred_z.scatter_(1,z_index,1)
    
    prec = torch.sum(pred_z * z, dim=1) / torch.sum(pred_z, dim=1)
    rec = torch.sum(pred_z * z, dim=1) / (torch.sum(z, dim=1) + 1e-6)
    
    return prec, rec, pred_z

In [None]:

aspect = 2
bias = 0.75
# bias = 0.8

data_dir = "../data/beer_classification/aspect{}".format(aspect)
# vocab, D_tr, D_dev, D_te = get_beer_datasets(data_dir, max_seq_len=300, max_sent_num=50)
vocab, D_tr, D_dev, D_te = get_beer_datasets_biased(data_dir, bias, max_seq_len=300, max_sent_num=50)


In [None]:
class SentenceTao(nn.Module):
    """
    Tao's model at sentence level (Toy setting).
    """
    def __init__(self, vocab, args):
        super().__init__()
        
        self.args = args    
        self.exploration_rate = args.exploration_rate
        self.vocab_size = vocab.vectors.shape[0]
        self.embedding_dim = vocab.vectors.shape[1]
        
        self.lambda_dl = 0.
        self.lambda_linear = 1.0 # 0.8
        self.lambda_sparsity = 1.0

        # Word embedding from pretrained vectors
        self.gen_embd = nn.Embedding.from_pretrained(vocab.vectors, freeze=False)
        self.pred_embd = nn.Embedding.from_pretrained(vocab.vectors, freeze=False)                
        
#         self.linear_embd = nn.Embedding.from_pretrained(vocab.vectors, freeze=True)
        
        # Generator RNN
        self.gen_gru = nn.GRU(self.embedding_dim, args.rnn_dim, batch_first=True, bidirectional=True)        

        # Token-level rationale classifier (serve as a head)
        self.gen_fc_token = nn.Linear(args.rnn_dim * 2, 1, bias=False)
        
#         self.gen_fc_token = nn.Sequential(
#             nn.Linear(args.rnn_dim * 2, 1),
#             nn.ReLU()
#         )
        
        # Sentence-level rationale classifier (serve as a head)
        self.gen_fc = nn.Linear(args.rnn_dim * 2, 1, bias=False)
        
        # Predictor RNN
        self.pred_gru = nn.GRU(self.embedding_dim, args.rnn_dim, batch_first=True, bidirectional=True)
        
        # Task predictor 
        self.pred_fc = nn.Linear(args.rnn_dim * 2, args.num_classes)
        self.pred_fc_hard = nn.Linear(args.rnn_dim * 2, args.num_classes)
        
    def gen_vars(self):
        """
        Return the variables of the generator.
        """
        params = list(self.gen_embd.parameters()) + list(
            self.gen_gru.parameters()) + list(
            self.gen_fc.parameters()) + list(
            self.gen_fc_token.parameters())
        
        return params

    def pred_vars(self):
        """
        Return the variables of the predictor.
        """
        params = list(self.pred_embd.parameters()) + list(
            self.pred_gru.parameters()) + list(
            self.pred_fc.parameters()) + list(self.pred_fc_hard.parameters())
        
        return params
            
    def _one_hot(self, idx, depth):
        """
        Returns a one-hot tensor.
        idx -- (batch_size, )
        depth -- max number
        """
        return torch.zeros(len(idx), depth, device=idx.device).scatter_(1, idx.unsqueeze(1), 1.)
        
    def _generate_rationales(self, z_prob_, num_samples=1):
        '''
        Input:
            z_prob_ -- (num_rows, length, z_dim)
        Output:
            z -- (num_rows, length)
        '''        
        
        # (num_rows, length)
        log_probs = torch.log(z_prob_)
        
        # sample actions
        if self.training:
            z_ = torch.multinomial(z_prob_, 1, replacement=True)
            neg_log_probs_ = - log_probs.gather(1, z_)
            z = z_.squeeze(1) #.float()
        else:
            z_ = torch.max(z_prob_, dim=-1)[1]
            neg_log_probs_ = - log_probs.gather(1, z_.unsqueeze(-1))
            z = z_

        # (num_rows, num_samples)
        neg_log_probs = neg_log_probs_.squeeze(1)
        
        return z, neg_log_probs
        
    
    def forward(self, seq, mask, sent_mask, y, train_part='generator'):
        """
        Inputs:
            seq -- (batch_size, seq_length)
            mask -- (batch_size, seq_length), input mask
            sent_mask -- (batch_size, seq_length, num_sents), assume only three sentences per example
        """
        LARGE_NEG = -1e9
        
        max_seq_len = seq.shape[1]
        lens = mask.sum(axis=1) # (batch_size, )
        mask_ = mask.unsqueeze(-1) # (batch_size, seq_length, 1)
        
        sent_row_mask = (torch.sum(sent_mask, dim=1) > 1.).float() # (batch_size, num_sent)
        
        if train_part == "generator":
            # Word embedding
            gen_embs = self.gen_embd(seq) # (batch_size, seq_len, embedding_dim)

            # Generator
            gen_embs_ = pack_padded_sequence(gen_embs, lengths=lens, batch_first=True, enforce_sorted=False)
            gen_outs_, _ = self.gen_gru(gen_embs_)
            # (batch_size, seq_len, rnn_size)
#             gen_outs = pad_packed_sequence(gen_outs_, batch_first=True, 
#                                            total_length=max_seq_len, padding_value=LARGE_NEG)[0]
            gen_outs = pad_packed_sequence(gen_outs_, batch_first=True, 
                                           total_length=max_seq_len, padding_value=0.)[0]
            
            # Sentence-level pooling        
            # (batch_size, seq_len, rnn_size, 1)
            gen_outs_ = gen_outs.unsqueeze(-1)
            # (batch_size, seq_len, 1, num_sentences)
            sent_mask_ = sent_mask.unsqueeze(-2)
            
            # get representations per sentence
            sent_repts_ = gen_outs_ * sent_mask_ + (1 - sent_mask_) * LARGE_NEG
            # (batch_size, rnn_size, num_sentences)
            sent_repts = torch.max(sent_repts_, dim=1)[0]
            # (batch_size, num_sentences, rnn_size)
            sent_repts = torch.transpose(sent_repts, 1, 2).contiguous()

            # (batch_size, num_sentences) 
            sent_att_logits = self.gen_fc(sent_repts).squeeze(-1)
            # (batch_size, num_sentences) in one-hot format
            sent_probs = F.softmax(sent_att_logits * sent_row_mask + (1 - sent_row_mask) * LARGE_NEG, dim=-1)
    
            sent_probs_ = (1 - self.exploration_rate) * sent_probs + self.exploration_rate / sent_probs.size(-1)
            sent_select, neg_log_probs = self._generate_rationales(sent_probs_)
            sent_select = self._one_hot(sent_select, depth=sent_mask.shape[-1])
        
        elif train_part == "classifiers":
            with torch.no_grad():
                # Word embedding
                gen_embs = self.gen_embd(seq) # (batch_size, seq_len, embedding_dim)

                # Generator
                gen_embs_ = pack_padded_sequence(gen_embs, lengths=lens, batch_first=True, enforce_sorted=False)
                gen_outs_, _ = self.gen_gru(gen_embs_)
                # (batch_size, seq_len, rnn_size)
#                 gen_outs = pad_packed_sequence(gen_outs_, batch_first=True, 
#                                                total_length=max_seq_len, padding_value=LARGE_NEG)[0]
                gen_outs = pad_packed_sequence(gen_outs_, batch_first=True, 
                                               total_length=max_seq_len, padding_value=0.)[0]

                # Sentence-level pooling        
                # (batch_size, seq_len, rnn_size, 1)
                gen_outs_ = gen_outs.unsqueeze(-1)
                # (batch_size, seq_len, 1, num_sentences)
                sent_mask_ = sent_mask.unsqueeze(-2)

                # get representations per sentence
                sent_repts_ = gen_outs_ * sent_mask_ + (1 - sent_mask_) * LARGE_NEG
                # (batch_size, rnn_size, num_sentences)
                sent_repts = torch.max(sent_repts_, dim=1)[0]
                # (batch_size, num_sentences, rnn_size)
                sent_repts = torch.transpose(sent_repts, 1, 2).contiguous()

                # (batch_size, num_sentences) 
                sent_att_logits = self.gen_fc(sent_repts).squeeze(-1)
                # (batch_size, num_sentences) in one-hot format
                sent_probs = F.softmax(sent_att_logits * sent_row_mask + (1 - sent_row_mask) * LARGE_NEG, dim=-1)

                sent_probs_ = (1 - self.exploration_rate) * sent_probs + self.exploration_rate / sent_probs.size(-1)
                sent_select, neg_log_probs = self._generate_rationales(sent_probs_)
                sent_select = self._one_hot(sent_select, depth=sent_mask.shape[-1])

                sent_select = sent_select.detach()
                neg_log_probs = neg_log_probs.detach()

        rationale_ = sent_probs.unsqueeze(-1) # (batch_size, num_sent, 1) 

        sent_rationale = sent_mask * sent_select.unsqueeze(1) # (batch_size, seq_len, num_sentences)
        sent_rationale = torch.sum(sent_rationale, dim=-1) # (batch_size, seq_len)
        sent_rationale_ = sent_rationale.unsqueeze(-1)  
                    
        # Predictor 
        # mask out non-rationale words
        pred_embs = self.pred_embd(seq) # (batch_size, seq_len, embedding_dim)
        # mask the embedding 
        pred_embs_ = pred_embs

        pred_embs_ = pack_padded_sequence(pred_embs_, lengths=lens, batch_first=True, enforce_sorted=False)
        pred_outs_, _ = self.pred_gru(pred_embs_)
        # (batch_size, seq_len, rnn_size)
        pred_outs = pad_packed_sequence(pred_outs_, batch_first=True, 
                                       total_length=max_seq_len, padding_value=0.)[0] 

        # Sentence-level pooling        
        # (batch_size, seq_len, rnn_size, 1)
        pred_outs_ = pred_outs.unsqueeze(-1)
        # (batch_size, seq_len, 1, num_sentences)
        sent_mask_ = sent_mask.unsqueeze(-2)

        # get representations per sentence
        pred_sent_repts_ = pred_outs_ * sent_mask_ + (1 - sent_mask_) * LARGE_NEG
        # (batch_size, rnn_size, num_sentences)
        pred_sent_repts = torch.max(pred_sent_repts_, dim=1)[0]
        # (batch_size, num_sentences, rnn_size)
        pred_sent_repts = torch.transpose(pred_sent_repts, 1, 2).contiguous()
        
        pred_out = torch.sum(pred_sent_repts * rationale_, dim=1)

        # classification
        pred_logits = self.pred_fc(pred_out)
        
        # mask the embedding 
        pred_embs_sent_ = pred_embs * sent_rationale_

        pred_embs_sent_ = pack_padded_sequence(pred_embs_sent_, lengths=lens, batch_first=True, enforce_sorted=False)
        pred_outs_sent_, _ = self.pred_gru(pred_embs_sent_)
        # (batch_size, seq_len, rnn_size)
        pred_outs_sent = pad_packed_sequence(pred_outs_sent_, batch_first=True, 
                                       total_length=max_seq_len, padding_value=0.)[0] 
        
        batch_mask = (torch.sum(sent_rationale, dim=-1) > 1.).float().unsqueeze(-1) # (batch_size, 1)
        # mask pred_outs using rationale
        pred_outs_sent = sent_rationale_ * pred_outs_sent + (1. - sent_rationale_) * LARGE_NEG
        # max pooling along seq direction
        pred_out_sent = torch.max(pred_outs_sent, dim=1)[0]
        
#         pred_outs_sent = sent_rationale_ * pred_outs_sent
#         pred_out_sent = torch.sum(pred_outs_sent, dim=1) / torch.sum(sent_rationale_, dim=1)

        # classification
        pred_out_sent = pred_out_sent * batch_mask
        pred_logits_sent = self.pred_fc_hard(pred_out_sent)
        
        return pred_logits, pred_logits_sent, sent_probs, sent_select, neg_log_probs
             
        
    def get_advantages(self, predict, label, baseline):
        '''
        Input:
            z -- (batch_size, length)
        '''
            
        with torch.no_grad():

            e_loss = torch.tanh(F.cross_entropy(predict, label, reduction='none'))

            rewards = - e_loss

            advantages = rewards - baseline # (batch_size,)
            advantages = Variable(advantages.data, requires_grad=False).to(predict.device)
        
        return advantages, rewards
    
    
    def get_loss(self, predict, label, neg_log_probs, baseline):

        reward_tuple = self.get_advantages(predict, label, baseline)
        advantages, rewards = reward_tuple

        # (batch_size, num_samples)
        rl_loss = torch.sum(neg_log_probs * advantages) / neg_log_probs.size(0)
#         rl_loss2 = torch.sum(neg_log_probs2 * advantages_flat_) / (neg_log_probs2.size(0) * neg_log_probs2.size(1))
        
        return rl_loss, rewards

In [None]:
class Dummy():
    pass

args = Dummy()
args.rnn_dim = 100
args.num_classes = 2
args.gumbel_temp = 1.0 # 0.1

args.lambda_smooth = 1e-3

args.exploration_rate = 0.2

args.highlight_ratio = 0.05

args.l2_decay = 1e-4

model = SentenceTao(vocab, args).cuda()
print(model)

In [None]:
num_epochs = 60

batch_size = 500

# gen_optimizer = torch.optim.Adam(model.gen_vars() , lr=1e-3)
gen_optimizer = torch.optim.Adam(model.gen_vars() , lr=1e-3 * 0.1)
pred_optimizer = torch.optim.Adam(model.pred_vars() , lr=1e-3)

D_tr_ = DataLoader(D_tr, batch_size=batch_size, shuffle=True, num_workers=16)
D_dev_ = DataLoader(D_dev, batch_size=batch_size, shuffle=False, num_workers=4)
D_te_ = DataLoader(D_te, batch_size=batch_size, shuffle=False, num_workers=4)

counter = 0

switch_epoch = 0

queue_length = 200
history_rewards = deque(maxlen=queue_length)
history_rewards.append(0.)

for i_epoch in range(num_epochs):

    print ("================")
    print ("epoch: %d" % i_epoch)
    print ("================")
    
    if i_epoch == switch_epoch:
        print("SWITCH TO ADDING RL LOSS")

    model.train()

    for i_batch, data in enumerate(D_tr_):

#         optimizer.zero_grad()
        gen_optimizer.zero_grad()
        pred_optimizer.zero_grad()
        counter += 1

        x = data["x"].cuda()
        mask = data["mask"].cuda()
        y = data["y"].cuda()
        sent_mask = data["sent_mask"].transpose(1,2).cuda().contiguous()
        

        if counter % 2 == 0:
            # update the classifiers

#             logits, sent_select, _ = model(x, mask, sent_mask, y, "classifiers", sent_select_policy='skew_z1')
#             if i_epoch < switch_epoch:
#                 logits, logits_sent, _, _, _  = model.forward_skew_classifier(x, mask, sent_mask, y)
#                 loss = F.cross_entropy(logits, y) + F.cross_entropy(logits_sent, y)
#             else:
#                 logits, sent_select, _, _ = model(x, mask, sent_mask, y, "classifiers", sent_select_policy='gen')
            logits, logits_sent, token_select, sent_select, _ = model(x, mask, sent_mask, y, "classifiers")


            probs = torch.softmax(logits, dim=-1)
            probs_sent = torch.softmax(logits_sent, dim=-1)

            consistency_loss = F.kl_div(torch.log(probs), 
                                        probs_sent, reduction='batchmean')

            consistency_loss += F.kl_div(torch.log(probs_sent), 
                                        probs, reduction='batchmean')

            sup_loss = F.cross_entropy(logits, y) + F.cross_entropy(logits_sent, y) + consistency_loss
#                 sup_loss = F.cross_entropy(logits, y) + F.cross_entropy(logits_sent, y) + consistency_loss * 2.

            sup_loss.backward()
            pred_optimizer.step()
        else:            
            logits, logits_sent, token_select, sent_select, neg_log_probs = model(x, mask, sent_mask, 
                                                                                  y, "generator")
            
            probs = torch.softmax(logits, dim=-1)
            probs_sent = torch.softmax(logits_sent, dim=-1)
            consis_loss = 0.5 * torch.sum(F.kl_div(torch.log(probs), probs_sent, reduction='none'), dim=1)
            consis_loss += 0.5 * torch.sum(F.kl_div(torch.log(probs_sent), probs, reduction='none'), dim=1)

            loss = F.cross_entropy(logits, y)

            baseline = Variable(torch.FloatTensor([float(np.mean(history_rewards))])).to(neg_log_probs.device)
            rl_loss, rewards = model.get_loss(logits_sent, y, neg_log_probs, baseline)

            batch_reward = np.mean(rewards.cpu().data.numpy())
            history_rewards.append(batch_reward)
            
#             if i_epoch < switch_epoch:
#                 loss += 0. * rl_loss
#             else:
#                 loss += rl_loss

            loss.backward()
            gen_optimizer.step()

        pred = torch.max(logits, 1)[1]
        acc = (pred == y).sum().float() / y.shape[0]
        
        if i_epoch >= switch_epoch:
            pred_sent = torch.max(logits_sent, 1)[1]
            acc_sent = (pred_sent == y).sum().float() / y.shape[0]
        
        if i_batch % 10 == 0:
#             print ("Train batch: %d, loss: %.4f, acc: %.4f." % (i_batch, loss.item(), acc.item()))
            if i_epoch < switch_epoch:
                print ("Train batch: %d, loss: %.4f, acc: %.4f." % (i_batch, loss.item(), acc.item()))
            elif i_batch == 0 and i_epoch == 0:
                pass
            else:
                print ("Train batch: %d, sup loss: %.4f, rl loss: %.4f, consis_loss: %.4f, acc: %.4f, sent_acc: %.4f." % (i_batch, 
                                                                                sup_loss.item(), 
                                                                                loss.item(), 
                                                                                consistency_loss.item(),
                                                                                acc.item(), acc_sent.item()))

#         loss.backward()
#         optimizer.step()

    display_example(vocab, x[0], data["sent_mask"][0][0])
    
    model.eval()
#     if i_epoch >= switch_epoch:
#         display_example(vocab, x[0], token_select[0])
    
    with torch.no_grad():

        correct_pred = 0.
        correct_pred_sent = 0.
        num_sample = 0
        
        dev_sparsity = 0.
        
        for i_batch, data in enumerate(D_dev_):
            x = data["x"].cuda()
            mask = data["mask"].cuda()
            y = data["y"].cuda()
            sent_mask = data["sent_mask"].transpose(1,2).cuda().contiguous()
#             z = data["z"].cuda()
                                 
            if i_epoch < switch_epoch:
                logits, sent_select, _, _, _ = model(x, mask, sent_mask, y)
            else:
                logits, logits_sent, token_select, sent_select, _ = model(x, mask, sent_mask, y)
                    
            pred = torch.max(logits, 1)[1]
            correct_pred += (pred == y).sum().float()
            if i_epoch >= switch_epoch:
                pred_sent = torch.max(logits_sent, 1)[1]
                correct_pred_sent += (pred_sent == y).sum().float()
            num_sample += y.shape[0]
            
#             if i_batch == 0:
#                 total_sent_select = torch.sum(sent_select, dim=0)
#             else:
#                 total_sent_select += torch.sum(sent_select, dim=0)
                
#             if i_epoch >= switch_epoch:
#                 mask_z = token_select * mask
#                 seq_lengths = torch.sum(mask, dim=1)

#                 sparsity_loss = torch.sum(mask_z, dim=-1) / seq_lengths  #(batch_size,)
#                 dev_sparsity += torch.sum(sparsity_loss).cpu().item()
           
        # print ("Dev acc: %.4f" % (correct_pred / num_sample))
        if i_epoch < switch_epoch:
#             print ("Dev acc: %.4f" % (correct_pred / num_sample))
            print ("Dev acc: %.4f sent-level: %.4f" % (correct_pred / num_sample, 
                                                       correct_pred_sent / num_sample))
        else:
            print ("Dev acc: %.4f sent-level: %.4f" % (correct_pred / num_sample, 
                                                       correct_pred_sent / num_sample))
        if i_epoch >= switch_epoch:
            print ("Token sparisity: %.4f" % (dev_sparsity / num_sample)) 
            
        correct_pred = 0.
        correct_pred_sent = 0.
        num_sample = 0
        
        dev_sparsity = 0.
#         z_prec = {10:0.,20:0}
        z_prec = {10:0., 15:0., 20:0, 25:0., 30:0.}
        z_rec = {10:0., 15:0., 20:0, 25:0., 30:0.}
        z_f1 = {10:0., 15:0., 20:0, 25:0., 30:0.}
#         z_prec_sent = 0.
#         z_rec_sent = 0.
#         z_f1_sent = 0.
        z_total = 0.
        z_prec_sent = {} # 0.
        z_rec_sent = {} # 0.
        z_f1_sent = {} # 0.
        z_total_sent = {} #0.
        
        for z_idx in range(5):
            z_prec_sent[z_idx] = 0.
            z_rec_sent[z_idx] = 0.
            z_f1_sent[z_idx] = 0.
            z_total_sent[z_idx] = 0.
            
        z1_prec_sent = 0.
        z1_rec_sent = 0.
        z1_f1_sent = 0.
        z1_total = 0.
        
        for i_batch, data in enumerate(D_te_):
            x = data["x"].cuda()
            mask = data["mask"].cuda()
            y = data["y"].cuda()
            sent_mask = data["sent_mask"].transpose(1,2).cuda().contiguous()
            z = data["z"].cuda()
                                 
            logits, logits_sent, token_select, sent_select, _ = model(x, mask, sent_mask, y)
                    
            pred = torch.max(logits, 1)[1]
            correct_pred += (pred == y).sum().float()
            if i_epoch >= switch_epoch:
                pred_sent = torch.max(logits_sent, 1)[1]
                correct_pred_sent += (pred_sent == y).sum().float()
            num_sample += y.shape[0]
            
#             if i_batch == 0:
#                 total_sent_select = torch.sum(sent_select, dim=0)
#             else:
#                 total_sent_select += torch.sum(sent_select, dim=0)
                
#             if i_epoch >= switch_epoch:
#                 mask_z = token_select * mask
#                 seq_lengths = torch.sum(mask, dim=1)

#                 sparsity_loss = torch.sum(mask_z, dim=-1) / seq_lengths  #(batch_size,)
#                 dev_sparsity += torch.sum(sparsity_loss).cpu().item()
                
#             z_res_p, z_res_r, pred_z = _eval_att_rationale(token_select, mask, z, ratio=args.highlight_ratio)
#             for key in z_prec.keys():
#                 z_res_p, z_res_r, pred_z = _eval_att_rationale_topk(token_select, mask, z[:,2,:], topk=key)
#                 z_prec[key] += torch.sum(z_res_p)
#                 z_rec[key] += torch.sum(z_res_r)
#                 z_f1[key] += torch.sum(z_res_p * z_res_r * 2 / (z_res_p + z_res_r + 1e-6))
                
#             z_total += z_res_p.size(0)
                
            sent_rationale = sent_mask * sent_select.unsqueeze(1) # (batch_size, seq_len, num_sentences)
            sent_rationale = torch.sum(sent_rationale, dim=-1) # (batch_size, seq_len)
            
            for z_idx in range(5):
                prec = torch.sum(sent_rationale * z[:,z_idx,:], dim=1) / (torch.sum(sent_rationale, dim=1) + 1e-6)
                rec = torch.sum(sent_rationale * z[:,z_idx,:], dim=1) / (torch.sum(z[:,z_idx,:], dim=1) + 1e-6)
            
                z_prec_sent[z_idx] += torch.sum(prec)
                z_rec_sent[z_idx] += torch.sum(rec)
                z_f1_sent[z_idx] += torch.sum(prec * rec * 2 / (prec + rec + 1e-6))
            
                z_total_sent[z_idx] += z.size(0)
            
#             prec = torch.sum(sent_rationale * z, dim=1) / torch.sum(sent_rationale, dim=1)
#             rec = torch.sum(sent_rationale * z, dim=1) / (torch.sum(z, dim=1) + 1e-6)
            
#             z_prec_sent += torch.sum(prec)
#             z_rec_sent += torch.sum(rec)
#             z_f1_sent += torch.sum(prec * rec * 2 / (prec + rec + 1e-6))
            
#             z_total += z_res_p.size(0)
            
#             display_example(vocab, x[0], pred_z[0])
            display_example(vocab, x[0], sent_rationale[0])
            display_example(vocab, x[0], z[:,aspect,:][0])

            prec = torch.sum(sent_rationale * sent_mask[:, :, 0], dim=1) / (torch.sum(sent_rationale, dim=1) + 1e-6)
            rec = torch.sum(sent_rationale * sent_mask[:, :, 0], dim=1) / (torch.sum(sent_mask[:, :, 0], dim=1) + 1e-6)

            z1_prec_sent += torch.sum(prec)
            z1_rec_sent += torch.sum(rec)
            z1_f1_sent += torch.sum(prec * rec * 2 / (prec + rec + 1e-6))

            z1_total += z.size(0)
           
        if i_epoch < switch_epoch:
            print ("Test acc: %.4f" % (correct_pred / num_sample))
        else:
            print ("Test acc: %.4f sent-level: %.4f" % (correct_pred / num_sample, 
                                                       correct_pred_sent / num_sample))
#         for key in z_prec.keys():
#             print("Top-%d: Highlight precision: %.4f recall: %.4f f1: %.4f" % (key, 
#                                                                       z_prec[key] / z_total, 
#                                                                       z_rec[key] / z_total, 
#                                                                       z_f1[key] / z_total))
#         print("Sent-level: Highlight precision: %.4f recall: %.4f f1: %.4f" % (z_prec_sent / z_total, 
#                                                                       z_rec_sent / z_total, 
#                                                                       z_f1_sent / z_total))
        for z_idx in range(5):
            print("Sent-level: Highlight precision: %.4f recall: %.4f f1: %.4f" % (z_prec_sent[z_idx] / z_total_sent[z_idx], 
                                                                          z_rec_sent[z_idx] / z_total_sent[z_idx], 
                                                                          z_f1_sent[z_idx] / z_total_sent[z_idx]))

        print("Z1 selection: Highlight precision: %.4f recall: %.4f f1: %.4f" % (z1_prec_sent / z1_total, 
                                                                      z1_rec_sent / z1_total, 
                                                                      z1_f1_sent / z1_total))
            
        if i_epoch >= switch_epoch:
            print ("Token sparisity: %.4f" % (dev_sparsity / num_sample)) 


In [None]:
# display_example(vocab, x[0], sent_rationale[0])
# aspect = 2
# display_example(vocab, x[0], z[:,aspect,:][0])