In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# !cp "drive/MyDrive/Question Generation/vae/models.py" .
%cd '/content/drive/MyDrive/Second/MCQ/vae/'
!pwd

/content/drive/MyDrive/Second/MCQ/vae
/content/drive/MyDrive/Second/MCQ/vae


In [3]:
!pip install transformers
!pip install json-lines
## scatter 1.12+cu113
# !pip install torch-scatter -f https://data.pyg.org/whl/torch-1.12.1+cu113.html
# scatter 1.13+cu116
!pip install torch-scatter -f https://data.pyg.org/whl/torch-1.13.0+cu116.html
!pip install import-ipynb
import import_ipynb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.26.1-py3-none-any.whl (6.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m15.9 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m28.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.12.1-py3-none-any.whl (190 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.3/190.3 KB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.12.1 tokenizers-0.13.2 transformers-4.26.1
Looking in indexes: https://pypi.org/simple, http

In [4]:
import argparse
import os
import random

import numpy as np
import torch
from tqdm.notebook import tqdm, trange
from transformers import BertTokenizer

from eval3D import eval_vae
from utils3D import batch_to_device, get_harv_data_loader, get_squad_data_loader

In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch_scatter import scatter_max
from transformers import BertModel, BertTokenizer

def return_mask_lengths(ids):
    mask = torch.sign(ids).float()
    lengths = torch.sum(mask, 1)
    return mask, lengths


def cal_attn(query, memories, mask):
    ## memories is c_hs, the per-state output
    
    ## mask is 0 at the paddings
    ## line below sets padding to -10000
    # mask=1-mask; mask[mask==1] = -float("inf")
    mask = (1.0 - mask.float()) * -10000.0
    attn_logits = torch.matmul(query, memories.transpose(-1, -2).contiguous())
    attn_logits = attn_logits + mask
    ## padding goes to 0, because we do softmax of -10000,
    attn_weights = F.softmax(attn_logits, dim=-1)
    attn_outputs = torch.matmul(attn_weights, memories)
    return attn_outputs, attn_logits


def gumbel_softmax(logits, tau=1, hard=False, eps=1e-20, dim=-1):
    # type: (Tensor, float, bool, float, int) -> Tensor

    gumbels = -(torch.empty_like(logits).exponential_() +
                eps).log()  # ~Gumbel(0,1)
    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
    y_soft = gumbels.softmax(dim)

    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Re-parametrization trick.
        ret = y_soft
    return ret


class CategoricalKLLoss(nn.Module):
    def __init__(self):
        super(CategoricalKLLoss, self).__init__()

    def forward(self, P, Q):
        log_P = P.log()
        log_Q = Q.log()
        kl = (P * (log_P - log_Q)).sum(dim=-1).sum(dim=-1)
        return kl.mean(dim=0)


class GaussianKLLoss(nn.Module):
    def __init__(self):
        super(GaussianKLLoss, self).__init__()

    def forward(self, mu1, logvar1, mu2, logvar2):
        numerator = logvar1.exp() + torch.pow(mu1 - mu2, 2)
        fraction = torch.div(numerator, (logvar2.exp()))
        kl = 0.5 * torch.sum(logvar2 - logvar1 + fraction - 1, dim=1)
        return kl.mean(dim=0)


class Embedding(nn.Module):
    def __init__(self, bert_model):
        super(Embedding, self).__init__()
        bert_embeddings = BertModel.from_pretrained(bert_model).embeddings
        self.word_embeddings = bert_embeddings.word_embeddings
        self.token_type_embeddings = bert_embeddings.token_type_embeddings
        self.position_embeddings = bert_embeddings.position_embeddings
        self.LayerNorm = bert_embeddings.LayerNorm
        self.dropout = bert_embeddings.dropout

    def forward(self, input_ids, token_type_ids=None, position_ids=None):
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        if position_ids is None:
            seq_length = input_ids.size(1)
            position_ids = torch.arange(
                seq_length, dtype=torch.long, device=input_ids.device)
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        position_embeddings = self.position_embeddings(position_ids)

        embeddings = words_embeddings + token_type_embeddings + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)

        return embeddings


class ContextualizedEmbedding(nn.Module):
    def __init__(self, bert_model):
        super(ContextualizedEmbedding, self).__init__()
        bert = BertModel.from_pretrained(bert_model)
        self.embedding = bert.embeddings
        self.encoder = bert.encoder
        self.num_hidden_layers = bert.config.num_hidden_layers

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        seq_length = input_ids.size(1)
        position_ids = torch.arange(
            seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

        extended_attention_mask = attention_mask.unsqueeze(
            1).unsqueeze(2).float()
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        head_mask = [None] * self.num_hidden_layers

        embedding_output = self.embedding(
            input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
        encoder_outputs = self.encoder(embedding_output,
                                       extended_attention_mask,
                                       head_mask=head_mask)
        sequence_output = encoder_outputs[0]

        return sequence_output


class CustomLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout, bidirectional=False):
        super(CustomLSTM, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.bidirectional = bidirectional
        self.dropout = nn.Dropout(dropout)
        if dropout > 0.0 and num_layers == 1:
            dropout = 0.0

        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
                            num_layers=num_layers, dropout=dropout,
                            bidirectional=bidirectional, batch_first=True)

    def forward(self, inputs, input_lengths, state=None):
        _, total_length, _ = inputs.size()

        input_packed = pack_padded_sequence(inputs, input_lengths.cpu(),
                                            batch_first=True, enforce_sorted=False)

        self.lstm.flatten_parameters()
        output_packed, state = self.lstm(input_packed, state)

        output = pad_packed_sequence(
            output_packed, batch_first=True, total_length=total_length)[0]
        output = self.dropout(output)

        return output, state

class PosteriorEncoder(nn.Module):
    def __init__(self, embedding, emsize,
                 nhidden, nlayers,
                 nzqdim, nza, nzadim,nzddim,
                 dropout=0.0):
        super(PosteriorEncoder, self).__init__()

        self.embedding = embedding
        self.nhidden = nhidden
        self.nlayers = nlayers
        self.nzqdim = nzqdim
        self.nzddim = nzddim
        self.nza = nza
        self.nzadim = nzadim

        self.encoder = CustomLSTM(input_size=emsize,
                                  hidden_size=nhidden,
                                  num_layers=nlayers,
                                  dropout=dropout,
                                  bidirectional=True)

        

        self.question_attention = nn.Linear(2 * nhidden, 2 * nhidden)
        self.context_attention = nn.Linear(2 * nhidden, 2 * nhidden)
        self.distractor_attention = nn.Linear(2 * nhidden, 2 * nhidden)
        self.zq_attention = nn.Linear(nzddim, 2 * nhidden)

        self.zq_linear = nn.Linear(4 * 2 * nhidden, 2 * nzqdim)
        self.zd_linear = nn.Linear(nzqdim + 11 * 2 * nhidden, 2 * nzddim)
        self.za_linear = nn.Linear(nzqdim + 2 * 2 * nhidden, nza * nzadim)

    def forward(self, c_ids, q_ids, a_ids, d_ids1, d_ids2, d_ids3):
        c_mask, c_lengths = return_mask_lengths(c_ids)
        q_mask, q_lengths = return_mask_lengths(q_ids)
        d_mask1, d_lengths1 = return_mask_lengths(d_ids1)
        d_mask2, d_lengths2 = return_mask_lengths(d_ids2)
        d_mask3, d_lengths3 = return_mask_lengths(d_ids3)

        # question enc
        q_embeddings = self.embedding(q_ids)
        q_hs, q_state = self.encoder(q_embeddings, q_lengths)
        q_h = q_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
        q_h = q_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)
        
        ## distractor1 enc 
        d_embeddings1 = self.embedding(d_ids1)
        d_hs1, d_state1 = self.encoder(d_embeddings1, d_lengths1)
        d_h1 = d_state1[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
        d_h1 = d_h1.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)
        ##
        ## distractor2 enc 
        d_embeddings2 = self.embedding(d_ids2)
        d_hs2, d_state2 = self.encoder(d_embeddings2, d_lengths2)
        d_h2 = d_state2[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
        d_h2 = d_h2.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)
        ##
        ## distractor3 enc 
        d_embeddings3 = self.embedding(d_ids3)
        d_hs3, d_state3 = self.encoder(d_embeddings3, d_lengths3)
        d_h3 = d_state3[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
        d_h3 = d_h3.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)
        ##
        # context enc
        c_embeddings = self.embedding(c_ids)
        c_hs, c_state = self.encoder(c_embeddings, c_lengths)
        c_h = c_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
        c_h = c_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)

        # print(f"q_hs shape: {q_hs.shape}")
        # print(f"d_hs shape: {d_hs.shape}")
        # print(f"c_hs shape: {c_hs.shape}")
        # context and answer enc
        c_a_embeddings = self.embedding(c_ids, a_ids, None)
        c_a_hs, c_a_state = self.encoder(c_a_embeddings, c_lengths)
        c_a_h = c_a_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
        c_a_h = c_a_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)

        
        # ## context and distractor enc
        # c_d_embeddings = self.embedding(c_ids, d_ids, None)
        # c_d_hs, c_d_state = self.encoder(c_d_embeddings, c_lengths)
        # c_d_h = c_d_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
        # c_d_h = c_d_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)
        # ##

        # attetion q, c
        mask = c_mask.unsqueeze(1)
        c_attned_by_q, _ = cal_attn(self.question_attention(q_h).unsqueeze(1),
                                    c_hs,
                                    mask)
        c_attned_by_q = c_attned_by_q.squeeze(1)

        # attetion c, q
        mask = q_mask.unsqueeze(1)
        q_attned_by_c, _ = cal_attn(self.context_attention(c_h).unsqueeze(1),
                                    q_hs,
                                    mask)
        q_attned_by_c = q_attned_by_c.squeeze(1)

        h = torch.cat([q_h, q_attned_by_c, c_h, c_attned_by_q], dim=-1)

        zq_mu, zq_logvar = torch.split(self.zq_linear(h), self.nzqdim, dim=1)
        zq = zq_mu + torch.randn_like(zq_mu) * torch.exp(0.5 * zq_logvar)

        ### attention d1, c
        mask = c_mask.unsqueeze(1)
        c_attned_by_d1, _ = cal_attn(self.distractor_attention(d_h1).unsqueeze(1),
                                       c_hs,
                                       mask)
        c_attned_by_d1 = c_attned_by_d1.squeeze(1)
        ## attention c, d1
        mask = d_mask1.unsqueeze(1)
        d1_attned_by_c, _ = cal_attn(self.context_attention(c_h).unsqueeze(1),
                                    d_hs1,
                                    mask)
        d1_attned_by_c = d1_attned_by_c.squeeze(1)
        ###
        ### attention d2, c
        mask = c_mask.unsqueeze(1)
        c_attned_by_d2, _ = cal_attn(self.distractor_attention(d_h2).unsqueeze(1),
                                       c_hs,
                                       mask)
        c_attned_by_d2 = c_attned_by_d2.squeeze(1)
        ## attention c, d2
        mask = d_mask2.unsqueeze(1)
        d2_attned_by_c, _ = cal_attn(self.context_attention(c_h).unsqueeze(1),
                                    d_hs2,
                                    mask)
        d2_attned_by_c = d2_attned_by_c.squeeze(1)
        ###
        ### attention d3, c
        mask = c_mask.unsqueeze(1)
        c_attned_by_d3, _ = cal_attn(self.distractor_attention(d_h3).unsqueeze(1),
                                       c_hs,
                                       mask)
        c_attned_by_d3 = c_attned_by_d3.squeeze(1)
        ## attention c, d3
        mask = d_mask3.unsqueeze(1)
        d3_attned_by_c, _ = cal_attn(self.context_attention(c_h).unsqueeze(1),
                                    d_hs3,
                                    mask)
        d3_attned_by_c = d3_attned_by_c.squeeze(1)
        ###
        ## attention zq, c
        mask = c_mask.unsqueeze(1)
        c_attned_by_zq, _ = cal_attn(self.zq_attention(zq).unsqueeze(1),
                                       c_hs,
                                       mask)
        c_attned_by_zq = c_attned_by_zq.squeeze(1)

        ##
        h = torch.cat([d_h1, d1_attned_by_c,d_h2, d2_attned_by_c,d_h3, d3_attned_by_c,c_h, c_attned_by_d1, c_attned_by_d2, c_attned_by_d3,zq,c_attned_by_zq], dim=-1)

        zd_mu, zd_logvar = torch.split(self.zd_linear(h), self.nzddim, dim=1)
        zd1 = zd_mu + torch.randn_like(zd_mu) * torch.exp(0.5 * zd_logvar)
        zd2 = zd_mu + torch.randn_like(zd_mu) * torch.exp(0.5 * zd_logvar)
        zd3 = zd_mu + torch.randn_like(zd_mu) * torch.exp(0.5 * zd_logvar)
        ##

        # attention zq, c_a
        mask = c_mask.unsqueeze(1)
        c_a_attned_by_zq, _ = cal_attn(self.zq_attention(zq).unsqueeze(1),
                                       c_a_hs,
                                       mask)
        c_a_attned_by_zq = c_a_attned_by_zq.squeeze(1)

        h = torch.cat([zq, c_a_attned_by_zq, c_a_h], dim=-1)

        za_logits = self.za_linear(h).view(-1, self.nza, self.nzadim)
        za_prob = F.softmax(za_logits, dim=-1)
        za = gumbel_softmax(za_logits, hard=True)

        return zq_mu, zq_logvar, zq, za_prob, za, zd_mu, zd_logvar, zd1, zd2, zd3

class PriorEncoder(nn.Module):
    def __init__(self, embedding, emsize,
                 nhidden, nlayers,
                 nzqdim, nza, nzadim,nzddim,
                 dropout=0):
        super(PriorEncoder, self).__init__()

        self.embedding = embedding
        self.nhidden = nhidden
        self.nlayers = nlayers
        self.nzqdim = nzqdim
        self.nzddim = nzddim
        self.nza = nza
        self.nzadim = nzadim

        self.context_encoder = CustomLSTM(input_size=emsize,
                                          hidden_size=nhidden,
                                          num_layers=nlayers,
                                          dropout=dropout,
                                          bidirectional=True)

        self.zq_attention = nn.Linear(nzqdim, 2 * nhidden)

        self.zq_linear = nn.Linear(2 * nhidden, 2 * nzqdim)
        self.zd_linear = nn.Linear(nzqdim + 2 * 2 * nhidden, 2 * nzddim)
        self.za_linear = nn.Linear(nzqdim + 2 * 2 * nhidden, nza * nzadim)

    def forward(self, c_ids):
        c_mask, c_lengths = return_mask_lengths(c_ids)

        c_embeddings = self.embedding(c_ids)
        c_hs, c_state = self.context_encoder(c_embeddings, c_lengths)
        c_h = c_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
        c_h = c_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)

        zq_mu, zq_logvar = torch.split(self.zq_linear(c_h), self.nzqdim, dim=1)
        zq = zq_mu + torch.randn_like(zq_mu)*torch.exp(0.5*zq_logvar)

        mask = c_mask.unsqueeze(1)
        c_attned_by_zq, _ = cal_attn(self.zq_attention(zq).unsqueeze(1),
                                     c_hs,
                                     mask)
        c_attned_by_zq = c_attned_by_zq.squeeze(1)

        h = torch.cat([zq, c_attned_by_zq, c_h], dim=-1)

        zd_mu, zd_logvar = torch.split(self.zd_linear(h), self.nzddim, dim=1)
        zd1 = zd_mu + torch.randn_like(zd_mu)*torch.exp(0.5*zd_logvar)
        zd2 = zd_mu + torch.randn_like(zd_mu)*torch.exp(0.5*zd_logvar)
        zd3 = zd_mu + torch.randn_like(zd_mu)*torch.exp(0.5*zd_logvar)

        za_logits = self.za_linear(h).view(-1, self.nza, self.nzadim)
        za_prob = F.softmax(za_logits, dim=-1)
        za = gumbel_softmax(za_logits, hard=True)

        return zq_mu, zq_logvar, zq, za_prob, za,zd_mu, zd_logvar, zd1, zd2, zd3

    def interpolation(self, c_ids, zq):

        c_mask, c_lengths = return_mask_lengths(c_ids)

        c_embeddings = self.embedding(c_ids)
        c_hs, c_state = self.context_encoder(c_embeddings, c_lengths)
        c_h = c_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
        c_h = c_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)

        mask = c_mask.unsqueeze(1)
        c_attned_by_zq, _ = cal_attn(
            self.zq_attention(zq).unsqueeze(1), c_hs, mask)
        c_attned_by_zq = c_attned_by_zq.squeeze(1)

        h = torch.cat([zq, c_attned_by_zq, c_h], dim=-1)

        za_logits = self.za_linear(h).view(-1, self.nza, self.nzadim)
        za = gumbel_softmax(za_logits, hard=True)

        return za

class AnswerDecoder(nn.Module):
    def __init__(self, embedding, emsize,
                 nhidden, nlayers,
                 dropout=0.0):
        super(AnswerDecoder, self).__init__()

        self.embedding = embedding

        self.context_lstm = CustomLSTM(input_size=4 * emsize,
                                       hidden_size=nhidden,
                                       num_layers=nlayers,
                                       dropout=dropout,
                                       bidirectional=True)

        self.start_linear = nn.Linear(2 * nhidden, 1)
        self.end_linear = nn.Linear(2 * nhidden, 1)
        self.ls = nn.LogSoftmax(dim=1)

    def forward(self, init_state, c_ids):
        _, max_c_len = c_ids.size()
        c_mask, c_lengths = return_mask_lengths(c_ids)

        H = self.embedding(c_ids, c_mask)
        U = init_state.unsqueeze(1).repeat(1, max_c_len, 1)
        G = torch.cat([H, U, H * U, torch.abs(H - U)], dim=-1)
        M, _ = self.context_lstm(G, c_lengths)

        start_logits = self.start_linear(M).squeeze(-1)
        end_logits = self.end_linear(M).squeeze(-1)

        start_end_mask = (c_mask == 0)
        masked_start_logits = start_logits.masked_fill(
            start_end_mask, -10000.0)
        masked_end_logits = end_logits.masked_fill(start_end_mask, -10000.0)

        return masked_start_logits, masked_end_logits

    def generate(self, init_state, c_ids):
        start_logits, end_logits = self.forward(init_state, c_ids)
        c_mask, _ = return_mask_lengths(c_ids)
        batch_size, max_c_len = c_ids.size()

        mask = torch.matmul(c_mask.unsqueeze(2).float(),
                            c_mask.unsqueeze(1).float())
        mask = torch.triu(mask) == 0
        score = (self.ls(start_logits).unsqueeze(2)
                 + self.ls(end_logits).unsqueeze(1))
        score = score.masked_fill(mask, -10000.0)
        score, start_positions = score.max(dim=1)
        score, end_positions = score.max(dim=1)
        start_positions = torch.gather(start_positions,
                                       1,
                                       end_positions.view(-1, 1)).squeeze(1)

        idxes = torch.arange(0, max_c_len, out=torch.LongTensor(max_c_len))
        idxes = idxes.unsqueeze(0).to(
            start_logits.device).repeat(batch_size, 1)

        start_positions = start_positions.unsqueeze(1)
        start_mask = (idxes >= start_positions).long()
        end_positions = end_positions.unsqueeze(1)
        end_mask = (idxes <= end_positions).long()
        a_ids = start_mask + end_mask - 1

        return a_ids, start_positions.squeeze(1), end_positions.squeeze(1)

class ContextEncoderforQG(nn.Module):
    def __init__(self, embedding, emsize,
                 nhidden, nlayers,
                 dropout=0.0):
        super(ContextEncoderforQG, self).__init__()
        self.embedding = embedding
        self.context_lstm = CustomLSTM(input_size=emsize,
                                       hidden_size=nhidden,
                                       num_layers=nlayers,
                                       dropout=dropout,
                                       bidirectional=True)
        self.context_linear = nn.Linear(2 * nhidden, 2 * nhidden)
        self.fusion = nn.Linear(4 * nhidden, 2 * nhidden, bias=False)
        self.gate = nn.Linear(4 * nhidden, 2 * nhidden, bias=False)

    def forward(self, c_ids, a_ids):
        c_mask, c_lengths = return_mask_lengths(c_ids)
        c_embeddings = self.embedding(c_ids, c_mask, a_ids)
        c_outputs, _ = self.context_lstm(c_embeddings, c_lengths)
        # attention
        mask = torch.matmul(c_mask.unsqueeze(2), c_mask.unsqueeze(1))
        c_attned_by_c, _ = cal_attn(self.context_linear(c_outputs),
                                    c_outputs,
                                    mask)
        c_concat = torch.cat([c_outputs, c_attned_by_c], dim=2)
        c_fused = self.fusion(c_concat).tanh()
        c_gate = self.gate(c_concat).sigmoid()
        c_outputs = c_gate * c_fused + (1 - c_gate) * c_outputs
        return c_outputs

# class ContextEncoderforDG(nn.Module):
#     def __init__(self, embedding, emsize,
#                  nhidden, nlayers,
#                  dropout=0.0):
#         super(ContextEncoderforDG, self).__init__()
#         self.embedding = embedding
#         self.context_lstm = CustomLSTM(input_size=emsize,
#                                        hidden_size=nhidden,
#                                        num_layers=nlayers,
#                                        dropout=dropout,
#                                        bidirectional=True)
#         self.context_linear = nn.Linear(2 * nhidden, 2 * nhidden)
#         self.fusion = nn.Linear(4 * nhidden, 2 * nhidden, bias=False)
#         self.gate = nn.Linear(4 * nhidden, 2 * nhidden, bias=False)

#     def forward(self, c_ids, a_ids):
#         c_mask, c_lengths = return_mask_lengths(c_ids)
#         c_embeddings = self.embedding(c_ids, c_mask, a_ids)
#         c_outputs, _ = self.context_lstm(c_embeddings, c_lengths)
#         # attention
#         mask = torch.matmul(c_mask.unsqueeze(2), c_mask.unsqueeze(1))
#         c_attned_by_c, _ = cal_attn(self.context_linear(c_outputs),
#                                     c_outputs,
#                                     mask)
#         c_concat = torch.cat([c_outputs, c_attned_by_c], dim=2)
#         c_fused = self.fusion(c_concat).tanh()
#         c_gate = self.gate(c_concat).sigmoid()
#         c_outputs = c_gate * c_fused + (1 - c_gate) * c_outputs
#         return c_outputs

class QuestionDecoder(nn.Module):
    def __init__(self, sos_id, eos_id,
                 embedding, contextualized_embedding, emsize,
                 nhidden, ntokens, nlayers,
                 dropout=0.0,
                 max_q_len=64):
        super(QuestionDecoder, self).__init__()

        self.sos_id = sos_id
        self.eos_id = eos_id
        self.emsize = emsize
        self.embedding = embedding
        self.nhidden = nhidden
        self.ntokens = ntokens
        self.nlayers = nlayers
        # this max_len include sos eos
        self.max_q_len = max_q_len

        self.context_lstm = ContextEncoderforQG(contextualized_embedding, emsize,
                                                nhidden // 2, nlayers, dropout)

        self.question_lstm = CustomLSTM(input_size=emsize,
                                        hidden_size=nhidden,
                                        num_layers=nlayers,
                                        dropout=dropout,
                                        bidirectional=False)

        self.question_linear = nn.Linear(nhidden, nhidden)

        self.concat_linear = nn.Sequential(nn.Linear(2*nhidden, 2*nhidden),
                                           nn.Dropout(dropout),
                                           nn.Linear(2*nhidden, 2*emsize))

        self.logit_linear = nn.Linear(emsize, ntokens, bias=False)

        # fix output word matrix
        self.logit_linear.weight = embedding.word_embeddings.weight
        for param in self.logit_linear.parameters():
            param.requires_grad = False

        self.discriminator = nn.Bilinear(emsize, nhidden, 1)

    def postprocess(self, q_ids):
        eos_mask = q_ids == self.eos_id
        no_eos_idx_sum = (eos_mask.sum(dim=1) == 0).long() * \
            (self.max_q_len - 1)
        eos_mask = eos_mask.cpu().numpy()
        q_lengths = np.argmax(eos_mask, axis=1) + 1
        q_lengths = torch.tensor(q_lengths).to(
            q_ids.device).long() + no_eos_idx_sum
        batch_size, max_len = q_ids.size()
        idxes = torch.arange(0, max_len).to(q_ids.device)
        idxes = idxes.unsqueeze(0).repeat(batch_size, 1)
        q_mask = (idxes < q_lengths.unsqueeze(1))
        q_ids = q_ids.long() * q_mask.long()
        return q_ids

    def forward(self, init_state, c_ids, q_ids, a_ids):
        batch_size, max_q_len = q_ids.size()

        c_outputs = self.context_lstm(c_ids, a_ids)

        c_mask, _ = return_mask_lengths(c_ids)
        q_mask, q_lengths = return_mask_lengths(q_ids)

        # question dec
        q_embeddings = self.embedding(q_ids)
        q_outputs, _ = self.question_lstm(q_embeddings, q_lengths, init_state)

        # attention
        mask = torch.matmul(q_mask.unsqueeze(2), c_mask.unsqueeze(1))
        c_attned_by_q, attn_logits = cal_attn(self.question_linear(q_outputs),
                                              c_outputs,
                                              mask)

        # gen logits
        q_concated = torch.cat([q_outputs, c_attned_by_q], dim=2)
        q_concated = self.concat_linear(q_concated)
        q_maxouted, _ = q_concated.view(
            batch_size, max_q_len, self.emsize, 2).max(dim=-1)
        gen_logits = self.logit_linear(q_maxouted)

        # copy logits
        bq = batch_size * max_q_len
        c_ids = c_ids.unsqueeze(1).repeat(
            1, max_q_len, 1).view(bq, -1).contiguous()
        attn_logits = attn_logits.view(bq, -1).contiguous()
        copy_logits = torch.zeros(bq, self.ntokens).to(c_ids.device)
        copy_logits = copy_logits - 10000.0
        copy_logits, _ = scatter_max(attn_logits, c_ids, out=copy_logits)
        copy_logits = copy_logits.masked_fill(copy_logits == -10000.0, 0)
        copy_logits = copy_logits.view(batch_size, max_q_len, -1).contiguous()

        logits = gen_logits + copy_logits

        # mutual information btw answer and question
        a_emb = c_outputs * a_ids.float().unsqueeze(2)
        a_mean_emb = torch.sum(a_emb, 1) / a_ids.sum(1).unsqueeze(1).float()
        fake_a_mean_emb = torch.cat([a_mean_emb[-1].unsqueeze(0),
                                     a_mean_emb[:-1]], dim=0)

        q_emb = q_maxouted * q_mask.unsqueeze(2)
        q_mean_emb = torch.sum(q_emb, 1) / q_lengths.unsqueeze(1).float()
        fake_q_mean_emb = torch.cat([q_mean_emb[-1].unsqueeze(0),
                                     q_mean_emb[:-1]], dim=0)

        bce_loss = nn.BCEWithLogitsLoss()
        true_logits = self.discriminator(q_mean_emb, a_mean_emb)
        true_labels = torch.ones_like(true_logits)

        fake_a_logits = self.discriminator(q_mean_emb, fake_a_mean_emb)
        fake_q_logits = self.discriminator(fake_q_mean_emb, a_mean_emb)
        fake_logits = torch.cat([fake_a_logits, fake_q_logits], dim=0)
        fake_labels = torch.zeros_like(fake_logits)

        true_loss = bce_loss(true_logits, true_labels)
        fake_loss = 0.5 * bce_loss(fake_logits, fake_labels)
        loss_info = 0.5 * (true_loss + fake_loss)

        return logits, loss_info, q_maxouted

    def generate(self, init_state, c_ids, a_ids):
        c_mask, _ = return_mask_lengths(c_ids)
        c_outputs = self.context_lstm(c_ids, a_ids)

        batch_size = c_ids.size(0)

        q_ids = torch.LongTensor([self.sos_id] * batch_size).unsqueeze(1)
        q_ids = q_ids.to(c_ids.device)
        token_type_ids = torch.zeros_like(q_ids)
        position_ids = torch.zeros_like(q_ids)
        q_embeddings = self.embedding(q_ids, token_type_ids, position_ids)

        state = init_state

        # unroll
        all_q_ids = list()
        all_q_ids.append(q_ids)
        for _ in range(self.max_q_len - 1):
            position_ids = position_ids + 1
            q_outputs, state = self.question_lstm.lstm(q_embeddings, state)

            # attention
            mask = c_mask.unsqueeze(1)
            c_attned_by_q, attn_logits = cal_attn(self.question_linear(q_outputs),
                                                  c_outputs,
                                                  mask)

            # gen logits
            q_concated = torch.cat([q_outputs, c_attned_by_q], dim=2)
            q_concated = self.concat_linear(q_concated)
            q_maxouted, _ = q_concated.view(
                batch_size, 1, self.emsize, 2).max(dim=-1)
            gen_logits = self.logit_linear(q_maxouted)

            # copy logits
            attn_logits = attn_logits.squeeze(1)
            copy_logits = torch.zeros(
                batch_size, self.ntokens).to(c_ids.device)
            copy_logits = copy_logits - 10000.0
            copy_logits, _ = scatter_max(attn_logits, c_ids, out=copy_logits)
            copy_logits = copy_logits.masked_fill(copy_logits == -10000.0, 0)

            logits = gen_logits + copy_logits.unsqueeze(1)

            q_ids = torch.argmax(logits, 2)
            all_q_ids.append(q_ids)

            q_embeddings = self.embedding(q_ids, token_type_ids, position_ids)

        q_ids = torch.cat(all_q_ids, 1)
        q_ids = self.postprocess(q_ids)

        return q_ids

    def sample(self, init_state, c_ids, a_ids):
        c_mask, c_lengths = return_mask_lengths(c_ids)
        c_outputs = self.context_lstm(c_ids, a_ids)

        batch_size = c_ids.size(0)

        q_ids = torch.LongTensor([self.sos_id] * batch_size).unsqueeze(1)
        q_ids = q_ids.to(c_ids.device)
        token_type_ids = torch.zeros_like(q_ids)
        position_ids = torch.zeros_like(q_ids)
        q_embeddings = self.embedding(q_ids, token_type_ids, position_ids)

        state = init_state

        # unroll
        all_q_ids = list()
        all_q_ids.append(q_ids)
        for _ in range(self.max_q_len - 1):
            position_ids = position_ids + 1
            q_outputs, state = self.question_lstm.lstm(q_embeddings, state)

            # attention
            mask = c_mask.unsqueeze(1)
            c_attned_by_q, attn_logits = cal_attn(self.question_linear(q_outputs),
                                                  c_outputs,
                                                  mask)

            # gen logits
            q_concated = torch.cat([q_outputs, c_attned_by_q], dim=2)
            q_concated = self.concat_linear(q_concated)
            q_maxouted, _ = q_concated.view(batch_size, 1, self.emsize, 2).max(dim=-1)
            gen_logits = self.logit_linear(q_maxouted)

            # copy logits
            attn_logits = attn_logits.squeeze(1)
            copy_logits = torch.zeros(batch_size, self.ntokens).to(c_ids.device)
            copy_logits = copy_logits - 10000.0
            copy_logits, _ = scatter_max(attn_logits, c_ids, out=copy_logits)
            copy_logits = copy_logits.masked_fill(copy_logits == -10000.0, 0)

            logits = gen_logits + copy_logits.unsqueeze(1)
            logits = logits.squeeze(1)
            logits = self.top_k_top_p_filtering(logits, 2, top_p=0.8)
            probs = F.softmax(logits, dim=-1)
            q_ids = torch.multinomial(probs, num_samples=1)  # [b,1]
            all_q_ids.append(q_ids)

            q_embeddings = self.embedding(q_ids, token_type_ids, position_ids)

        q_ids = torch.cat(all_q_ids, 1)
        q_ids = self.postprocess(q_ids)

        return q_ids

class DistractorDecoder(nn.Module):
    def __init__(self, sos_id, eos_id,
                 embedding, contextualized_embedding, emsize,
                 d_nhidden, ntokens, d_nlayers,
                 q_nhidden,q_nlayers,d_dropout=0.0,q_dropout=0.0,
                 max_d_len=64):
        super(DistractorDecoder, self).__init__()

        self.sos_id = sos_id
        self.eos_id = eos_id
        self.emsize = emsize
        self.embedding = embedding
        self.d_nhidden = d_nhidden
        self.ntokens = ntokens
        self.d_nlayers = d_nlayers
        # this max_len include sos eos
        self.max_d_len = max_d_len 

        self.context_lstm = ContextEncoderforQG(contextualized_embedding, emsize,
                                                d_nhidden // 2, d_nlayers, d_dropout)

        self.distractor_lstm = CustomLSTM(input_size=emsize,
                                        hidden_size=d_nhidden,
                                        num_layers=d_nlayers,
                                        dropout=d_dropout,
                                        bidirectional=False)

        self.question_lstm = CustomLSTM(input_size=emsize,
                                        hidden_size=q_nhidden,
                                        num_layers=q_nlayers,
                                        dropout=q_dropout,
                                        bidirectional=False)

        self.distractor_linear = nn.Linear(d_nhidden, d_nhidden)
        # self.question_linear = nn.Linear(q_nhidden, q_nhidden)

        self.concat_linear = nn.Sequential(nn.Linear(2*d_nhidden+q_nhidden, 2*d_nhidden+q_nhidden),
                                           nn.Dropout(d_dropout),
                                           nn.Linear(2*d_nhidden+q_nhidden, 2*emsize))

        self.logit_linear = nn.Linear(emsize, ntokens, bias=False)

        # fix output word matrix
        self.logit_linear.weight = embedding.word_embeddings.weight
        for param in self.logit_linear.parameters():
            param.requires_grad = False

        self.discriminator = nn.Bilinear(emsize, d_nhidden, 1)
        self.discriminator2 = nn.Bilinear(emsize, emsize, 1)

    def postprocess(self, d_ids):
        eos_mask = d_ids == self.eos_id
        no_eos_idx_sum = (eos_mask.sum(dim=1) == 0).long() * \
            (self.max_d_len - 1)
        eos_mask = eos_mask.cpu().numpy()
        d_lengths = np.argmax(eos_mask, axis=1) + 1
        d_lengths = torch.tensor(d_lengths).to(
            d_ids.device).long() + no_eos_idx_sum
        batch_size, max_len = d_ids.size()
        idxes = torch.arange(0, max_len).to(d_ids.device)
        idxes = idxes.unsqueeze(0).repeat(batch_size, 1)
        d_mask = (idxes < d_lengths.unsqueeze(1))
        d_ids = d_ids.long() * d_mask.long()
        return d_ids

    def forward(self, d_init_state,q_init_state, c_ids, q_ids, a_ids,d_ids,q_maxouted):
        batch_size, max_d_len = d_ids.size()

        c_outputs = self.context_lstm(c_ids, a_ids)

        c_mask, _ = return_mask_lengths(c_ids)
        d_mask, d_lengths = return_mask_lengths(d_ids)
        q_mask, q_lengths = return_mask_lengths(q_ids)
        # distractor dec
        d_embeddings = self.embedding(d_ids)
        d_outputs, _ = self.distractor_lstm(d_embeddings, d_lengths, d_init_state)
        # # question dec
        q_embeddings = self.embedding(q_ids)
        q_outputs, _ = self.question_lstm(q_embeddings, q_lengths, q_init_state)

        # attention q,d
        mask_temp = torch.matmul(d_mask.unsqueeze(2), q_mask.unsqueeze(1))
        q_attned_by_d, attn_logits = cal_attn(self.distractor_linear(d_outputs),
                                              q_outputs,
                                              mask_temp)
        # attention c,d
        mask_temp = torch.matmul(d_mask.unsqueeze(2), c_mask.unsqueeze(1))
        c_attned_by_d, attn_logits = cal_attn(self.distractor_linear(d_outputs),
                                              c_outputs,
                                              mask_temp)
        
        

        # gen logits
        d_q_concated = torch.cat([d_outputs, c_attned_by_d, q_attned_by_d], dim=2)
   
        d_q_concated = self.concat_linear(d_q_concated)
      
        d_maxouted, _ = d_q_concated.view(
            batch_size, max_d_len, self.emsize, 2).max(dim=-1)
        gen_logits = self.logit_linear(d_maxouted)

        # copy logits
        bd = batch_size * max_d_len
        c_ids = c_ids.unsqueeze(1).repeat(
            1, max_d_len, 1).view(bd, -1).contiguous()
        attn_logits = attn_logits.view(bd, -1).contiguous()
        copy_logits = torch.zeros(bd, self.ntokens).to(c_ids.device)
        copy_logits = copy_logits - 10000.0
        copy_logits, _ = scatter_max(attn_logits, c_ids, out=copy_logits)
        copy_logits = copy_logits.masked_fill(copy_logits == -10000.0, 0)
        copy_logits = copy_logits.view(batch_size, max_d_len, -1).contiguous()

        logits = gen_logits + copy_logits

        # mutual information btw answer and distractor
        a_emb = c_outputs * a_ids.float().unsqueeze(2)
        a_mean_emb = torch.sum(a_emb, 1) / a_ids.sum(1).unsqueeze(1).float()
        fake_a_mean_emb = torch.cat([a_mean_emb[-1].unsqueeze(0),
                                     a_mean_emb[:-1]], dim=0)

        d_emb = d_maxouted * d_mask.unsqueeze(2)
        d_mean_emb = torch.sum(d_emb, 1) / d_lengths.unsqueeze(1).float()
        fake_d_mean_emb = torch.cat([d_mean_emb[-1].unsqueeze(0),
                                     d_mean_emb[:-1]], dim=0)

        bce_loss = nn.BCEWithLogitsLoss()
        true_logits = self.discriminator(d_mean_emb, a_mean_emb)
        true_labels = torch.ones_like(true_logits)

        fake_a_logits = self.discriminator(d_mean_emb, fake_a_mean_emb)
        fake_d_logits = self.discriminator(fake_d_mean_emb, a_mean_emb)
        fake_logits = torch.cat([fake_a_logits, fake_d_logits], dim=0)
        fake_labels = torch.zeros_like(fake_logits)

        true_loss = bce_loss(true_logits, true_labels)
        fake_loss = 0.5 * bce_loss(fake_logits, fake_labels)
        loss_info_with_answer = 0.5 * (true_loss + fake_loss)

        # mutual information btw distractor and question
        d_emb = d_maxouted * d_mask.unsqueeze(2)
        d_mean_emb = torch.sum(d_emb, 1) / d_lengths.unsqueeze(1).float()
        fake_d_mean_emb = torch.cat([d_mean_emb[-1].unsqueeze(0),
                                     d_mean_emb[:-1]], dim=0)

        q_emb = q_maxouted * q_mask.unsqueeze(2)
        q_mean_emb = torch.sum(q_emb, 1) / q_lengths.unsqueeze(1).float()
        fake_q_mean_emb = torch.cat([q_mean_emb[-1].unsqueeze(0),
                                     q_mean_emb[:-1]], dim=0)

        bce_loss = nn.BCEWithLogitsLoss()
        true_logits = self.discriminator2(q_mean_emb, d_mean_emb)
        true_labels = torch.ones_like(true_logits)

        fake_d_logits = self.discriminator2(q_mean_emb, fake_d_mean_emb)
        fake_q_logits = self.discriminator2(fake_q_mean_emb, d_mean_emb)
        fake_logits = torch.cat([fake_d_logits, fake_q_logits], dim=0)
        fake_labels = torch.zeros_like(fake_logits)

        true_loss = bce_loss(true_logits, true_labels)
        fake_loss = 0.5 * bce_loss(fake_logits, fake_labels)
        loss_info_with_question = 0.5 * (true_loss + fake_loss)
        
        return logits, loss_info_with_answer,loss_info_with_question

    def generate(self, d_init_state,q_init_state, c_ids, a_ids,q_ids):
        c_mask, _ = return_mask_lengths(c_ids)
        q_mask, q_lengths = return_mask_lengths(q_ids)
        c_outputs = self.context_lstm(c_ids, a_ids)

        
        # # question dec
        q_embeddings = self.embedding(q_ids)
        q_outputs, _ = self.question_lstm(q_embeddings, q_lengths, q_init_state)


        batch_size = c_ids.size(0)

        d_ids = torch.LongTensor([self.sos_id] * batch_size).unsqueeze(1)
        d_ids = d_ids.to(c_ids.device)
        token_type_ids = torch.zeros_like(d_ids)
        position_ids = torch.zeros_like(d_ids)
        d_embeddings = self.embedding(d_ids, token_type_ids, position_ids)

        state = d_init_state

        # unroll
        all_d_ids = list()
        all_d_ids.append(d_ids)
        for _ in range(self.max_d_len - 1):
            position_ids = position_ids + 1
            d_outputs, state = self.distractor_lstm.lstm(d_embeddings, state)
            
            # attention q,d
            mask_temp = q_mask.unsqueeze(1)
            q_attned_by_d, attn_logits = cal_attn(self.distractor_linear(d_outputs),
                                                  q_outputs,
                                                  mask_temp)
            # attention c,d
            mask_temp = c_mask.unsqueeze(1)
            c_attned_by_d, attn_logits = cal_attn(self.distractor_linear(d_outputs),
                                                  c_outputs,
                                                  mask_temp)

            # gen logits
            d_concated = torch.cat([d_outputs, c_attned_by_d,q_attned_by_d], dim=2)
            d_concated = self.concat_linear(d_concated)
            d_maxouted, _ = d_concated.view(
                batch_size, 1, self.emsize, 2).max(dim=-1)
            gen_logits = self.logit_linear(d_maxouted)

            # copy logits
            attn_logits = attn_logits.squeeze(1)
            copy_logits = torch.zeros(
                batch_size, self.ntokens).to(c_ids.device)
            copy_logits = copy_logits - 10000.0
            copy_logits, _ = scatter_max(attn_logits, c_ids, out=copy_logits)
            copy_logits = copy_logits.masked_fill(copy_logits == -10000.0, 0)

            logits = gen_logits + copy_logits.unsqueeze(1)

            d_ids = torch.argmax(logits, 2)
            all_d_ids.append(d_ids)

            d_embeddings = self.embedding(d_ids, token_type_ids, position_ids)

        d_ids = torch.cat(all_d_ids, 1)
        d_ids = self.postprocess(d_ids)

        return d_ids

    def sample(self, d_init_state,q_init_state, c_ids, a_ids,q_ids):
        c_mask, c_lengths = return_mask_lengths(c_ids)
        q_mask, q_lengths = return_mask_lengths(q_ids)
        c_outputs = self.context_lstm(c_ids, a_ids)

        # question dec
        q_embeddings = self.embedding(q_ids)
        q_outputs, _ = self.question_lstm(q_embeddings, q_lengths, q_init_state)                
        

        batch_size = c_ids.size(0)

        d_ids = torch.LongTensor([self.sos_id] * batch_size).unsqueeze(1)
        d_ids = d_ids.to(c_ids.device)
        token_type_ids = torch.zeros_like(d_ids)
        position_ids = torch.zeros_like(d_ids)
        d_embeddings = self.embedding(d_ids, token_type_ids, position_ids)

        state = d_init_state

        # unroll
        all_d_ids = list()
        all_d_ids.append(d_ids)
        for _ in range(self.max_q_len - 1):
            position_ids = position_ids + 1
            d_outputs, state = self.distractor_lstm.lstm(d_embeddings, state)

            # attention q,d
            mask_temp = q_mask.unsqueeze(1)
            q_attned_by_d, attn_logits = cal_attn(self.distractor_linear(d_outputs),
                                                  q_outputs,
                                                  mask_temp)
        
            # attention c,d
            mask_temp = c_mask.unsqueeze(1)
            c_attned_by_d, attn_logits = cal_attn(self.distractor_linear(d_outputs),
                                                  c_outputs,
                                                  mask_temp)

            # gen logits
            d_concated = torch.cat([d_outputs, c_attned_by_d,q_attned_by_d], dim=2)
            d_concated = self.concat_linear(d_concated)
            d_maxouted, _ = d_concated.view(batch_size, 1, self.emsize, 2).max(dim=-1)
            gen_logits = self.logit_linear(d_maxouted)

            # copy logits
            attn_logits = attn_logits.squeeze(1)
            copy_logits = torch.zeros(batch_size, self.ntokens).to(c_ids.device)
            copy_logits = copy_logits - 10000.0
            copy_logits, _ = scatter_max(attn_logits, c_ids, out=copy_logits)
            copy_logits = copy_logits.masked_fill(copy_logits == -10000.0, 0)

            logits = gen_logits + copy_logits.unsqueeze(1)
            logits = logits.squeeze(1)
            logits = self.top_k_top_p_filtering(logits, 2, top_p=0.8)
            probs = F.softmax(logits, dim=-1)
            d_ids = torch.multinomial(probs, num_samples=1)  # [b,1]
            all_d_ids.append(d_ids)

            d_embeddings = self.embedding(d_ids, token_type_ids, position_ids)

        d_ids = torch.cat(all_d_ids, 1)
        d_ids = self.postprocess(d_ids)

        return d_ids

class DiscreteVAE(nn.Module):
    def __init__(self, args):
        super(DiscreteVAE, self).__init__()
        tokenizer = BertTokenizer.from_pretrained(args.bert_model)
        padding_idx = tokenizer.vocab['[PAD]']
        sos_id = tokenizer.vocab['[CLS]']
        eos_id = tokenizer.vocab['[SEP]']
        ntokens = len(tokenizer.vocab)

        bert_model = args.bert_model
        if "large" in bert_model:
            emsize = 1024
        else:
            emsize = 768

        enc_nhidden = args.enc_nhidden
        enc_nlayers = args.enc_nlayers
        enc_dropout = args.enc_dropout
        dec_a_nhidden = args.dec_a_nhidden
        dec_a_nlayers = args.dec_a_nlayers
        dec_a_dropout = args.dec_a_dropout
        self.dec_q_nhidden = dec_q_nhidden = args.dec_q_nhidden
        self.dec_q_nlayers = dec_q_nlayers = args.dec_q_nlayers
        self.dec_d_nhidden = dec_d_nhidden = args.dec_d_nhidden
        self.dec_d_nlayers = dec_d_nlayers = args.dec_d_nlayers
        dec_q_dropout = args.dec_q_dropout
        dec_d_dropout = args.dec_d_dropout
        self.nzqdim = nzqdim = args.nzqdim
        self.nzddim = nzddim = args.nzddim
        self.nza = nza = args.nza
        self.nzadim = nzadim = args.nzadim

        self.lambda_kl = args.lambda_kl
        self.lambda_info = args.lambda_info

        max_q_len = args.max_q_len
        max_d_len = args.max_d_len

        embedding = Embedding(bert_model)
        contextualized_embedding = ContextualizedEmbedding(bert_model)
        # freeze embedding
        for param in embedding.parameters():
            param.requires_grad = False
        for param in contextualized_embedding.parameters():
            param.requires_grad = False

        self.posterior_encoder = PosteriorEncoder(embedding, emsize,
                                                  enc_nhidden, enc_nlayers,
                                                  nzqdim, nza, nzadim,nzddim,
                                                  enc_dropout)

        self.prior_encoder = PriorEncoder(embedding, emsize,
                                          enc_nhidden, enc_nlayers,
                                          nzqdim, nza, nzadim,nzddim, enc_dropout)

        self.answer_decoder = AnswerDecoder(contextualized_embedding, emsize,
                                            dec_a_nhidden, dec_a_nlayers,
                                            dec_a_dropout)

        self.question_decoder = QuestionDecoder(sos_id, eos_id,
                                                embedding, contextualized_embedding, emsize,
                                                dec_q_nhidden, ntokens, dec_q_nlayers,
                                                dec_q_dropout,
                                                max_q_len)

        self.distractor_decoder = DistractorDecoder(sos_id, eos_id,
                                                embedding, contextualized_embedding, emsize,
                                                dec_d_nhidden, ntokens, dec_d_nlayers,
                                                dec_q_nhidden, dec_q_nlayers,
                                                dec_d_dropout,dec_q_dropout,
                                                max_d_len)

        self.q_h_linear = nn.Linear(nzqdim, dec_q_nlayers * dec_q_nhidden)
        self.q_c_linear = nn.Linear(nzqdim, dec_q_nlayers * dec_q_nhidden)
        self.d_h_linear = nn.Linear(nzddim, dec_d_nlayers * dec_d_nhidden)
        self.d_c_linear = nn.Linear(nzddim, dec_d_nlayers * dec_d_nhidden)
        self.a_linear = nn.Linear(nza * nzadim, emsize, False)

        self.q_rec_criterion = nn.CrossEntropyLoss(ignore_index=padding_idx)
        self.d_rec_criterion = nn.CrossEntropyLoss(ignore_index=padding_idx)
        self.gaussian_kl_criterion = GaussianKLLoss()
        self.categorical_kl_criterion = CategoricalKLLoss()

    def return_init_state_qa(self, zq, za):

        q_init_h = self.q_h_linear(zq)
        q_init_c = self.q_c_linear(zq)

        q_init_h = q_init_h.view(-1, self.dec_q_nlayers,
                                 self.dec_q_nhidden).transpose(0, 1).contiguous()
        q_init_c = q_init_c.view(-1, self.dec_q_nlayers,
                                 self.dec_q_nhidden).transpose(0, 1).contiguous()
        q_init_state = (q_init_h, q_init_c)

        za_flatten = za.view(-1, self.nza * self.nzadim)
        a_init_state = self.a_linear(za_flatten)

        return q_init_state, a_init_state

    
    def return_init_state_d(self, zd):
        d_init_h = self.d_h_linear(zd)
        d_init_c = self.d_c_linear(zd)
        
        d_init_h = d_init_h.view(-1, self.dec_d_nlayers,
                                 self.dec_d_nhidden).transpose(0, 1).contiguous()
        d_init_c = d_init_c.view(-1, self.dec_d_nlayers,
                                 self.dec_d_nhidden).transpose(0, 1).contiguous()
        d_init_state = (d_init_h, d_init_c)

        return d_init_state


    def forward(self, c_ids, q_ids, a_ids, start_positions, end_positions,d1_ids, d2_ids, d3_ids):
        # I need to: generate 3 distractors.
        # Posterior encoder needs to take in the ground truth distractors, to sample 3 zd
        # Prior encoder needs to sample 3 zds
        # Does that mean I have 3 KL loss for zd?
        posterior_zq_mu, posterior_zq_logvar, posterior_zq, \
            posterior_za_prob, posterior_za, \
            posterior_zd_mu, posterior_zd_logvar, \
                posterior_zd1, posterior_zd2, posterior_zd3 \
            = self.posterior_encoder(c_ids, q_ids, a_ids,d1_ids,d2_ids,d3_ids)

        prior_zq_mu, prior_zq_logvar, _, \
            prior_za_prob, _, \
            prior_zd_mu, prior_zd_logvar, _, _, _ \
            = self.prior_encoder(c_ids)

        q_init_state, a_init_state = self.return_init_state_qa(
            posterior_zq, posterior_za)
        d1_init_state = self.return_init_state_d(posterior_zd1)
        d2_init_state = self.return_init_state_d(posterior_zd2)
        d3_init_state = self.return_init_state_d(posterior_zd3)
        # answer decoding
        start_logits, end_logits = self.answer_decoder(a_init_state, c_ids)
        # question decoding
        q_logits, loss_info, q_maxouted = self.question_decoder(
            q_init_state, c_ids, q_ids, a_ids)        
        # distractor1 decoding
        d1_logits, d1_loss_info_with_answer,d1_loss_info_with_question = self.distractor_decoder(
            d1_init_state,q_init_state, c_ids, q_ids, a_ids,d1_ids,q_maxouted)
        # distractor2 decoding
        d2_logits, d2_loss_info_with_answer,d2_loss_info_with_question = self.distractor_decoder(
            d2_init_state,q_init_state, c_ids, q_ids, a_ids,d2_ids,q_maxouted)
        # distractor3 decoding
        d3_logits, d3_loss_info_with_answer,d3_loss_info_with_question = self.distractor_decoder(
            d3_init_state,q_init_state, c_ids, q_ids, a_ids,d3_ids,q_maxouted)

        # q rec loss
        loss_q_rec = self.q_rec_criterion(q_logits[:, :-1, :].transpose(1, 2).contiguous(),
                                          q_ids[:, 1:])

        # a rec loss
        max_c_len = c_ids.size(1)
        a_rec_criterion = nn.CrossEntropyLoss(ignore_index=max_c_len)
        start_positions.clamp_(0, max_c_len)
        end_positions.clamp_(0, max_c_len)
        loss_start_a_rec = a_rec_criterion(start_logits, start_positions)
        loss_end_a_rec = a_rec_criterion(end_logits, end_positions)
        loss_a_rec = 0.5 * (loss_start_a_rec + loss_end_a_rec)

        # d1 rec loss
        loss_d1_rec = self.d_rec_criterion(d1_logits[:, :-1, :].transpose(1, 2).contiguous(),
                                          d1_ids[:, 1:])
        # d2 rec loss
        loss_d2_rec = self.d_rec_criterion(d2_logits[:, :-1, :].transpose(1, 2).contiguous(),
                                          d2_ids[:, 1:])
        # d3 rec loss
        loss_d3_rec = self.d_rec_criterion(d3_logits[:, :-1, :].transpose(1, 2).contiguous(),
                                          d3_ids[:, 1:])

        # kl loss
        loss_zq_kl = self.gaussian_kl_criterion(posterior_zq_mu,
                                                posterior_zq_logvar,
                                                prior_zq_mu,
                                                prior_zq_logvar)
        # kl loss for d
        loss_zd_kl = self.gaussian_kl_criterion(posterior_zd_mu,
                                                posterior_zd_logvar,
                                                prior_zd_mu,
                                                prior_zd_logvar)

        loss_za_kl = self.categorical_kl_criterion(posterior_za_prob,
                                                   prior_za_prob)
        loss_d_rec = loss_d1_rec + loss_d2_rec + loss_d3_rec
        loss_info_d_q = (d1_loss_info_with_question+d2_loss_info_with_question+d3_loss_info_with_question)/3
        d_loss_info_with_answer = d1_loss_info_with_answer + d2_loss_info_with_answer + d3_loss_info_with_answer
        loss_kl = self.lambda_kl * (loss_zq_kl + loss_za_kl+loss_zd_kl)
        loss_info = self.lambda_info * (loss_info+d_loss_info_with_answer+loss_info_d_q)
        
        loss = loss_q_rec + loss_a_rec + loss_kl + loss_info + loss_d_rec

        return loss, \
            loss_q_rec, loss_a_rec, \
            loss_zq_kl, loss_za_kl, \
            loss_info, loss_d_rec,loss_zd_kl, loss_info_d_q#,loss_info_d_ans

    def generate(self, zq, za, c_ids,zd1, zd2, zd3):
        q_init_state, a_init_state = self.return_init_state_qa(zq, za)
        d1_init_state = self.return_init_state_d(zd1)
        d2_init_state = self.return_init_state_d(zd2)
        d3_init_state = self.return_init_state_d(zd3)
        a_ids, start_positions, end_positions = self.answer_decoder.generate(   ## answer generation
            a_init_state, c_ids)

        q_ids = self.question_decoder.generate(q_init_state, c_ids, a_ids)  ## question generation
        d1_ids = self.distractor_decoder.generate(d1_init_state,q_init_state, c_ids, a_ids, q_ids)      
        d2_ids = self.distractor_decoder.generate(d2_init_state,q_init_state, c_ids, a_ids, q_ids)      
        d3_ids = self.distractor_decoder.generate(d3_init_state,q_init_state, c_ids, a_ids, q_ids)      

        return q_ids, start_positions, end_positions, d1_ids, d2_ids, d3_ids

    def return_answer_logits(self, zq, za, c_ids,zd1, zd2, zd3):
        _, a_init_state = self.return_init_state_qa(zq, za)

        start_logits, end_logits = self.answer_decoder(a_init_state, c_ids)

        return start_logits, end_logits

In [6]:
# from trainerCustom import VAETrainer
# from modelsCustom import DiscreteVAE, return_mask_lengths

In [7]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
def setArguments():
    args = dict()
    args["seed"]=1004
    args["debug"]=True
    args["train_dir"]='../data/sciq/squad_format/test.json'
    args["dev_dir"]='../data/sciq/squad_format/test.json'

    args["max_c_len"]=384
    args["max_q_len"]=64
    args['max_d_len']=15

    args["model_dir"]="../save/vae-checkpoint"
    args["epochs"]=20
    args["lr"]=1e-3
    args["batch_size"]=1
    args["weight_decay"]=0.0
    args["clip"]=5.0

    args["bert_model"]='bert-base-uncased'
    args["enc_nhidden"]=300
    args["enc_nlayers"]=1
    args["enc_dropout"]=0.2
    args["dec_a_nhidden"]=300
    args["dec_a_nlayers"]=1
    args["dec_a_dropout"]=0.2
    args["dec_q_nhidden"]=900
    args["dec_q_nlayers"]=2
    args["dec_q_dropout"]=0.3
    args['dec_d_nhidden']=900
    args['dec_d_nlayers']=2
    args['dec_d_dropout']=0.3
    args["nzqdim"]=50
    args['nzddim']=50
    args["nza"]=20
    args["nzadim"]=10
    args["lambda_kl"]=0.1
    args["lambda_info"]=1.0
    return dotdict(args)
args = setArguments()

if args.debug:
    print("Debug Mode On.")
    args.model_dir = "./dummy"
# set model dir
model_dir = args.model_dir
os.makedirs(model_dir, exist_ok=True)
args.model_dir = os.path.abspath(model_dir)

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)


Debug Mode On.


In [8]:
class VAETrainer(object):
    def __init__(self, args):
        self.args = args
        self.clip = args.clip
        self.device = args.device

        self.vae = DiscreteVAE(args).to(self.device)
        params = filter(lambda p: p.requires_grad, self.vae.parameters())
        self.optimizer = torch.optim.Adam(params, lr=args.lr)

        self.loss_q_rec = 0
        self.loss_a_rec = 0
        self.loss_d_rec = 0
        self.loss_zq_kl = 0
        self.loss_za_kl = 0
        self.loss_zd_kl = 0
        self.loss_info = 0

    def train(self, c_ids, q_ids, a_ids, start_positions, end_positions,d1_ids,d2_ids,d3_ids):
        self.vae = self.vae.train()
        # Forward
        loss, \
        loss_q_rec, loss_a_rec, \
        loss_zq_kl, loss_za_kl, \
        loss_info, loss_d_rec, loss_zd_kl, loss_info_d_q \
        = self.vae(c_ids, q_ids, a_ids, start_positions, end_positions,d1_ids,d2_ids,d3_ids)
        # Backward
        self.optimizer.zero_grad()
        loss.backward()

        # Step
        self.optimizer.step()

        self.loss_q_rec = loss_q_rec.item()
        self.loss_a_rec = loss_a_rec.item()
        self.loss_zq_kl = loss_zq_kl.item()
        self.loss_za_kl = loss_za_kl.item()
        self.loss_zd_kl = loss_zd_kl.item()
        self.loss_info = loss_info.item()
        self.loss_info_d_q=loss_info_d_q.item()
        #self.loss_info_d_ans=loss_info_d_ans.item()
        self.loss_d_rec = loss_d_rec.item()

    def generate_posterior(self, c_ids, q_ids, a_ids,d1_ids,d2_ids,d3_ids):
        self.vae = self.vae.eval()
        with torch.no_grad():
            _, _, zq, _, za, _, _, zd1, zd2, zd3 = self.vae.posterior_encoder(c_ids, q_ids, a_ids,d1_ids,d2_ids,d3_ids)
            q_ids, start_positions, end_positions, d1_ids, d2_ids, d3_ids = self.vae.generate(zq, za, c_ids, zd1, zd2, zd3)
        return q_ids, start_positions, end_positions, zq, d1_ids, d2_ids, d3_ids, zd1,zd2,zd3

    def generate_answer_logits(self, c_ids, q_ids, a_ids, d1_ids,d2_ids,d3_ids):
        self.vae = self.vae.eval()
        with torch.no_grad():
            _, _, zq, _, za, _, _, zd1, zd2, zd3 = self.vae.posterior_encoder(c_ids, q_ids, a_ids, d1_ids,d2_ids,d3_ids)
            start_logits, end_logits = self.vae.return_answer_logits(zq, za, c_ids, zd1, zd2, zd3)
        return start_logits, end_logits

    def generate_prior(self, c_ids):
        self.vae = self.vae.eval()
        with torch.no_grad():
            _, _, zq, _, za, _, _, zd1, zd2, zd3 = self.vae.prior_encoder(c_ids)
            q_ids, start_positions, end_positions, d1_ids, d2_ids, d3_ids = self.vae.generate(zq, za, c_ids, zd1, zd2, zd3)
        return q_ids, start_positions, end_positions, zq, d1_ids, d2_ids, d3_ids, zd1, zd2, zd3    

    def save(self, filename):
        params = {
            'state_dict': self.vae.state_dict(),
            'args': self.args
        }
        torch.save(params, filename)


In [9]:
def main(args):
    tokenizer = BertTokenizer.from_pretrained(args.bert_model)
    train_loader, _, _ = get_squad_data_loader(tokenizer, args.train_dir,
                                         shuffle=True, args=args)
    eval_data = get_squad_data_loader(tokenizer, args.dev_dir,
                                      shuffle=False, args=args)

    args.device = torch.cuda.current_device()

    trainer = VAETrainer(args)

    loss_log1 = tqdm(total=0, bar_format='{desc}')
    loss_log2 = tqdm(total=0, bar_format='{desc}')
    loss_log3 = tqdm(total=0, bar_format='{desc}')
    eval_log = tqdm(total=0, bar_format='{desc}')
    best_eval_log = tqdm(total=0, bar_format='{desc}')

    print("MODEL DIR: " + args.model_dir)

    best_bleu, best_bleu_d, best_em, best_f1 = 0.0, 0.0, 0.0, 0.0
    for epoch in trange(int(args.epochs), desc="Epoch", position=0):
        for batch in tqdm(train_loader, desc="Train iter", leave=False, position=1):
            c_ids, q_ids, a_ids, start_positions, end_positions,d1_ids,d2_ids,d3_ids \
            = batch_to_device(batch, args.device)
            ## There are now three d_ids, make the appropriate changes.
            trainer.train(c_ids, q_ids, a_ids, start_positions, end_positions,d1_ids,d2_ids,d3_ids)
            
            str1 = 'Q REC : {:06.4f} A REC : {:06.4f} D REC : {:06.4f}'
            str2 = 'ZQ KL : {:06.4f} ZA KL : {:06.4f} ZD KL : {:06.4f}'
            str3 = 'L_INFO : {:06.4f}  L_INFO_D_Q : {:06.4f}' # INFO_D_ANS : {:06.4f}'
            str1 = str1.format(float(trainer.loss_q_rec), float(trainer.loss_a_rec), float(trainer.loss_d_rec))
            str2 = str2.format(float(trainer.loss_zq_kl), float(trainer.loss_za_kl), float(trainer.loss_zd_kl))
            str3 = str3.format(float(trainer.loss_info), float(trainer.loss_info_d_q))#, float(trainer.loss_info_d_ans))
            loss_log1.set_description_str(str1)
            loss_log2.set_description_str(str2)
            loss_log3.set_description_str(str3)

        if epoch >= 0:
            metric_dict, bleu, _, bleu_d = eval_vae(epoch, args, trainer, eval_data)
            f1 = metric_dict["f1"]
            em = metric_dict["exact_match"]
            bleu = bleu * 100
            bleu_d = bleu_d * 100
            _str = '{}-th Epochs Q-BLEU : {:02.2f} D-BLEU : {:02.2f} EM : {:02.2f} F1 : {:02.2f}'
            _str = _str.format(epoch, bleu, bleu_d, em, f1)
            eval_log.set_description_str(_str)
            if em > best_em:
                best_em = em
            if f1 > best_f1:
                best_f1 = f1
                trainer.save(os.path.join(args.model_dir, "best_f1_model.pt"))
            if bleu > best_bleu:
                best_bleu = bleu
                trainer.save(os.path.join(args.model_dir, "best_q_bleu_model.pt"))
            if bleu_d > best_bleu_d:
                best_bleu_d = bleu_d
                trainer.save(os.path.join(args.model_dir, "best_d_bleu_model.pt"))

            _str = 'BEST Q-BLEU : {:02.2f} D-BLEU : {:02.2f} EM : {:02.2f} F1 : {:02.2f}'
            _str = _str.format(best_bleu, best_bleu_d, best_em, best_f1)
            best_eval_log.set_description_str(_str)


In [10]:
main(args)

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

100%|██████████| 1/1 [00:00<00:00,  3.27it/s]
100%|██████████| 820/820 [00:25<00:00, 31.69it/s]
100%|██████████| 1/1 [00:00<00:00,  1.11it/s]
100%|██████████| 820/820 [00:16<00:00, 50.91it/s]


Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predicti











MODEL DIR: /content/drive/MyDrive/Second/MCQ/vae/dummy


Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

Train iter:   0%|          | 0/825 [00:00<?, ?it/s]





Eval iter:   0%|          | 0/825 [00:00<?, ?it/s][A[A[A[A



Eval iter:   0%|          | 1/825 [00:00<09:53,  1.39it/s][A[A[A[A



Eval iter:   0%|          | 2/825 [00:01<10:07,  1.35it/s][A[A[A[A



Eval iter:   0%|          | 3/825 [00:02<09:38,  1.42it/s][A[A[A[A



Eval iter:   0%|          | 4/825 [00:02<08:36,  1.59it/s][A[A[A[A



Eval iter:   1%|          | 5/825 [00:03<07:51,  1.74it/s][A[A[A[A



Eval iter:   1%|          | 6/825 [00:03<07:20,  1.86it/s][A[A[A[A



Eval iter:   1%|          | 7/825 [00:04<08:26,  1.61it/s][A[A[A[A



Eval iter:   1%|          | 8/825 [00:04<07:49,  1.74it/s][A[A[A[A



Eval iter:   1%|          | 9/825 [00:05<07:39,  1.78it/s][A[A[A[A



Eval iter:   1%|          | 10/825 [00:06<08:15,  1.65it/s][A[A[A[A



Eval iter:   1%|▏         | 11/825 [00:06<07:48,  1.74it/s][A[A[A[A



Eval iter:   1%|▏         | 12/825 [00:07<07:38,  1.77it/s][A[A[A[A



Eval iter:   2%|▏         | 13/825 [00:

TypeError: ignored