In [4]:
%matplotlib inline

Loaded backend module://matplotlib_inline.backend_inline version unknown.


In [5]:
import logging
 
logging.basicConfig(filename = 'mem_with_lstm_train.log',
                    level = logging.DEBUG,
                    format = '%(asctime)s:%(levelname)s:%(name)s:%(message)s')
logging.getLogger().addHandler(logging.StreamHandler())

In [6]:
from tqdm import tqdm
import numpy as np
import pandas as pd

In [1]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import random_split

In [8]:
import torchmetrics

Falling back to TensorFlow client; we recommended you install the Cloud TPU client directly with pip install cloud-tpu-client.
Falling back to TensorFlow client; we recommended you install the Cloud TPU client directly with pip install cloud-tpu-client.
Creating converter from 7 to 5
Creating converter from 7 to 5
Creating converter from 5 to 7
Creating converter from 5 to 7
Creating converter from 7 to 5
Creating converter from 7 to 5
Creating converter from 5 to 7
Creating converter from 5 to 7


In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [2]:
train_context_reps = torch.load('train_context_lstm_reps.pt')

In [9]:
len(train_context_reps)

70749

In [11]:
train_context_reps = torch.load('train_context_lstm_reps.pt')
val_context_reps = torch.load('val_context_lstm_reps.pt')
test_context_reps = torch.load('test_context_lstm_reps.pt')

In [12]:
train_query_reps = torch.load('train_query_lstm_reps.pt')
val_query_reps = torch.load('val_query_lstm_reps.pt')
test_query_reps = torch.load('test_query_lstm_reps.pt')

In [13]:
train_label = torch.load('train_label.pt')
val_label = torch.load('val_label.pt')
test_label = torch.load('test_label.pt')

In [14]:
def get_data_loader(context_reps, query_reps, label, batch_size, shuffle):
    data_set = TensorDataset(context_reps, query_reps, label)
    loader = DataLoader(data_set, batch_size=batch_size, shuffle=shuffle)
    return loader

In [15]:
batch_size = 128

In [16]:
train_loader = get_data_loader(train_context_reps, train_query_reps, train_label, batch_size, shuffle=True)
val_loader = get_data_loader(val_context_reps, val_query_reps, val_label, batch_size, shuffle=False)
test_loader = get_data_loader(test_context_reps, test_query_reps, test_label, batch_size, shuffle=False)

In [17]:
class MemNetwork(nn.Module):
    def __init__(self):
        super(MemNetwork, self).__init__()
        
        self.linear = nn.Sequential(
            nn.Linear(400, 400)
        )

    def forward(self, x, y):

        x_input = self.linear(x)
        op = torch.sum(x_input*y, dim=1)
        return op

In [18]:
model = MemNetwork().to(device)
print(model)

MemNetwork(
  (linear): Sequential(
    (0): Linear(in_features=400, out_features=400, bias=True)
  )
)


In [19]:
num_epochs = 50
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)
scheduler = StepLR(optimizer, step_size=5, gamma=0.8)

In [20]:
checkpoint_path = '/data/sherin/checkpoint_lm/chkpt_lm_lstm_wdcy_steplr_recall_best.pt.tar'

In [21]:
epoch_loss_list = []
accuracy_list = []
val_loss_list = []
val_acc_list = []
valid_acc_max = 0 

for epoch in range(num_epochs):
    train_count = 0
    model.train()
    epoch_loss = 0.0
    accuracy = 0.0
    

    for context, query, labels in tqdm(train_loader):
        train_count = train_count+1
        context = context.to(device)
        query = query.to(device)    
        target = labels.to(device)
        label = labels.float().to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(context, query)
        loss = criterion(outputs, label)
        loss.backward()

        optimizer.step()
        
        epoch_loss += loss.item()
        accuracy += torchmetrics.functional.accuracy(outputs, target, threshold=0.5).item()

    
    model.eval()
    val_loss = 0.0
    val_acc = 0.0
    
    test_count = 0
    for context, query, labels in tqdm(val_loader):
        test_count = test_count + 1
        context = context.to(device)
        query = query.to(device)
        target = labels.to(device)
        label = labels.float().to(device)
        
        outputs = model(context, query)
        loss = criterion(outputs, label)
        val_loss += loss.item()
        val_acc += torchmetrics.functional.accuracy(outputs, target, threshold=0.5).item()
        
    accuracy = accuracy / train_count
    epoch_loss = epoch_loss / train_count
    val_loss = val_loss / test_count
    val_acc = val_acc / test_count
    
    if val_acc > valid_acc_max:
        logging.info("saving best model")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': val_loss,
            'accuracy': val_acc,
            }, checkpoint_path)
        valid_acc_max = val_acc
    else:
        logging.info("not saving the model")
    
    curr_lr = optimizer.param_groups[0]['lr']
    logging.info(f'curr_lr: {curr_lr}')
    logging.info(f'[{epoch + 1}] Training loss: {epoch_loss:.3f} Training accuracy : {accuracy:.3f}')
    logging.info(f'[{epoch + 1}] Validation loss: {val_loss:.3f} Validation accuracy : {val_acc:.3f}')
    epoch_loss_list.append(epoch_loss)
    accuracy_list.append(accuracy)
    val_loss_list.append(val_loss)
    val_acc_list.append(val_acc)
    
    scheduler.step()

print('Finished Training')

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 553/553 [00:02<00:00, 203.40it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [00:00<00:00, 341.18it/s]
saving best model
saving best model
curr_lr: 0.01
curr_lr: 0.01
[1] Training loss: 1.426 Training accuracy : 0.636
[1] Training loss: 1.426 Training accuracy : 0.636
[1] Validation loss: 1.856 Validation accuracy : 0.586
[1] Validation loss: 1.856 Validation accuracy : 0.586
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 553/553 [00:02<00:00, 261.54it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████

[12] Validation loss: 1.650 Validation accuracy : 0.595
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 553/553 [00:02<00:00, 267.42it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [00:00<00:00, 356.58it/s]
not saving the model
not saving the model
curr_lr: 0.0064
curr_lr: 0.0064
[13] Training loss: 0.557 Training accuracy : 0.804
[13] Training loss: 0.557 Training accuracy : 0.804
[13] Validation loss: 1.679 Validation accuracy : 0.594
[13] Validation loss: 1.679 Validation accuracy : 0.594
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 553/553 [00:01<00:00, 278.74it/s]
100%|████████████████████████████████████████

curr_lr: 0.002621440000000001
curr_lr: 0.002621440000000001
[35] Training loss: 0.290 Training accuracy : 0.872
[35] Training loss: 0.290 Training accuracy : 0.872
[35] Validation loss: 1.242 Validation accuracy : 0.592
[35] Validation loss: 1.242 Validation accuracy : 0.592
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 553/553 [00:02<00:00, 273.00it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [00:00<00:00, 342.52it/s]
not saving the model
not saving the model
curr_lr: 0.002097152000000001
curr_lr: 0.002097152000000001
[36] Training loss: 0.258 Training accuracy : 0.888
[36] Training loss: 0.258 Training accuracy : 0.888
[36] Validation loss: 1.175 Validation accuracy : 0.593
[36] Validation loss: 1.175 Validation accuracy : 0.593


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 553/553 [00:02<00:00, 254.12it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [00:00<00:00, 331.18it/s]
not saving the model
not saving the model
curr_lr: 0.0013421772800000008
curr_lr: 0.0013421772800000008
[47] Training loss: 0.231 Training accuracy : 0.906
[47] Training loss: 0.231 Training accuracy : 0.906
[47] Validation loss: 1.120 Validation accuracy : 0.597
[47] Validation loss: 1.120 Validation accuracy : 0.597
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 553/553 [00:02<00:00, 256.52it/s]
100%|██████████████████████████████████████████████████████████████████

Finished Training


In [22]:
# # https://pytorch.org/tutorials/beginner/saving_loading_models.html
# for inference - load checkpointed model

PATH = checkpoint_path
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
acc = checkpoint['accuracy']

# inferece
model.eval()

MemNetwork(
  (linear): Sequential(
    (0): Linear(in_features=400, out_features=400, bias=True)
  )
)

In [23]:
test_count = 0
output_logits_lstm = []
test_acc = 0
for context, query, labels in tqdm(test_loader):
    test_count += 1
    context = context.to(device)
    query = query.to(device)

    target = labels.to(device)
    label = labels.float().to(device)
        
    outputs = model(context, query)
    output_logits = outputs.detach().cpu().numpy()
    test_acc += torchmetrics.functional.accuracy(outputs, target, threshold=0.5).item()
    output_logits_lstm.append(output_logits)

accuracy = test_acc/test_count
print("The test accuracy is {}".format(accuracy))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 301/301 [00:00<00:00, 353.04it/s]

The test accuracy is 0.6025111263376534





In [24]:
test_output_logits_lstm = np.hstack(output_logits_lstm)

In [25]:
lstm_pred = torch.sigmoid(torch.tensor(test_output_logits_lstm))
lstm_pred_label = 1.0 * (lstm_pred > 0.5)

In [26]:
torch.save(lstm_pred, 'lstm_pred.pt')
torch.save(lstm_pred_label, 'lstm_pred_label.pt')

In [27]:
lstm_pred

tensor([0.0835, 0.0686, 0.0205,  ..., 0.6746, 0.0989, 0.0123])

In [28]:
lstm_pred_label

tensor([0., 0., 0.,  ..., 1., 0., 0.])