In [1]:
import pickle
import numpy as np
import torch
import torch.nn.functional as F
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
from apex import amp

In [2]:
all_data = pickle.load(open("./squad_processed_1.1.pickle", "rb"))
data = np.asarray(all_data[len(all_data) // 5:])
validation_data = np.asarray(all_data[: len(all_data) // 5])

In [3]:
print(len(all_data))
print(len(data))

87506
70005


In [4]:
def getBatch(bs = 64, validation=False):
    _use_data = data
    if (validation == True):
        _use_data = validation_data
    indices = np.random.randint(0, len(_use_data), (bs,))
    batch = np.asarray(data)[indices]
    #print(len(batch))
    #print(batch)
    inputs = torch.LongTensor([dp["data"]["tokens"] for dp in batch]).cuda()
    inputs = inputs[:,:384]
    attention_masks = inputs != 0 #works like numpy does.
    segments = torch.LongTensor([dp["data"]["segments"] for dp in batch]).cuda()
    segments = segments[:,:384]
    start_ = torch.LongTensor([dp["answer_start"] for dp in batch]).cuda()
    end_ = torch.LongTensor([dp["answer_end"] for dp in batch]).cuda()
    return inputs, segments, attention_masks, start_, end_

In [5]:
i, se, att, st, en =  getBatch(bs=3)
print(i.size(), se.size(), st.size(), en.size())

torch.Size([3, 384]) torch.Size([3, 384]) torch.Size([3]) torch.Size([3])


In [6]:
class SQuADHead(torch.nn.Module):
    def __init__(self, 
                num_bert_layers=1,
                backprop_thru_bert=False,
                internal_dim = 256                
                ):
        super(SQuADHead, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased").cuda()
        self.out = torch.nn.Linear(768,2)
        self.dropout = torch.nn.Dropout(0.1)
    
    def forward(self, inputs, segments, attention_masks):
        f, _ = self.bert(inputs, segments, attention_masks, output_all_encoded_layers=False)
        out_ = self.out(self.dropout(f))
        start_, end_ = torch.split(out_, 1, dim=-1)
        return start_.squeeze(), end_.squeeze()

In [None]:
model = SQuADHead().cuda()
loss = torch.nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)
from apex.fp16_utils import FP16_Optimizer

model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
#optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=False, static_loss_scale=128.0)

def Train(network, epochs=10, batches_per_epoch=3000, bs=20):
    start_losses = []
    end_losses = []
    for j in range(epochs):
        for k in range(batches_per_epoch):
            i, se, att, st, en = getBatch(bs=bs)
            st_, en_ = model.forward(i, se, att)
            items_to_use = en < 384
            st_ = st_[items_to_use]
            en_ = en_[items_to_use]
            st = st[items_to_use]
            en = en[items_to_use]
            optimizer.zero_grad()
            loss1 = loss(F.log_softmax(st_, dim=-1), st)
            loss2 = loss(F.log_softmax(en_, dim=-1), en)
            start_losses.append(loss1.data.item())
            end_losses.append(loss2.data.item())

            net_loss = (loss1 + loss2)/2
            with amp.scale_loss(net_loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            """
            with amp.scale_loss(loss2, optimizer) as scaled_loss:
                scaled_loss.backward()
            #loss2.backward()
            """
            optimizer.step()
            start_losses = start_losses[-1000:]
            end_losses = end_losses[-1000:]
            print("Epoch:", j, "Batch:", k, 
                  "Start Loss:", np.round(np.mean(start_losses), 5), 
                  "End Loss:", np.round(np.mean(end_losses), 5), end="\r")

Train(model)

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Gradient overflow.  Skipping step, reducing loss scale to 32768.0
Gradient overflow.  Skipping step, reducing loss scale to 16384.0
Gradient overflow.  Skipping step, reducing loss scale to 8192.0
Gradient overflow.  Skipping step, reducing loss scale to 4096.0
Gradient overflow.  Skipping step, reducing loss sc

In [None]:
print(start_losses)