In [6]:
import sys
sys.path.append(".")

import torch
import pytorch_lightning as pl
import torch.nn.functional as F

from datamodule import *
from pytorch_lightning.callbacks import LearningRateLogger
from pytorch_lightning.loggers import NeptuneLogger, TensorBoardLogger
from pytorch_lightning.profiler import AdvancedProfiler
from pytorch_lightning.metrics import Accuracy

import pickle
import os
from joblib import Memory
import shutil
import argparse
from lang import *
import joblib
from pytorch_lightning import Callback
from transformers import BertModel, BertTokenizer,DistilBertTokenizer
from lang import *
from snli.train_utils import SNLI_model, snli_glove_data_module, snli_bert_data_module,SwitchOptim
from utils.keys import NEPTUNE_API
from utils.helpers import seed_torch
from utils.save_models import save_model,save_model_neptune


In [7]:
data_module = snli_bert_data_module(128)
Lang = data_module.Lang
embedding_matrix = None

seed_torch()

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math


class Transformer_config:
    embedding_dim = 768
    initializer_range=0.02
    max_len = 150
    sub_enc_layer = 3
    n_heads = 12
    interaction = "concat"

    def __init__(self, lang, embedding_matrix=None, **kwargs):
        self.embedding_matrix = None
        if lang.tokenizer_ == "BERT":
            self.vocab_size = lang.vocab_size
            self.padding_idx = lang.bert_tokenizer.vocab["[PAD]"]
        else:
            self.embedding_matrix = embedding_matrix
            self.vocab_size = lang.vocab_size_final()
            self.padding_idx = lang.word2idx[lang.config.pad]
        for k, v in kwargs.items():
            setattr(self, k, v)





def _init_weights(module):
    """ Initialize the weights """
    if isinstance(module, (nn.Linear, nn.Embedding)):
        # Slightly different from the TF version which uses truncated_normal for initialization
        # cf https://github.com/pytorch/pytorch/pull/5617
        module.weight.data.normal_(mean=0.0, std=0.02)
    elif isinstance(module, nn.LayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
        module.bias.data.zero_()


class TransformerEncoder(nn.Module):
    def __init__(self, conf):
        super(TransformerEncoder, self).__init__()
        self.conf = conf
        self.word_embedding = nn.Embedding(
            num_embeddings=self.conf.vocab_size,
            embedding_dim=self.conf.embedding_dim,
            padding_idx=self.conf.padding_idx,
        )

        self.pos_embedding = nn.Embedding(
            num_embeddings=self.conf.max_len,
            embedding_dim=self.conf.embedding_dim,
        )
        self.register_buffer("position_ids", torch.arange(self.conf.max_len).expand((1, -1)))


        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=self.conf.embedding_dim, nhead=self.conf.n_heads),
            self.conf.sub_enc_layer,
        )
        self.LayerNorm = nn.LayerNorm(self.conf.embedding_dim)
        self.pooler = nn.Linear(self.conf.embedding_dim, self.conf.embedding_dim)

        self.translate = nn.Linear(self.conf.embedding_dim, self.conf.embedding_dim)
        self.template = nn.Parameter(torch.zeros((1)), requires_grad=True)
        self.dropout = nn.Dropout(p=0.3)

    def make_src_mask(self,src):
        mask = src.transpose(0,1)==self.conf.padding_idx
        return mask


    def forward(self, x):
        x = x.transpose(0, 1)
        seq_length, N = x.shape
        
        position_ids = self.position_ids[:, :seq_length].expand(N,-1).transpose(0,1)

        
        emb = self.word_embedding(x)
        pos_embedding = self.pos_embedding(position_ids)
        emb = emb + pos_embedding

        emb = self.LayerNorm(emb)
        emb = self.dropout(emb)

        mask = self.make_src_mask(x)
        opt = self.transformer(emb, src_key_padding_mask = mask)
        opt = self.pooler(opt)
        opt = self.dropout(F.tanh(opt))
        opt = self.translate(opt)
        # opt = self.dropout(F.relu(opt))
        return opt

    


class Transformer_snli(nn.Module):
    def __init__(self, conf):
        super(Transformer_snli, self).__init__()
        self.conf = conf
        self.encoder = TransformerEncoder(self.conf)

        if self.conf.interaction == "concat":
            final_dim = 2 * self.conf.embedding_dim
            self.interact = self.interact_concat

        elif self.conf.interaction == "sum_prod":
            final_dim = 4 * self.conf.embedding_dim
            self.interact = self.interact_sum_prod

        self.cls = nn.Linear(final_dim, 3)
        self.softmax = nn.Softmax(dim=2)
        self.apply(_init_weights)

    def interact_concat(self, a, b):
        return torch.cat([a, b], dim=2)

    def interact_sum_prod(self, a, b):
        return torch.cat([a, b, a + b, a * b], dim=2)

    def forward(self, x0, x1):
        x0_emb = self.encoder(x0)[:1, :, :]
        x1_emb = self.encoder(x1)[:1, :, :]

        conc = self.interact(x0_emb,x1_emb)
        opt = self.cls(conc)

        return opt


In [14]:
hparams = {
    "optimizer_base":{
        "optim": "adamw",
        "lr": 3e-4,
        "scheduler": "const"
        },
    "optimizer_tune":{
        "optim": "adam",
        "lr": 3e-4,
        "weight_decay": 0.1,
        "scheduler": "lambda"
    },
    "switch_epoch":2,
}

conf_kwargs = {
    "batch_size":128
}

EPOCHS=5

conf = Transformer_config(Lang,**conf_kwargs)
# model = Transformer_snli(conf)
model = SNLI_model(Transformer_snli,conf,hparams)

None


In [15]:
tensorboard_logger = TensorBoardLogger("lightning_logs")
lr_logger = LearningRateLogger(logging_interval="step")

trainer = pl.Trainer(
    gpus=1,
    max_epochs=EPOCHS,
    progress_bar_refresh_rate=10,
    profiler=False,
    auto_lr_find=False,
    callbacks=[lr_logger,SwitchOptim()],
    logger=[tensorboard_logger],
    row_log_interval=2,
    # gradient_clip_val=0.5
)

trainer.fit(model, data_module)
trainer.test(model, datamodule=data_module)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params | In sizes | Out sizes
------------------------------------------------------------------
0 | model | Transformer_snli | 41 M   | ?        | ?        


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Saving latest checkpoint..



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.6338), 'test_loss': tensor(0.8200, device='cuda:0')}
--------------------------------------------------------------------------------



[{'test_acc': 0.6338270902633667, 'test_loss': 0.8200086951255798}]

In [63]:
configuration = BertConfig()

In [6]:
dl = data_module.train_dataloader()
for i in dl:
    a,b,c = i
    break

In [11]:
a.shape

torch.Size([128, 100])

In [37]:
b.shape

torch.Size([128, 100])

In [38]:
c.shape

torch.Size([128])

In [29]:
asd = torch.arange(150)

In [38]:
asd.expand(1,-1).expand(128,-1).shape

torch.Size([128, 150])

In [34]:
asd.shape

torch.Size([150])