In [15]:
#The ESIM seems complex, let's make it first
#following almostly copied from yjqiang
import os
import re
import time
import torch
import torch.nn as nn
import torch.optim as optim
import numpy
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

torch.manual_seed(1)

class BiLSTM(nn.Module):
    def __init__(self, embedding_dim, hidden_size, bidirectional=True):
        super().__init__()
        self.bilstm = nn.LSTM(embedding_dim, hidden_size, bidirectional)

    def forward(self, tensor_sentences, length_list):
        packed_sentences = pack_padded_sequence(tensor_sentences, length_list, batch_first=True, enforce_sorted=False)
        output, _ = self.bilstm(packed_sentences)
        result, _ = pad_packed_sequence(output, batch_first=True)
        return result

![alt text](encoder.png)

![alt text](whatesimcando.png)

In [None]:
class LocalInferenceModeling(nn.Module):
    def __init__(self):
        super().__init__()
        self.softmax = nn.Softmax(dim=2)  # -1

    def attention(self, x1_bar, length_list1, x2_bar, length_list2):
        e = torch.bmm(x1_bar, x2_bar.transpose(1, 2))  # shape: (batch_size, max_sentence_length1, max_sentence_length2)

        batch_size, max_sentence_length1, max_sentence_length2 = e.shape

        #to get rid of pad
        mask1 = torch.ge(torch.arange(max_sentence_length1).expand(batch_size, -1), seq_lengths1.unsqueeze(-1))
        mask2 = torch.ge(torch.arange(max_sentence_length2).expand(batch_size, -1), seq_lengths2.unsqueeze(-1))

        softmax_e = self.softmax(e.masked_fill(mask2.unsqueeze(1), float('-inf')))
        x1_tilde = torch.bmm(softmax_e, x2_bar)

        softmax_e = self.softmax(e.transpose(1, 2).masked_fill(mask1.unsqueeze(1), float('-inf')))
        x2_tilde = torch.bmm(softmax_e, x1_bar) 
        return x1_tilde, x2_tilde

    @staticmethod
    def enhancement(x_bar:, x_tilde):
        return torch.cat([x_bar, x_tilde, x_bar - x_tilde, x_bar * x_tilde], dim=-1)

    def forward(self, x1_bar, length_list1, x2_bar, length_list2):
        x1_tilde, x2_tilde = self.attention(x1_bar, length_list1, x2_bar, length_list2)
        return self.enhancement(x1_bar, x1_tilde), self.enhancement(x2_bar, x2_tilde)

![alt text](attention1.png)
![alt text](attention2.png)
![alt text](enhance.png)

In [17]:
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, class_num):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.tanh = nn.Tanh()
        self.fc2 = nn.Linear(hidden_size, class_num)
        self.loss_func = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.fc2(self.tanh(self.fc1(x)))

class InferenceComposition(nn.Module):
    def __init__(self, input_size, hidden_size, class_num):
        super().__init__()
        self.F = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.BiLSTM = BiLSTM(input_size=hidden_size, hidden_size=hidden_size // 2)

        self.MLP = MLP(in_features=4 * hidden_size, hidden_features=hidden_size, class_num=class_num)
        self.loss_func = self.MLP.loss_func

    def handle_x(self, m_x, length_list):
        v_x_t = self.BiLSTM(self.relu(self.F(m_x)), seq_lengths)

        max_sentence_length = m_x.shape[1]
        v_x_t_transpose = v_x_t.transpose(1, 2)
        v_x_avg = F.avg_pool1d(v_x_t_transpose, kernel_size=max_sentence_length).squeeze(-1)
        v_x_max = F.max_pool1d(v_x_t_transpose, kernel_size=max_sentence_length).squeeze(-1)
        return torch.cat([v_x_avg, v_x_max], dim=1)

    def forward(self, m_x1: torch.Tensor, seq_lengths1: torch.Tensor, m_x2: torch.Tensor, seq_lengths2: torch.Tensor) -> torch.Tensor:
        v = torch.cat([self.handle_x(m_x1, seq_lengths1), self.handle_x(m_x2, seq_lengths2)], dim=-1)  # shape: (batch_size, 4 * hidden_size)  
        return self.MLP.get_scores(v)


class ESIM(nn.Module):
    def __init__(self, embedding_dim, vocab_dim, num_layers, embedding):
        super().__init__()
        if embedding:
            self.embedding = embedding
        else:
            self.embedding = nn.Embedding(vocab_dim, embedding_dim, padding_idx=pad_idx)
        self.BiLSTM=BiLSTM(embedding_dim, vocab_dim, True)
        self.LocalInterenceModeling=LocalInferenceModeling()
        self.inference_composition = InferenceComposition(input_size=hidden_size*4, hidden_size=hidden_size, class_num=class_num)
        self.loss_func = self.inference_composition.loss_func
        
    def forward(self, x1, length_list1, x2, length_list2):
        x1 = self.embedding(x1)  # 论文中的 a；shape: (batch_size, max_sentence_length1, embedding_size)
        x2 = self.embedding(x2)  # 论文中的 b；shape: (batch_size, max_sentence_length2, embedding_size)

        # 3.1  INPUT ENCODING
        x1_bar = self.BiLSTM(x1, lengths_list1)  # 论文中的 a_bar；shape: (batch_size, max_sentence_length1, hidden_size)
        x2_bar = self.BiLSTM(x2, lengths_list2)  # 论文中的 b_bar；shape: (batch_size, max_sentence_length2, hidden_size)

        # 3.2 Local Inference Modeling
        m_x1, m_x2 = self.local_inference_modeling.forward(x1_bar, length_list1, x2_bar, length_list2)  # 论文中的 ma/mb；shape: (batch_size, max_sentence_length_i, hidden_size*4)

        # 3.3 INFERENCE COMPOSITION
        scores = self.inference_composition(m_x1, length_list1, m_x2, length_list2)  # scores shape: (batch_size, class_num
        return scores

In [None]:
#too tedious ... let's take a break ...
