In [15]:
from src.datasets.document import *
from src.defaults import *
import sys
old_stdout = sys.stdout
load_nli = 'NLI-93'
# download_models_from_neptune(load_nli)
field = load_field(load_nli)
dataset_conf = {'dataset': 'reuters', 'max_num_sent': 60,"sent_tokenizer":"spacy","batch_size":32,"device":"cuda"}


In [16]:
data = document_dataset(dataset_conf,sentence_field = field)

In [17]:
for i in data.train_iter:
    print(i)
    break


[torchtext.data.batch.Batch of size 32 from REUTERS]
	[.text]:[torch.cuda.LongTensor of size 32x60x50 (GPU 0)]
	[.label]:[torch.cuda.LongTensor of size 32x10 (GPU 0)]


In [4]:
model_conf = {'results_dir': 'results', 'device': 'cuda', 'dropout': 0.2, 'dataset': 'imdb', 'hidden_size': 300, "attention_layer_param":150,"freeze_encoder":False,"num_layers":1}

In [5]:
from src.model.nli_models import *
from src.model.novelty_models import *

def load_encoder(enc_data):
    if enc_data['options'].get("attention_layer_param",0)==0:
        model = bilstm_snli(enc_data["options"])
    elif enc_data['options'].get("r",0)==0:
        model = attn_bilstm_snli(enc_data["options"])
    else:
        model = struc_attn_snli(enc_data["options"])
    return model

nli_model_data = load_encoder_data(load_nli)
nli_model_data['options']["use_glove"] = False
encoder = load_encoder(nli_model_data).encoder
model_conf["encoder_dim"] = nli_model_data["options"]["hidden_size"]

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math



class Attention(nn.Module):
    def __init__(self, conf):
        super(Attention, self).__init__()
        self.Ws = nn.Linear(
            2 * conf["hidden_size"],
            conf["attention_layer_param"],
            bias=False,
        )
        self.Wa = nn.Linear(conf["attention_layer_param"], 1, bias=False)

    def forward(self, hid):
        opt = self.Ws(hid)
        opt = torch.tanh(opt)
        opt = self.Wa(opt)
        opt = F.softmax(opt, dim=1)
        return opt


class HAN_DOC(nn.Module):
    def __init__(self, conf, encoder):
        super(HAN_DOC, self).__init__()
        self.conf = conf
        self.encoder = encoder
        if self.conf["freeze_encoder"]:
            self.encoder.requires_grad_(False)

        self.translate = nn.Linear(
            2 * self.conf["encoder_dim"], self.conf["hidden_size"]
        )
        self.act = nn.ReLU()
        self.dropout = nn.Dropout(conf["dropout"])
        self.template = nn.Parameter(torch.zeros((1)), requires_grad=True)
        self.lstm_layer = nn.LSTM(
            input_size=self.conf["hidden_size"],
            hidden_size=self.conf["hidden_size"],
            num_layers=self.conf["num_layers"],
            bidirectional=True,
        )
        self.attention = Attention(conf)

    def forward(self, inp):
        batch_size, num_sent, max_len = inp.shape
        x = inp.view(-1, max_len)

        x_padded_idx = x.sum(dim=1) != 0
        x_enc = []
        for sub_batch in x[x_padded_idx].split(64):
            x_enc.append(self.encoder(sub_batch, None))
        x_enc = torch.cat(x_enc, dim=0)

        x_enc_t = torch.zeros((batch_size * num_sent, x_enc.size(1))).to(
            self.template.device
        )

        x_enc_t[x_padded_idx] = x_enc
        x_enc_t = x_enc_t.view(batch_size, num_sent, -1)

        embedded = self.dropout(self.translate(x_enc_t))
        embedded = self.act(embedded)

        all_, (_, _) = self.lstm_layer(embedded)
        attn = self.attention(all_)

        cont = torch.bmm(attn.permute(0, 2, 1), all_)
        cont = cont.squeeze(1)
        return cont


class HAN_DOC_Classifier(nn.Module):
    def __init__(self,conf,encoder):
        super().__init__()
        self.conf = conf
        self.han_doc = HAN_DOC(conf,encoder)

        self.act = nn.ReLU()
        self.dropout = nn.Dropout(conf["dropout"])
        self.fc = nn.Linear(2 * conf["hidden_size"], 10)

    def forward(self, x0):
        x0_enc = self.han_doc(x0)
        cont = self.dropout(self.act(x0_enc))
        cont = self.fc(cont)
        return cont


In [7]:
model = HAN_DOC_Classifier(model_conf,encoder)

In [8]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

14808461

In [9]:

from hyperdash import Experiment

sys.stdout = old_stdout

exp = Experiment("documentReuters",api_key_getter = get_hyperdash_api)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.BCELoss()
device = torch.device("cuda")

model.to(device)


def train(model,dl,optimizer,criterion):
    model.train()
    iterator_size = len(dl)
    n_correct, n_total,running_loss = 0, 0,0

    for i, data in enumerate(dl, 0):
        inputs, labels = data.text,data.label

        batch_size = labels.shape[0]
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(torch.sigmoid(outputs), labels.to(torch.float))
    

        exp.metric('train loss',loss.item(),log=False)
        print(f"loss = {loss.item()}",end = '\r')
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    
    
    exp.metric('train running loss',running_loss/iterator_size,log=False)
    print('Loss: {}'.format(running_loss))
    print("-------------")


def validate(model,dl,criterion):

    iterator_size = len(dl)
    n_correct, n_total,running_loss = 0, 0,0
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(dl, 0):
            inputs, labels = data.text,data.label
            batch_size = labels.shape[0]
            # forward 
            outputs = model(inputs)
            loss = criterion(torch.sigmoid(outputs), labels.to(torch.float))
            
            exp.metric('val loss',loss.item(),log=False)
            print(f"loss = {loss.item()}",end = '\r')
            running_loss += loss.item()
        
        
        
        exp.metric('val running loss',running_loss/iterator_size, log=False)
        print('Val Loss: {}'.format(running_loss))
        print("-------------")

try:

    # loop over the dataset multiple times
    for epoch in range(7):
        train(model,data.train_iter,optimizer,criterion)
        validate(model,data.test_iter,criterion)
    validate(model,data.test_iter,criterion)
    exp.end()
    print('Finished Training')
except KeyboardInterrupt:
    validate(model,testloader,criterion)
    exp.end()

Loss: 91.73408880829811
-------------
Val Loss: 31.3567091524601
-------------
Loss: 89.9389337003231
-------------
Val Loss: 31.342998579144478
-------------
Loss: 89.93985575437546
-------------
Val Loss: 31.35296420753002
-------------
Loss: 89.97350004315376
-------------
Val Loss: 31.425887420773506
-------------
Loss: 89.88115161657333
-------------
Val Loss: 31.475493922829628
-------------
Loss: 89.90036675333977
-------------
Val Loss: 31.48111118376255
-------------
Loss: 89.9422544836998
-------------
Val Loss: 31.322191655635834
-------------
Val Loss: 31.322191655635834
-------------
This run of documentReuters ran for 0:13:05 and logs are available locally at: /root/.hyperdash/logs/documentreuters/documentreuters_2021-03-04t05-17-33-516560.log
Finished Training


In [18]:
model.eval()
all_opts = torch.empty((0,10))
all_lab = torch.empty((0,10))
with torch.no_grad():
    for i, d in enumerate(data.test_iter, 0):
        inputs, labels = d.text,d.label
        outputs = torch.sigmoid(model(inputs))
        all_opts = torch.cat([all_opts,outputs.cpu()],dim=0)
        all_lab = torch.cat([all_lab,labels.cpu()],dim=0)

        
            

In [38]:
from sklearn.metrics import accuracy_score,classification_report


In [44]:
print(classification_report(np.round(all_opts),all_lab))

              precision    recall  f1-score   support

           0       0.00      0.00      0.00         0
           1       0.00      0.00      0.00         0
           2       0.00      0.00      0.00         0
           3       0.00      0.00      0.00         0
           4       0.00      0.00      0.00         0
           5       0.00      0.00      0.00         0
           6       0.00      0.00      0.00         0
           7       0.00      0.00      0.00         0
           8       0.00      0.00      0.00         0
           9       1.00      0.72      0.84      2938

   micro avg       0.41      0.72      0.52      2938
   macro avg       0.10      0.07      0.08      2938
weighted avg       1.00      0.72      0.84      2938
 samples avg       0.40      0.72      0.51      2938

