In [1]:
import pickle
import numpy as np
from pytorch_pretrained_bert import BertTokenizer, BertModel
import torch
import torch.nn.functional as F
import json
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [2]:
data_path = "../../training_processed.pickle"
all_data = pickle.load(open(data_path, "rb"))

In [3]:
print(len(all_data["all_data"]), len(all_data["positives"]), len(all_data["negatives"]))

1804874 144334 1660540


In [4]:
positive_cutoff = len(all_data["positives"])//10
negative_cutoff = len(all_data["negatives"])//10

testing_positives = all_data["positives"][-positive_cutoff:]
testing_negatives = all_data["negatives"][-negative_cutoff:]

testing_data = [(index, 1) for index in testing_positives]
testing_data.extend([(index, 0) for index in testing_negatives])

np.random.shuffle(testing_data)
print(testing_data[:100])

training_positives = all_data["positives"][:-positive_cutoff]
training_negatives = all_data["negatives"][:-negative_cutoff]

_data = all_data["all_data"]

[(1744979, 0), (1707199, 0), (1797973, 0), (1727052, 0), (1635407, 0), (1770395, 0), (1740143, 0), (1709795, 0), (1685769, 0), (1740324, 0), (1678068, 0), (1755983, 0), (1721615, 0), (1689482, 0), (1742046, 0), (1783474, 0), (1743790, 0), (1692815, 0), (1774693, 0), (1754721, 0), (1761517, 0), (1730297, 1), (1705046, 0), (1707279, 0), (1641161, 0), (1658610, 0), (1734347, 0), (1674092, 0), (1624932, 0), (1783277, 0), (1662468, 0), (1683425, 0), (1670370, 0), (1636440, 0), (1742690, 0), (1695392, 1), (1724103, 1), (1722665, 0), (1659128, 0), (1647415, 0), (1633997, 0), (1701854, 0), (1793552, 0), (1708990, 0), (1692282, 0), (1778700, 0), (1795208, 0), (1717619, 0), (1778704, 0), (1753256, 0), (1726655, 1), (1689093, 0), (1747679, 0), (1782379, 0), (1752245, 0), (1734384, 0), (1675771, 0), (1649371, 1), (1718757, 0), (1664910, 1), (1734367, 0), (1695018, 0), (1749653, 0), (1719364, 0), (1700847, 0), (1772179, 0), (1681930, 0), (1662210, 0), (1664322, 0), (1774645, 0), (1662349, 0), (1677

In [5]:
max_comment_length = 120 #99th percentile of comment lengths

In [6]:
def genBatch(bs=8, testing = False, cuda=True):
    
    _positive_data = training_positives
    _negative_data = training_negatives
    
    _batch = None
    if (not testing):
        #during training always present a balanced training set
        positive_samples = bs//2
        negative_samples = bs - positive_samples
        _p = np.random.randint(0, len(_positive_data), (positive_samples,))
        _n = np.random.randint(0, len(_negative_data), (negative_samples,))

        _batch = [(tokenizer.convert_tokens_to_ids(_data[index]["comment_text"]), 1) 
                  for index in _p]
        _batch.extend([(tokenizer.convert_tokens_to_ids(_data[index]["comment_text"]), 0) 
                       for index in _n])
    else:
        _batch = []
        _indices = np.random.randint(0, len(testing_data), (bs,))
        for list_index in _indices:
            _index = testing_data[list_index][0]
            _class = testing_data[list_index][1]
            _batch.append((tokenizer.convert_tokens_to_ids(_data[_index]["comment_text"]), _class))
        
    np.random.shuffle(_batch)
    _docs = [dp[0] for dp in _batch]
    _y = [dp[1] for dp in _batch]
    
    docs = []
    for _doc in _docs:
        _doc.insert(0, 101)
        while (len(_doc) < max_comment_length):
            _doc.append(0)
        docs.append(_doc[:max_comment_length])
        
    docs = np.asarray(docs)
    segments = np.zeros(docs.shape)
    y = np.asarray(_y)
    if (cuda):
        docs = torch.LongTensor(docs).cuda()
        segments = torch.LongTensor(segments).cuda()
        y = torch.FloatTensor(y).cuda()
    else:
        docs = torch.LongTensor(docs)
        segments = torch.LongTensor(segments)
        y = torch.FloatTensor(y)
    mask = docs > 0
    
    return docs, segments, mask, y

d, se, m, y = genBatch(bs=12, testing=False, cuda=True)
print(d.size(), se.size(), m.size(), y.size())

torch.Size([12, 120]) torch.Size([12, 120]) torch.Size([12, 120]) torch.Size([12])


In [7]:
class AttentionHead(torch.nn.Module):
    def __init__(self, dim=64, bert_model = "bert-base-uncased"):
        super(AttentionHead, self).__init__()
        self.bert_dim = 768
        if ("large" in bert_model):
            self.bert_dim = 1024
        self.w = torch.nn.Linear(self.bert_dim, dim)
        self.v = torch.nn.Linear(dim,1)
        self.o = torch.nn.Linear(self.bert_dim, dim)
    
    def forward(self, _d):
        _att = torch.tanh(self.w(_d))
        _att = self.v(_att)
        _att = F.softmax(_att, dim=1)
        _o = _d * _att
        _o = torch.sum(_o, dim=1)
        return self.o(_o)

In [8]:
class multiHeadedClassifier(torch.nn.Module):
    def __init__(self, attention_heads = 1, 
                 attention_head_dim=512,
                 bert_model = "bert-base-uncased",
                 output_dims = 1
                ):
        super(multiHeadedClassifier, self).__init__()
        self.attentions = torch.nn.ModuleList([])
        self.bert_dim = 768
        if ("large" in bert_model):
            self.bert_dim = 1024
        self.bert = BertModel.from_pretrained("bert-base-uncased").cuda()
        for i in range(attention_heads):
            self.attentions.append(AttentionHead(dim=attention_head_dim, bert_model=bert_model))
        self.lin = torch.nn.Linear(self.bert_dim + attention_heads*attention_head_dim, attention_head_dim)
        self.output = torch.nn.Linear(attention_head_dim, 1)
    
    def forward(self, d, se, m):
        _d, _ = self.bert(d, se, m, output_all_encoded_layers=False)
        
        att_outs = []
        for i in range(len(self.attentions)):
            head_out = self.attentions[i](_d)
            att_outs.append(head_out)
        att_outs = torch.cat(att_outs, dim=-1).unsqueeze(1)
        _out = torch.cat([_d[:,0,:].unsqueeze(1), att_outs], dim=-1)
        _out = torch.tanh(self.lin(_out))
        return self.output(_out).squeeze(-1)

In [10]:
network = multiHeadedClassifier(attention_heads=1, attention_head_dim=256)
network.cuda()

multiHeadedClassifier(
  (attentions): ModuleList(
    (0): AttentionHead(
      (w): Linear(in_features=768, out_features=256, bias=True)
      (v): Linear(in_features=256, out_features=1, bias=True)
      (o): Linear(in_features=768, out_features=256, bias=True)
    )
  )
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): FusedLayerNorm(torch.Size([768]), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
 

In [11]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_auc_score as roc
def validate(network, batches=10, bs=64):
    print("\tValidating...")
    y_preds = []
    y_acts = []
    for i in range(batches):
        with torch.no_grad():
            d, se, m, y = genBatch(bs=12, testing=True, cuda=True)
            y_pred = network.forward(d, se, m)
            if (len(y_preds) == 0):
                y_preds = y_pred
                y_acts = y
            else:
                y_preds = torch.cat([y_preds, y_pred], dim=0)
                y_acts = torch.cat([y_acts, y])

    loss = loss_fn(y_preds, y_acts.unsqueeze(-1))
    score = roc(y_acts.cpu().numpy(), y_preds.cpu().numpy())
    print("\t", np.round(loss.data.item(), 5), np.round(score,2))
    return loss.data.item(), score

In [12]:
validation_losses = []
epoch_losses = []
validation_rocs = []

def _save(network, cause):
    print("\tSaving model for cause:", cause)
    torch.save(network.state_dict(), "./UnbiasedToxicity_" + cause + ".h5")
    _trgcycle = {
        "training_losses":epoch_losses,
        "validation_losses":validation_losses,
        "roc_auc":validation_rocs
    }
    with open("./TrainingCycle_UbiasedToxicity_" + cause + ".json", "w") as f:
        f.write(json.dumps(_trgcycle))
        f.close()
    
def saveModel(network):
    loss, roc_auc = validate(network)
    validation_losses.append(loss)
    validation_rocs.append(roc_auc)
    
    if (np.min(epoch_losses) == epoch_losses[-1]):
        _save(network,"BestTrainingLoss")
    
    if (np.min(validation_losses) == validation_losses[-1]):
        _save(network, "BestValidationLoss")
    
    if (np.max(validation_rocs) == validation_rocs[-1]):
        _save(network, "BestRoCAUC")

In [13]:
def train(network, optimizer=None, loss_function=None, epochs=2, batches_per_epoch=10, bs=12):
    for k in range(epochs):
        batch_losses = []
        batch_rocs = []
        for j in range(batches_per_epoch):
            optimizer.zero_grad()
            d, se, m, y = genBatch(bs=bs, testing=False, cuda=True)
            y_pred = network.forward(d, se, m)
            loss = loss_function(y_pred, y.unsqueeze(-1))
            loss.backward()
            optimizer.step()
            y_act = y.detach().cpu().numpy()
            y_ = torch.sigmoid(y_pred).detach().cpu().numpy()
            batch_losses.append(loss.data.item())
            roc_ = roc(y_act,y_)
            batch_rocs.append(roc_)
            _str = "Epoch: " + str(k + 1) + "; Batch: (" + str(j+1) + "/" + str(batches_per_epoch) + ")"
            _str = _str + "\tLoss: " + str(np.round(np.mean(batch_losses), 5)) + \
                    "; AUC:" + str(np.round(np.mean(batch_rocs), 2))
            print(_str, end="\r")
        print("\n")
        epoch_losses.append(np.mean(batch_losses))
        saveModel(network)

In [14]:
loss_fn = torch.nn.BCEWithLogitsLoss()
optim = torch.optim.SGD(network.parameters(), lr=1e-2)

In [None]:
train(network, optimizer=optim, loss_function=loss_fn, epochs=10, bs=64, batches_per_epoch=500)
#train(network, optimizer=optim, loss_function=loss_fn, epochs=2, bs=12, batches_per_epoch=5)

Epoch: 1; Batch: (9/500)	Loss: 0.69299; AUC:0.51