In [2]:
%matplotlib inline

In [3]:
import logging
 
logging.basicConfig(filename = 'mem_with_bert_train.log',
                    level = logging.DEBUG,
                    format = '%(asctime)s:%(levelname)s:%(name)s:%(message)s')
logging.getLogger().addHandler(logging.StreamHandler())

In [4]:
from tqdm import tqdm
import numpy as np
import pandas as pd

Note: detected 128 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
Note: NumExpr detected 128 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
NumExpr defaulting to 8 threads.


In [5]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import random_split

In [6]:
import torchmetrics

Falling back to TensorFlow client; we recommended you install the Cloud TPU client directly with pip install cloud-tpu-client.
Creating converter from 7 to 5
Creating converter from 5 to 7
Creating converter from 7 to 5
Creating converter from 5 to 7


In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


# load natural language data

In [3]:
train_context_reps = torch.load('train_context_bert_reps.pt')
val_context_reps = torch.load('val_context_bert_reps.pt')
test_context_reps = torch.load('test_context_bert_reps.pt')

In [4]:
train_query_reps = torch.load('train_query_bert_reps.pt')
val_query_reps = torch.load('val_query_bert_reps.pt')
test_query_reps = torch.load('test_query_bert_reps.pt')

In [5]:
train_label = torch.load('train_label.pt')
val_label = torch.load('val_label.pt')
test_label = torch.load('test_label.pt')

In [8]:
def get_data_loader(context_reps, query_reps, label, batch_size, shuffle):
    data_set = TensorDataset(context_reps, query_reps, label)
    loader = DataLoader(data_set, batch_size=batch_size, shuffle=shuffle)
    return loader

In [9]:
batch_size = 128

In [19]:
train_loader = get_data_loader(train_context_reps, train_query_reps, train_label, batch_size, shuffle=True)
val_loader = get_data_loader(val_context_reps, val_query_reps, val_label, batch_size, shuffle=False)
test_loader = get_data_loader(test_context_reps, test_query_reps, test_label, batch_size, shuffle=False)

# get synthetic data

In [8]:
train_context_reps = torch.load('syn_no_olp_train_context_bert_reps.pt')
val_context_reps = torch.load('syn_no_olp_val_context_bert_reps.pt')
test_context_reps = torch.load('syn_no_olp_test_context_bert_reps.pt')

In [9]:
train_query_reps = torch.load('syn_no_olp_train_query_bert_reps.pt')
val_query_reps = torch.load('syn_no_olp_val_query_bert_reps.pt')
test_query_reps = torch.load('syn_no_olp_test_query_bert_reps.pt')

In [10]:
train_label = torch.load('syn_no_olp_train_label.pt')
val_label = torch.load('syn_no_olp_val_label.pt')
test_label = torch.load('syn_no_olp_test_label.pt')

In [11]:
def get_data_loader(context_reps, query_reps, label, batch_size, shuffle):
    data_set = TensorDataset(context_reps, query_reps, label)
    loader = DataLoader(data_set, batch_size=batch_size, shuffle=shuffle)
    return loader

In [12]:
batch_size = 128

In [13]:
train_loader = get_data_loader(train_context_reps, train_query_reps, train_label, batch_size, shuffle=True)
val_loader = get_data_loader(val_context_reps, val_query_reps, val_label, batch_size, shuffle=False)
test_loader = get_data_loader(test_context_reps, test_query_reps, test_label, batch_size, shuffle=False)

# Train network

In [15]:
class MemNetwork(nn.Module):
    def __init__(self):
        super(MemNetwork, self).__init__()
        
        self.linear = nn.Sequential(
            nn.Linear(768, 768)
        )

    def forward(self, x, y):

        x_input = self.linear(x)
        op = torch.sum(x_input*y, dim=1)
        return op

In [16]:
model = MemNetwork().to(device)
print(model)

MemNetwork(
  (linear): Sequential(
    (0): Linear(in_features=768, out_features=768, bias=True)
  )
)


In [17]:
num_epochs = 50
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)
scheduler = StepLR(optimizer, step_size=5, gamma=0.8)

# checkpoint for natural language data

In [59]:
checkpoint_path = '/data/sherin/checkpoint_lm/chkpt_lm_bert_wdcy_steplr_recall_best.pt.tar'

# checkpoint for synthetic data

In [18]:
checkpoint_path = '/data/sherin/checkpoint_synthetic/chkpt_lm_bert_wdcy_steplr_recall_best.pt.tar'

In [20]:
num_epochs = 50
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001)
scheduler = StepLR(optimizer, step_size=5, gamma=0.8)

In [21]:
epoch_loss_list = []
accuracy_list = []
val_loss_list = []
val_acc_list = []
valid_acc_max = 0 

for epoch in range(num_epochs):
    train_count = 0
    model.train()
    epoch_loss = 0.0
    accuracy = 0.0
    

    for context, query, labels in tqdm(train_loader):
        train_count = train_count+1
        context = context.to(device)
        query = query.to(device)    
        target = labels.to(device)
        label = labels.float().to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(context, query)
        loss = criterion(outputs, label)
        loss.backward()

        optimizer.step()
        
        epoch_loss += loss.item()
        accuracy += torchmetrics.functional.accuracy(outputs, target, threshold=0.5).item()

    
    model.eval()
    val_loss = 0.0
    val_acc = 0.0
    
    test_count = 0
    for context, query, labels in tqdm(val_loader):
        test_count = test_count + 1
        context = context.to(device)
        query = query.to(device)
        target = labels.to(device)
        label = labels.float().to(device)
        
        outputs = model(context, query)
        loss = criterion(outputs, label)
        val_loss += loss.item()
        val_acc += torchmetrics.functional.accuracy(outputs, target, threshold=0.5).item()
        
    accuracy = accuracy / train_count
    epoch_loss = epoch_loss / train_count
    val_loss = val_loss / test_count
    val_acc = val_acc / test_count
    
    if val_acc > valid_acc_max:
        logging.info("saving best model")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': val_loss,
            'accuracy': val_acc,
            }, checkpoint_path)
        valid_acc_max = val_acc
    else:
        logging.info("not saving the model")
    
    curr_lr = optimizer.param_groups[0]['lr']
    logging.info(f'curr_lr: {curr_lr}')
    logging.info(f'[{epoch + 1}] Training loss: {epoch_loss:.3f} Training accuracy : {accuracy:.3f}')
    logging.info(f'[{epoch + 1}] Validation loss: {val_loss:.3f} Validation accuracy : {val_acc:.3f}')
    epoch_loss_list.append(epoch_loss)
    accuracy_list.append(accuracy)
    val_loss_list.append(val_loss)
    val_acc_list.append(val_acc)
    
    scheduler.step()

print('Finished Training')

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1086/1086 [00:03<00:00, 356.89it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 123/123 [00:00<00:00, 469.50it/s]
saving best model
curr_lr: 0.0001
[1] Training loss: 1.270 Training accuracy : 0.568
[1] Validation loss: 1.368 Validation accuracy : 0.498
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1086/1086 [00:03<00:00, 357.37it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 123/123 [00:00<00:00, 468.63it/s]
not saving the model
curr_lr: 0.0001
[2] Training loss: 

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1086/1086 [00:02<00:00, 368.64it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 123/123 [00:00<00:00, 466.19it/s]
not saving the model
curr_lr: 5.120000000000001e-05
[16] Training loss: 0.783 Training accuracy : 0.588
[16] Validation loss: 1.019 Validation accuracy : 0.498
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1086/1086 [00:02<00:00, 365.64it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 123/123 [00:00<00:00, 467.02it/s]
not saving the model
curr_lr: 5.1200

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 123/123 [00:00<00:00, 458.61it/s]
not saving the model
curr_lr: 3.2768000000000016e-05
[30] Training loss: 0.728 Training accuracy : 0.591
[30] Validation loss: 0.922 Validation accuracy : 0.497
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1086/1086 [00:02<00:00, 363.74it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 123/123 [00:00<00:00, 469.64it/s]
not saving the model
curr_lr: 2.6214400000000015e-05
[31] Training loss: 0.723 Training accuracy : 0.593
[31] Validation loss: 0.942 Validation accuracy : 0.500
100%|██████████████████████████████████████████████████████████████████████

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1086/1086 [00:03<00:00, 358.23it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 123/123 [00:00<00:00, 452.39it/s]
not saving the model
curr_lr: 1.677721600000001e-05
[45] Training loss: 0.704 Training accuracy : 0.594
[45] Validation loss: 0.904 Validation accuracy : 0.499
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1086/1086 [00:03<00:00, 360.90it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 123/123 [00:00<00:00, 468.29it/s]
saving best model
curr_lr: 1.3421772

Finished Training


# natural language analysis


In [61]:
# # https://pytorch.org/tutorials/beginner/saving_loading_models.html
# for inference - load checkpointed model

PATH = checkpoint_path
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
acc = checkpoint['accuracy']

# inferece
model.eval()

MemNetwork(
  (linear): Sequential(
    (0): Linear(in_features=768, out_features=768, bias=True)
  )
)

In [77]:
test_count = 0
output_logits_bert = []
test_acc = 0
for context, query, labels in tqdm(test_loader):
    test_count += 1
    context = context.to(device)
    query = query.to(device)

    target = labels.to(device)
    label = labels.float().to(device)
        
    outputs = model(context, query)
    output_logits = outputs.detach().cpu().numpy()
    test_acc += torchmetrics.functional.accuracy(outputs, target, threshold=0.5).item()
    output_logits_bert.append(output_logits)

accuracy = test_acc/test_count
print("The test accuracy is {}".format(accuracy))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 301/301 [00:00<00:00, 354.07it/s]

The test accuracy is 0.6104366345659047





In [65]:
test_output_logits_bert = np.hstack(output_logits_bert)

In [79]:
bert_pred = torch.sigmoid(torch.tensor(test_output_logits_bert))
bert_pred_label = 1.0 * (bert_pred > 0.5)

In [80]:
torch.save(bert_pred, 'bert_pred.pt')
torch.save(bert_pred_label, 'bert_pred_label.pt')

In [70]:
df_test_final = pd.read_json("test_final.json")

In [71]:
df_test_final["bert_pred"] = torch.sigmoid(torch.tensor(test_output_logits_bert)).cpu().numpy()

In [72]:
df_test_final["bert_pred_label"] = 1.0 * (df_test_final["bert_pred"] > 0.5)

# synthetic analysis