In [1]:
# imports
import json
from pathlib import Path
import os
import sys
import pickle
import torch
from tqdm import tqdm

# Paths preprocessed datasets, generated by utils/prep.py (to speed up training)
DIR_DATASET = Path(current_dir=os.getcwd()).parent.absolute()/'dataset_DocRED'
DATA_DEV = DIR_DATASET/'dev.json' 
DATA_TRAIN = DIR_DATASET/'train_annotated.json'

DIR_DATASET_PREP = Path(current_dir=os.getcwd()).parent.absolute()/'dataset_prep'
DATA_DEV_PREP = DIR_DATASET_PREP/'dev_prep.pickle' 
DATA_TRAIN_PREP = DIR_DATASET_PREP/'train_annotated_prep.pickle'

PATH_UTILS = DIR_DATASET.parent.absolute().__str__()

# add local lib to path
if PATH_UTILS not in sys.path:
    sys.path.append(PATH_UTILS)

  DIR_DATASET = Path(current_dir=os.getcwd()).parent.absolute()/'dataset_DocRED'
  DIR_DATASET_PREP = Path(current_dir=os.getcwd()).parent.absolute()/'dataset_prep'


In [2]:
from utils import DocREDDataset, collate_fn
from torch.utils.data import DataLoader, random_split
# load preped dataset
fData = open(DATA_TRAIN_PREP.__str__(), 'rb')
data = pickle.load(fData)
fData.close()

# train/val split
train_size = int(0.8 * len(data))  # 80% for training
val_size = len(data) - train_size  # 20% for validation

# test dataset
fData = open(DATA_DEV_PREP.__str__(), 'rb')
data_test = pickle.load(fData)
fData.close()

In [None]:
# Define model
from torch import nn
from pytorch_tcn import TemporalConv1d

# TCN block
class TCN_Block(nn.Module):
    def __init__(self, in_features, out_features, kernel_size=7, stride=1, dilation=1, dropout=0.3):
        super(TCN_Block, self).__init__()

        # temporal conv with weight_norm
        self.res = nn.Conv1d(in_features, out_features, kernel_size=1) # 1x1 conv sampling for residual
        self.conv = nn.utils.parametrizations.weight_norm(TemporalConv1d(in_features, out_features, kernel_size, stride, dilation=dilation))
        self.relu = nn.ReLU() 
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = self.conv(x) # Dilated Causal Conv
        x = self.dropout(self.relu(x)) # dropout regularization
        return x

# The model
class LSTM_TCNClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, hidden_dims_tcn, dropout=0.3):
        super(LSTM_TCNClassifier, self).__init__()
        self.coref_embed = nn.Embedding(512, 20, padding_idx=0)
        self.ner_embed = nn.Embedding(512, 20, padding_idx=0)

        # tcn
        tcn_layers = []
        for i in range(len(hidden_dims_tcn)):
            in_features = 0
            if i == 0:
                in_features = input_dim
            else:
                in_features = hidden_dims_tcn[i-1]
            out_features = hidden_dims_tcn[i]
            dilation = 2 ** i # dilation 1, 2, 4, ...
            tcn_layer = TCN_Block(in_features, out_features, dilation=dilation, dropout=dropout)
            tcn_layers.append(tcn_layer)
        self.tcn = nn.Sequential(*tcn_layers)

        # lstm
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout)

        # layer norms
        self.ln_h = nn.LayerNorm(hidden_dim)
        self.ln_t = nn.LayerNorm(hidden_dim)

        # classification
        self.dropout1 = nn.Dropout(p=dropout)
        self.fc1 = nn.Linear(hidden_dim+hidden_dims_tcn[-1], hidden_dim)
        self.relu1 = nn.ReLU()

        # bilinear output
        self.bln = nn.Bilinear(hidden_dim*2, hidden_dim*2, output_dim)

        # self.softmax = nn.Softmax(dim=-1) # NOTE: removed, output bilinear, use sigmoid later

        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
    
    def forward(self, embed, coref, ner, map_h, map_t):
        # coref 
        coref = self.coref_embed(coref)
        # ner (coref_type)
        ner = self.ner_embed(ner)
        
        net_in = torch.cat([embed, coref, ner], dim=-1)

        # tcn
        tcn_out = self.tcn(net_in.transpose(-1, -2)).transpose(-1, -2)

        # lstm
        hidden, carry = torch.zeros(self.num_layers, net_in.shape[0], self.hidden_dim).to(embed.device), torch.zeros(self.num_layers, net_in.shape[0], self.hidden_dim).to(embed.device)
        # xavier init hidden and carry for stability over randn
        nn.init.xavier_normal_(hidden)
        nn.init.xavier_normal_(carry)
        lstm_out, (hidden, carry) = self.lstm(net_in, (hidden, carry))


        # Extract features from time sequence output

        # concat tcn and lstm features
        seq_out = torch.cat([tcn_out, lstm_out], dim=-1)
        # linear layer
        seq_out = self.dropout1(self.relu1(self.fc1(seq_out)))

        # saperate features for head and tail entities
        map_h = map_h.float().unsqueeze(1)
        map_t = map_t.float().unsqueeze(1)

        seq_out_h = self.ln_h(torch.matmul(map_h, seq_out).squeeze(1)) 
        seq_out_t = self.ln_t(torch.matmul(map_t, seq_out).squeeze(1))

        dist = seq_out_t - seq_out_h

        seq_out_h = torch.concat([seq_out_h, dist], dim=-1)
        seq_out_t = torch.concat([seq_out_t, -dist], dim=-1)
        
        # bilinear output
        out = self.bln(seq_out_h, seq_out_t)
        return out

# Model Setup & Training

(Note: Training took about 1 hour on a machine with 8C16T CPU and Nvidia RTX 3060-12G, loading the pre-processed training and testing datasets takes around 26GB of RAM)

In [4]:
# Parameters
input_dim = 768 + 20 + 20  # DistilledBERT pretrained embeddings + coref + ner type (both head and tail)
hidden_dim = 128
hidden_dims_tcn=[256,128,64]
dropout = 0.5
output_dim = 96 + 1  # 96 types + Na
num_layers = 2
batch_size = 32
learning_rate = 0.0002
max_epochs = 50 # has to be multiple of k_folds

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

# Split dataset for training and validation
dataset = DocREDDataset(data, na_factor=0.5) # introduce new random NA examples in new k-fold cycle
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

# model
model = LSTM_TCNClassifier(input_dim, hidden_dim, output_dim, num_layers, dropout=dropout, hidden_dims_tcn=hidden_dims_tcn).to(device)
criterion = nn.BCEWithLogitsLoss() # Classification loss
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

def train(model, embed, coref, ner, y, map_h, map_t):
    optimizer.zero_grad()
    outputs = model(embed, coref, ner, map_h, map_t)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()

    return loss.item()

import numpy as np
# store train losses
losses_train = []
losses_val = []

# training loop
(DIR_DATASET_PREP.parent/'checkpoints').mkdir(exist_ok=True) # create output directory
for epoch in range(1, max_epochs+1):
    # train
    train_loss = []
    for batch in tqdm(train_loader, desc=f"epoch {epoch} [Train]"):
        x, y = batch

        x = x.to(device)
        y = y.to(device)

        embed = x[:,:,:768]
        coref = x[:,:,768].long()
        ner = x[:,:,769].long()
        map_h = x[:,:,770]
        map_t = x[:,:,771]

        loss = train(model, embed, coref, ner, y, map_h, map_t)
        train_loss.append(loss)
    avg_train_loss = sum(train_loss)/len(train_loss)
    losses_train.append(avg_train_loss)

    # val
    val_loss = []
    for batch in tqdm(val_loader, desc=f"epoch {epoch} [Val]"):
        with torch.no_grad():
            x, y = batch
            x = x.to(device)
            y = y.to(device)
            embed = x[:,:,:768]
            coref = x[:,:,768].long()
            ner = x[:,:,769].long()
            map_h = x[:,:,770]
            map_t = x[:,:,771]
            outputs = model(embed, coref, ner, map_h, map_t)
            loss = criterion(outputs, y)
            val_loss.append(loss.item())
    avg_val_loss = sum(val_loss)/len(val_loss)
    losses_val.append(avg_val_loss)

    print(f"loss_train={avg_train_loss}, loss_val={avg_val_loss}")
    
    if epoch % 5 == 0:
        # save checkpoint 
        torch.save(model.state_dict(), DIR_DATASET_PREP.parent/'checkpoints'/f'lstm_tcn_{epoch}e.model')

# save loss for logging
loss_train = np.array(losses_train)
np.save(DIR_DATASET_PREP.parent/'checkpoints'/'loss_train.npy', loss_train)
loss_val = np.array(losses_val)
np.save(DIR_DATASET_PREP.parent/'checkpoints'/'loss_val.npy', loss_val)

100%|██████████| 3053/3053 [00:04<00:00, 664.49it/s]
epoch 1 [Train]: 100%|██████████| 1413/1413 [01:20<00:00, 17.48it/s]
epoch 1 [Val]: 100%|██████████| 354/354 [00:06<00:00, 55.11it/s]


loss_train=0.09923500584679677, loss_val=0.04370540759847157


epoch 2 [Train]: 100%|██████████| 1413/1413 [01:19<00:00, 17.73it/s]
epoch 2 [Val]: 100%|██████████| 354/354 [00:06<00:00, 54.93it/s]


loss_train=0.0380226759304624, loss_val=0.03458818678591548


epoch 3 [Train]: 100%|██████████| 1413/1413 [01:20<00:00, 17.57it/s]
epoch 3 [Val]: 100%|██████████| 354/354 [00:06<00:00, 54.28it/s]


loss_train=0.03230339437011321, loss_val=0.030130370114903667


epoch 4 [Train]: 100%|██████████| 1413/1413 [01:20<00:00, 17.63it/s]
epoch 4 [Val]: 100%|██████████| 354/354 [00:06<00:00, 53.89it/s]


loss_train=0.02823377774483923, loss_val=0.027013055185687408


epoch 5 [Train]: 100%|██████████| 1413/1413 [01:18<00:00, 18.08it/s]
epoch 5 [Val]: 100%|██████████| 354/354 [00:06<00:00, 55.07it/s]


loss_train=0.02579704144744167, loss_val=0.025349992330022955


epoch 6 [Train]: 100%|██████████| 1413/1413 [01:18<00:00, 17.94it/s]
epoch 6 [Val]: 100%|██████████| 354/354 [00:05<00:00, 59.44it/s]


loss_train=0.02396405703254355, loss_val=0.023826376945411756


epoch 7 [Train]: 100%|██████████| 1413/1413 [01:14<00:00, 18.94it/s]
epoch 7 [Val]: 100%|██████████| 354/354 [00:05<00:00, 59.66it/s]


loss_train=0.02289556536654779, loss_val=0.023391788336053744


epoch 8 [Train]: 100%|██████████| 1413/1413 [01:16<00:00, 18.55it/s]
epoch 8 [Val]: 100%|██████████| 354/354 [00:06<00:00, 56.82it/s]


loss_train=0.021651953403420764, loss_val=0.021935891812551494


epoch 9 [Train]: 100%|██████████| 1413/1413 [01:18<00:00, 18.05it/s]
epoch 9 [Val]: 100%|██████████| 354/354 [00:06<00:00, 55.47it/s]


loss_train=0.020722564876011013, loss_val=0.021230052880374557


epoch 10 [Train]: 100%|██████████| 1413/1413 [01:18<00:00, 17.90it/s]
epoch 10 [Val]: 100%|██████████| 354/354 [00:06<00:00, 55.71it/s]


loss_train=0.01978014661981253, loss_val=0.02088314535801357


epoch 11 [Train]: 100%|██████████| 1413/1413 [01:19<00:00, 17.81it/s]
epoch 11 [Val]: 100%|██████████| 354/354 [00:06<00:00, 55.42it/s]


loss_train=0.01909405220516147, loss_val=0.020507659173786302


epoch 12 [Train]: 100%|██████████| 1413/1413 [01:21<00:00, 17.42it/s]
epoch 12 [Val]: 100%|██████████| 354/354 [00:06<00:00, 55.56it/s]


loss_train=0.018426775104469724, loss_val=0.019676966134900765


epoch 13 [Train]: 100%|██████████| 1413/1413 [01:18<00:00, 17.98it/s]
epoch 13 [Val]: 100%|██████████| 354/354 [00:06<00:00, 58.89it/s]


loss_train=0.017721330050988054, loss_val=0.019180246165790463


epoch 14 [Train]: 100%|██████████| 1413/1413 [01:16<00:00, 18.51it/s]
epoch 14 [Val]: 100%|██████████| 354/354 [00:06<00:00, 58.83it/s]


loss_train=0.01702737514165567, loss_val=0.018649179877786987


epoch 15 [Train]: 100%|██████████| 1413/1413 [01:17<00:00, 18.28it/s]
epoch 15 [Val]: 100%|██████████| 354/354 [00:06<00:00, 58.34it/s]


loss_train=0.016569984415038545, loss_val=0.01942865483957411


epoch 16 [Train]: 100%|██████████| 1413/1413 [01:16<00:00, 18.42it/s]
epoch 16 [Val]: 100%|██████████| 354/354 [00:06<00:00, 57.42it/s]


loss_train=0.015976311923027755, loss_val=0.01820666954252806


epoch 17 [Train]: 100%|██████████| 1413/1413 [01:22<00:00, 17.08it/s]
epoch 17 [Val]: 100%|██████████| 354/354 [00:11<00:00, 30.10it/s]


loss_train=0.01561735679612344, loss_val=0.01802023553610437


epoch 18 [Train]: 100%|██████████| 1413/1413 [02:21<00:00,  9.97it/s]
epoch 18 [Val]: 100%|██████████| 354/354 [00:11<00:00, 30.17it/s]


loss_train=0.015089344437928146, loss_val=0.018036268360775997


epoch 19 [Train]: 100%|██████████| 1413/1413 [01:48<00:00, 13.06it/s]
epoch 19 [Val]: 100%|██████████| 354/354 [00:06<00:00, 55.82it/s]


loss_train=0.014536666109767773, loss_val=0.017535545947398506


epoch 20 [Train]: 100%|██████████| 1413/1413 [01:18<00:00, 17.99it/s]
epoch 20 [Val]: 100%|██████████| 354/354 [00:06<00:00, 58.34it/s]


loss_train=0.014234608591264622, loss_val=0.017758073140525042


epoch 21 [Train]: 100%|██████████| 1413/1413 [01:16<00:00, 18.51it/s]
epoch 21 [Val]: 100%|██████████| 354/354 [00:06<00:00, 58.91it/s]


loss_train=0.013807720323116304, loss_val=0.017693323501527816


epoch 22 [Train]: 100%|██████████| 1413/1413 [01:17<00:00, 18.26it/s]
epoch 22 [Val]: 100%|██████████| 354/354 [00:06<00:00, 57.43it/s]


loss_train=0.013394339832029735, loss_val=0.017294220853923153


epoch 23 [Train]: 100%|██████████| 1413/1413 [01:17<00:00, 18.27it/s]
epoch 23 [Val]: 100%|██████████| 354/354 [00:06<00:00, 56.95it/s]


loss_train=0.013135707212818993, loss_val=0.01718871013292948


epoch 24 [Train]: 100%|██████████| 1413/1413 [01:15<00:00, 18.67it/s]
epoch 24 [Val]: 100%|██████████| 354/354 [00:06<00:00, 58.58it/s]


loss_train=0.012841141669993188, loss_val=0.017068702029064298


epoch 25 [Train]: 100%|██████████| 1413/1413 [01:17<00:00, 18.16it/s]
epoch 25 [Val]: 100%|██████████| 354/354 [00:06<00:00, 53.91it/s]


loss_train=0.01243465213712782, loss_val=0.01731632756842774


epoch 26 [Train]: 100%|██████████| 1413/1413 [01:17<00:00, 18.28it/s]
epoch 26 [Val]: 100%|██████████| 354/354 [00:06<00:00, 57.42it/s]


loss_train=0.012083404182712828, loss_val=0.01640091766208663


epoch 27 [Train]: 100%|██████████| 1413/1413 [01:17<00:00, 18.23it/s]
epoch 27 [Val]: 100%|██████████| 354/354 [00:06<00:00, 58.47it/s]


loss_train=0.011909404738232102, loss_val=0.01721917396926947


epoch 28 [Train]: 100%|██████████| 1413/1413 [01:17<00:00, 18.15it/s]
epoch 28 [Val]: 100%|██████████| 354/354 [00:06<00:00, 58.32it/s]


loss_train=0.011558195263983106, loss_val=0.017244525460294834


epoch 29 [Train]: 100%|██████████| 1413/1413 [01:16<00:00, 18.35it/s]
epoch 29 [Val]: 100%|██████████| 354/354 [00:06<00:00, 56.32it/s]


loss_train=0.011261714377189392, loss_val=0.01698096383278821


epoch 30 [Train]: 100%|██████████| 1413/1413 [01:16<00:00, 18.54it/s]
epoch 30 [Val]: 100%|██████████| 354/354 [00:06<00:00, 57.04it/s]


loss_train=0.011001049437127106, loss_val=0.01708690661379257


epoch 31 [Train]: 100%|██████████| 1413/1413 [01:17<00:00, 18.20it/s]
epoch 31 [Val]: 100%|██████████| 354/354 [00:06<00:00, 56.58it/s]


loss_train=0.010808224304034114, loss_val=0.01694497534621107


epoch 32 [Train]: 100%|██████████| 1413/1413 [01:17<00:00, 18.12it/s]
epoch 32 [Val]: 100%|██████████| 354/354 [00:06<00:00, 55.75it/s]


loss_train=0.01062103326205673, loss_val=0.017438207128174645


epoch 33 [Train]: 100%|██████████| 1413/1413 [01:17<00:00, 18.34it/s]
epoch 33 [Val]: 100%|██████████| 354/354 [00:06<00:00, 56.41it/s]


loss_train=0.010284042888273526, loss_val=0.017149439691907943


epoch 34 [Train]: 100%|██████████| 1413/1413 [01:15<00:00, 18.61it/s]
epoch 34 [Val]: 100%|██████████| 354/354 [00:06<00:00, 58.80it/s]


loss_train=0.010171272889470388, loss_val=0.01768958292859422


epoch 35 [Train]: 100%|██████████| 1413/1413 [01:20<00:00, 17.61it/s]
epoch 35 [Val]: 100%|██████████| 354/354 [00:06<00:00, 54.12it/s]


loss_train=0.009878353946542358, loss_val=0.01743033582655092


epoch 36 [Train]: 100%|██████████| 1413/1413 [01:20<00:00, 17.50it/s]
epoch 36 [Val]: 100%|██████████| 354/354 [00:06<00:00, 55.53it/s]


loss_train=0.009726862814962885, loss_val=0.017550370729617817


epoch 37 [Train]: 100%|██████████| 1413/1413 [01:20<00:00, 17.54it/s]
epoch 37 [Val]: 100%|██████████| 354/354 [00:06<00:00, 55.25it/s]


loss_train=0.009481017178080064, loss_val=0.017489361366267595


epoch 38 [Train]: 100%|██████████| 1413/1413 [01:20<00:00, 17.54it/s]
epoch 38 [Val]: 100%|██████████| 354/354 [00:06<00:00, 53.28it/s]


loss_train=0.009387914780777815, loss_val=0.017989697341286276


epoch 39 [Train]: 100%|██████████| 1413/1413 [01:20<00:00, 17.59it/s]
epoch 39 [Val]: 100%|██████████| 354/354 [00:06<00:00, 52.79it/s]


loss_train=0.009072749717299473, loss_val=0.01837551852788166


epoch 40 [Train]: 100%|██████████| 1413/1413 [01:20<00:00, 17.52it/s]
epoch 40 [Val]: 100%|██████████| 354/354 [00:06<00:00, 54.33it/s]


loss_train=0.009077849147857153, loss_val=0.01779325572921142


epoch 41 [Train]: 100%|██████████| 1413/1413 [01:19<00:00, 17.74it/s]
epoch 41 [Val]: 100%|██████████| 354/354 [00:06<00:00, 53.75it/s]


loss_train=0.0087886632682778, loss_val=0.017915038672133773


epoch 42 [Train]: 100%|██████████| 1413/1413 [01:19<00:00, 17.77it/s]
epoch 42 [Val]: 100%|██████████| 354/354 [00:06<00:00, 54.19it/s]


loss_train=0.00864788932883644, loss_val=0.018017050172833695


epoch 43 [Train]: 100%|██████████| 1413/1413 [01:19<00:00, 17.69it/s]
epoch 43 [Val]: 100%|██████████| 354/354 [00:06<00:00, 55.30it/s]


loss_train=0.008538814034657019, loss_val=0.018524492743535566


epoch 44 [Train]: 100%|██████████| 1413/1413 [01:21<00:00, 17.39it/s]
epoch 44 [Val]: 100%|██████████| 354/354 [00:06<00:00, 54.34it/s]


loss_train=0.008360336781103726, loss_val=0.018115746703177775


epoch 45 [Train]: 100%|██████████| 1413/1413 [01:21<00:00, 17.24it/s]
epoch 45 [Val]: 100%|██████████| 354/354 [00:06<00:00, 53.74it/s]


loss_train=0.008216633321849203, loss_val=0.0186206615374734


epoch 46 [Train]: 100%|██████████| 1413/1413 [01:21<00:00, 17.29it/s]
epoch 46 [Val]: 100%|██████████| 354/354 [00:06<00:00, 55.15it/s]


loss_train=0.00801075038525506, loss_val=0.01894528611621025


epoch 47 [Train]: 100%|██████████| 1413/1413 [01:20<00:00, 17.60it/s]
epoch 47 [Val]: 100%|██████████| 354/354 [00:06<00:00, 58.12it/s]


loss_train=0.007922463105765408, loss_val=0.018242453193386732


epoch 48 [Train]: 100%|██████████| 1413/1413 [01:18<00:00, 17.99it/s]
epoch 48 [Val]: 100%|██████████| 354/354 [00:06<00:00, 56.11it/s]


loss_train=0.007874741356326195, loss_val=0.01985068650821508


epoch 49 [Train]: 100%|██████████| 1413/1413 [01:20<00:00, 17.63it/s]
epoch 49 [Val]: 100%|██████████| 354/354 [00:06<00:00, 55.88it/s]


loss_train=0.007554659471957189, loss_val=0.019216489612445823


epoch 50 [Train]: 100%|██████████| 1413/1413 [01:20<00:00, 17.58it/s]
epoch 50 [Val]: 100%|██████████| 354/354 [00:06<00:00, 54.81it/s]

loss_train=0.007452232037573362, loss_val=0.018806043190166016



