In [16]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display, HTML
import os
from os.path import join
from ast import literal_eval
import itertools
from urllib.parse import urlparse

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
    
from transformers import BertModel
bert_id = "google/bert_uncased_L-2_H-128_A-2"

# 1 Load Data

In [17]:
train_joint = pd.read_hdf(join("preprocessed_data","train_joint.h5"), key="s")
validation_join = pd.read_hdf(join("preprocessed_data","validation_joint.h5"), key="s")

In [18]:
train_joint

Unnamed: 0_level_0,Unnamed: 1_level_0,is_relevant,sector_ids,sentence_position,sentence_length,tokenized_sentence,project_name,country_code,url,text_length,sentence_count
doc_id,sentence_id,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
51657,0,0,[],0.000000,5.252273,"[101, 2047, 4341, 2937, 3360, 2415, 1999, 2474...",1,2,139,7.699389,2.944439
51657,1,0,[],0.693147,4.127134,"[101, 10110, 2003, 2012, 1996, 2415, 1997, 199...",1,2,139,7.699389,2.944439
51657,2,0,[],1.098612,4.682131,"[101, 1996, 2047, 3360, 2415, 2038, 2366, 2062...",1,2,139,7.699389,2.944439
51657,3,0,[],1.386294,4.875197,"[101, 1996, 4341, 2937, 3360, 2415, 2001, 2764...",1,2,139,7.699389,2.944439
51657,4,0,[],1.609438,4.204693,"[101, 2116, 1997, 2122, 3360, 2272, 2013, 3532...",1,2,139,7.699389,2.944439
...,...,...,...,...,...,...,...,...,...,...,...
34512,121,0,[],4.804021,4.234107,"[101, 2174, 1010, 11470, 19621, 2015, 2024, 20...",5,1,0,10.068493,4.836282
34512,122,0,[],4.812184,5.017280,"[101, 1999, 5712, 1010, 2045, 2024, 4311, 1997...",5,1,0,10.068493,4.836282
34512,123,0,[],4.820282,4.234107,"[101, 1996, 9353, 9331, 2015, 2136, 2097, 2562...",5,1,0,10.068493,4.836282
34512,124,0,[],4.828314,4.174387,"[101, 2017, 2064, 2424, 2019, 19184, 1997, 203...",5,1,0,10.068493,4.836282


# 2 Define Dataset

In [19]:
class IsRelevantDataset(Dataset):
    def __init__(self, joint_dataframe: pd.DataFrame, device=device, dimensions = None):
        self.X = joint_dataframe[["sentence_position", "sentence_length", "tokenized_sentence", "project_name", "country_code", "url", "text_length", "sentence_count"]].to_numpy()
        self.Y = joint_dataframe["is_relevant"].to_numpy()
        self.device = device
        
        if dimensions is None:
            self.dimensions = ((1, (4, len(set(self.X[:,3])), len(set(self.X[:,4])), len(set(self.X[:,5])))), 2)
        else:
            self.dimensions = dimensions
        
    def __len__(self):
        return len(self.Y)

    
    def __getitem__(self, idx, x_one_hot = True, x_train_ready = True):
        
        """
        Note that x_train_ready implies x_one_hot
        """
        x_tmp = self.X[idx]
        metric_x = torch.tensor([x_tmp[0], x_tmp[1], x_tmp[6], x_tmp[7]], device=self.device)#numerical features
        sentence_x = torch.tensor(x_tmp[2], device=self.device, dtype=torch.long)#bert features
        sentence_x = torch.cat((sentence_x, torch.zeros(512 - sentence_x.shape[0], device=self.device, dtype= torch.long)))
        
        #one hot features:
        project_name_x = torch.tensor(x_tmp[3], device=self.device, dtype=torch.long)
        country_code_x = torch.tensor(x_tmp[4], device=self.device, dtype=torch.long)
        url_x = torch.tensor(x_tmp[5], device=self.device)
        
        y = torch.tensor(self.Y[idx], device=self.device, dtype=torch.long)

        if x_train_ready or x_one_hot:
            project_name_x = nn.functional.one_hot(project_name_x, num_classes = self.dimensions[0][1][1])
            country_code_x = nn.functional.one_hot(country_code_x, num_classes = self.dimensions[0][1][2])
            url_x = nn.functional.one_hot(url_x, num_classes = self.dimensions[0][1][3])
        if x_train_ready:
            x_other = torch.cat((metric_x, project_name_x, country_code_x, url_x), dim=0)
            return (sentence_x, x_other), y
        
        return (sentence_x, (metric_x, project_name_x, country_code_x, url_x)), y

In [20]:
train_ds = IsRelevantDataset(train_joint, device = device)
validation_ds = IsRelevantDataset(validation_join, device = device, dimensions = train_ds.dimensions)

In [21]:
elem = torch.unsqueeze(train_ds.__getitem__(1)[0][0], 0)
bert_model(elem)

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[ 7.9678e-01,  4.3363e-01,  3.0845e-01,  ..., -1.0185e+00,
           4.1972e-01, -1.2215e+00],
         [ 1.7186e+00,  2.3990e-01, -1.1979e+00,  ..., -7.5705e-01,
           2.4631e-01,  2.9912e-02],
         [ 9.1795e-01,  8.0292e-01, -1.1369e+00,  ...,  1.2281e-03,
           3.3684e-01, -2.4455e-01],
         ...,
         [ 5.6881e-01,  6.9844e-01, -8.0349e-01,  ..., -3.5966e-01,
           6.5368e-01,  5.5346e-01],
         [ 1.0129e+00, -1.6660e-01,  1.7266e-01,  ..., -7.5208e-01,
          -1.1192e-01,  2.0395e-01],
         [ 3.7733e-01,  9.6487e-02, -7.2973e-01,  ..., -8.6767e-01,
          -9.0913e-02,  2.4773e-01]]], device='cuda:0',
       grad_fn=<NativeLayerNormBackward>), pooler_output=tensor([[ 0.2679,  0.5989,  0.4089, -0.1942,  0.1793,  0.3579,  0.6238, -0.5477,
         -0.4427,  0.6706,  0.3099,  0.1910,  0.5141, -0.8082,  0.3162,  0.6797,
          0.0857,  0.0825, -0.6264,  0.4433,  0.0432, -

# 4 Model Definition

In [22]:
class IsRelevantNet(nn.Module):
    def __init__(self, bert: BertModel, input_size, output_size):
        super(IsRelevantNet, self).__init__()
        
        self.bert = bert
        self.feed_forward = nn.Sequential(
            #nn.BatchNorm1d(bert.config.hidden_size + input_size),#just a feeling this might be nice
            nn.Linear(bert.config.hidden_size + input_size, 256),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(256, output_size)
        )
    def forward(self, x):
        x_bert = x[0]
        x_other = x[1]
        y_bert = self.bert(x[0])["last_hidden_state"][:,0] #all batches but only clf output
        
        x = torch.cat((y_bert, x_other), dim=1)#dim=1 is feature dimensions (0 is batch dim)
        
        return self.feed_forward(x)

# 5 Training Routine

In [23]:
def update(model, optimizer, loss, loader, tracker=None):
    model.train()
    
    for x, y in loader:
        y_hat = model(x)
        l = loss(y_hat, y)
        
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        
        print(f"\r {l.item()}",end="")
        if tracker:
            tracker.next_train(l)
    
    if tracker:
        tracker.submit_train()

@torch.no_grad()
def evaluate(model, metric, loader, tracker=None):
    model.eval()
    
    for x,y in loader:
        y_hat = model(x)
        l = metric(y_hat, y)
        if tracker:
            tracker.next_eval(l)
            
    if tracker:
        tracker.submit_eval()

In [24]:
lr = 1e-3
batch_size=16
epochs = 10
bert_hidden = 768

In [25]:
train_dl = DataLoader(train_ds,batch_size  = batch_size, shuffle=True)
validation_dl = DataLoader(train_ds,batch_size  = 64, shuffle=False)

loss = nn.CrossEntropyLoss()
model = IsRelevantNet(BertModel.from_pretrained(bert_id).to(device), sum(train_ds.dimensions[0][1]), train_ds.dimensions[1]).to(device)
#Should not train bert (for now)
model.bert.train(False)
for p in model.bert.parameters():
    p.requires_grad = False
optimizer = torch.optim.Adam(model.parameters(), lr = lr)

print(f"Number of parameters (including bert): {sum(p.numel() for p in model.parameters())}")
print(f"Number of trainable parameters (excluding bert): {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

Downloading:   0%|          | 0.00/382 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/17.7M [00:00<?, ?B/s]

Number of parameters (including bert): 4528002
Number of trainable parameters (excluding bert): 142082


In [26]:
for n_epoch in range(1, epochs+1):
    update(model, optimizer, loss, train_dl)

 0.27797731757164244

KeyboardInterrupt: 