In [1]:
from lang import *
from datamodule import *
from snli.train_utils import *
datamodule = snli_bert_data_module(char_emb=True)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np


class ADIN_encoder_conf:
    embedding_dim = 300
    hidden_size = 300
    dropout = 0.1
    opt_labels = 3
    attention_layer_param = 100
    activation = "tanh"
    char_embedding_size = 50

    def __init__(self, lang, embedding_matrix=None, **kwargs):
        self.embedding_matrix = None
        self.char_emb = lang.char_emb
        self.char_vocab_size = lang.char_vocab_size
        self.char_word_len = lang.char_emb_max_len

        if lang.tokenizer_ == "BERT":
            self.vocab_size = lang.vocab_size
            self.padding_idx = lang.bert_tokenizer.vocab["[PAD]"]
        else:
            self.embedding_matrix = embedding_matrix
            self.vocab_size = lang.vocab_size_final()
            self.padding_idx = lang.word2idx[lang.config.pad]
        for k, v in kwargs.items():
            setattr(self, k, v)



class SentenceEncoder(nn.Module):
    def __init__(self, conf):
        super(Attn_Encoder, self).__init__()
        self.conf = conf
        self.embedding = nn.Embedding(
            num_embeddings=self.conf.vocab_size,
            embedding_dim=self.conf.embedding_dim,
            padding_idx=self.conf.padding_idx,
        )

        if self.conf.char_emb:
            self.char_embedding = nn.Embedding(
                num_embeddings=self.conf.char_vocab_size,
                embedding_dim=self.conf.char_embedding_size,
                padding_idx=0
            )
            self.char_cnn = nn.Conv2d(
                self.conf.char_word_len,
                self.conf.char_embedding_size,
                (1, 6),
                stride=(1, 1),
                padding=0,
                bias=True
            )

        self.translate = nn.Linear(
            self.conf.embedding_dim+(self.conf.char_embedding_size if self.conf.char_emb else 0), self.conf.hidden_size
        )  
        self.act = nn.ReLU()
        self.dropout = nn.Dropout(p=0.3)
        
        if isinstance(self.conf.embedding_matrix, np.ndarray):
            self.embedding.from_pretrained(
                torch.tensor(self.conf.embedding_matrix),
                freeze=self.conf.freeze_embedding,
            )

    def char_embedding_forward(self,x):
        #X - [batch_size, seq_len, char_emb_size])
        batch_size, seq_len, char_emb_size= x.shape
        x = x.view(-1,char_emb_size)
        x = self.char_embedding(x) #(batch_size * seq_len, char_emb_size, emb_size)
        x = x.view(batch_size, -1, seq_len, char_emb_size)
        x = x.permute(0,3,2,1)
        x = self.char_cnn(x)
        x = torch.max(F.relu(x), 3)[0]
        return x.view(-1,seq_len,self.conf.char_embedding_size)


    def forward(self, inp, char_vec = None):
        batch_size = inp.shape[0]
        embedded = self.embedding(inp)
        if char_vec!=None:
            char_emb = self.char_embedding_forward(char_vec)
            embedded = torch.cat([embedded,char_emb],dim=2)
        embedded = self.dropout(embedded)
        embedded = self.translate(embedded)
        embedded = self.dropout(self.act(embedded))
        embedded = embedded.permute(1, 0, 2)
        


class ADIN(nn.Module):
    def __init__(self, conf):
        super(Attn_encoder_snli, self).__init__()
        self.conf = conf
        self.encoder = SentenceEncoder(conf)
        self.fc_in = nn.Linear(
            (2 if conf.bidirectional else 1) * 4 * self.conf.hidden_size,
            self.conf.hidden_size,
        )
        self.fcs = nn.ModuleList(
            [
                nn.Linear(self.conf.hidden_size, self.conf.hidden_size)
                for i in range(self.conf.fcs)
            ]
        )
        self.fc_out = nn.Linear(self.conf.hidden_size, self.conf.opt_labels)
        if self.conf.activation.lower() == "relu".lower():
            self.act = nn.ReLU()
        elif self.conf.activation.lower() == "tanh".lower():
            self.act = nn.Tanh()
        elif self.conf.activation.lower() == "leakyrelu".lower():
            self.act = nn.LeakyReLU()
        self.softmax = nn.Softmax(dim=2)
        self.dropout = nn.Dropout(p=self.conf.dropout)

    def forward(self, x0, x1, x0_char_vec = None, x1_char_vec = None):
        x0_enc = self.encoder(x0.long(),char_vec = x0_char_vec)
        x0_enc = self.dropout(x0_enc)
        x1_enc = self.encoder(x1.long(),char_vec = x1_char_vec)
        x1_enc = self.dropout(x1_enc)
        cont = torch.cat(
            [x0_enc, x1_enc, torch.abs(x0_enc - x1_enc), x0_enc * x1_enc], dim=2
        )
        opt = self.fc_in(cont)
        opt = self.dropout(opt)
        for fc in self.fcs:
            opt = fc(opt)
            opt = self.dropout(opt)
            opt = self.act(opt)
        opt = self.fc_out(opt)
        return opt


In [None]:

hparams = {
    "optimizer_base": {
        "optim": "adamw",
        "lr": 0.0010039910781394373,
        "scheduler": "const",
    },
    "optimizer_tune": {
        "optim": "adam",
        "lr": 0.0010039910781394373,
        "weight_decay": 0.1,
        "scheduler": "lambda",
    },
    "switch_epoch": 5,
}

lang = datamodule.Lang

In [None]:
model_conf = Attn_Encoder_conf(lang,None, **conf_kwargs)
model = SNLI_char_emb(Attn_encoder_snli, model_conf, hparams)