In [2]:
%matplotlib inline

In [3]:
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 18})

In [85]:
import logging
 
logging.basicConfig(filename = 'mem_with_transf_synth_train.log',
                    level = logging.DEBUG,
                    format = '%(asctime)s:%(levelname)s:%(name)s:%(message)s')
logging.getLogger().addHandler(logging.StreamHandler())

In [59]:
from tqdm import tqdm

In [77]:
import torchmetrics

In [26]:
import torch
import numpy as np
import pandas as pd
import math

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

In [52]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import random_split

In [21]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

Device: cuda:0


In [11]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

In [12]:
class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Stack all weight matrices 1...h together for efficiency
        # Note that in many implementations you see "bias=False" which is optional
        self.qkv_proj = nn.Linear(input_dim, 3*embed_dim)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        # Original Transformer initialization, see PyTorch documentation
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)

    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_length, embed_dim = x.size()
        qkv = self.qkv_proj(x)

        # Separate Q, K, V from linear output
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
        q, k, v = qkv.chunk(3, dim=-1)

        # Determine value outputs
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
        values = values.reshape(batch_size, seq_length, embed_dim)
        o = self.o_proj(values)

        if return_attention:
            return o, attention
        else:
            return o

In [14]:
class EncoderBlock(nn.Module):

    def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0):
        """
        Inputs:
            input_dim - Dimensionality of the input
            num_heads - Number of heads to use in the attention block
            dim_feedforward - Dimensionality of the hidden layer in the MLP
            dropout - Dropout probability to use in the dropout layers
        """
        super().__init__()

        # Attention layer
        self.self_attn = MultiheadAttention(input_dim, input_dim, num_heads)

        # Two-layer MLP
        self.linear_net = nn.Sequential(
            nn.Linear(input_dim, dim_feedforward),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(dim_feedforward, input_dim)
        )

        # Layers to apply in between the main layers
        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Attention part
        attn_out = self.self_attn(x, mask=mask)
        x = x + self.dropout(attn_out)
        x = self.norm1(x)

        # MLP part
        linear_out = self.linear_net(x)
        x = x + self.dropout(linear_out)
        x = self.norm2(x)
        
        return x


In [15]:
class TransformerEncoder(nn.Module):

    def __init__(self, num_layers, **block_args):
        super().__init__()
        self.layers = nn.ModuleList([EncoderBlock(**block_args) for _ in range(num_layers)])

    def forward(self, x, mask=None):
        for l in self.layers:
            x = l(x, mask=mask)
        return x

    def get_attention_maps(self, x, mask=None):
        attention_maps = []
        for l in self.layers:
            _, attn_map = l.self_attn(x, mask=mask, return_attention=True)
            attention_maps.append(attn_map)
            x = l(x)
        return attention_maps

In [16]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000):
        """
        Inputs
            d_model - Hidden dimensionality of the input.
            max_len - Maximum length of a sequence to expect.
        """
        super().__init__()

        # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        # register_buffer => Tensor which is not a parameter, but should be part of the modules state.
        # Used for tensors that need to be on the same device as the module.
        # persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model)
        self.register_buffer('pe', pe, persistent=False)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

In [17]:
class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler):

    def __init__(self, optimizer, warmup, max_iters):
        self.warmup = warmup
        self.max_num_iters = max_iters
        super().__init__(optimizer)

    def get_lr(self):
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]

    def get_lr_factor(self, epoch):
        lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))
        if epoch <= self.warmup:
            lr_factor *= epoch * 1.0 / self.warmup
        return lr_factor

In [5]:
train_context_query = np.load('train_context_query.npy', allow_pickle=True)
val_context_query = np.load('val_context_query.npy', allow_pickle=True)
test_context_query = np.load('test_context_query.npy', allow_pickle=True)

In [33]:
df_train= pd.read_json('synthetic_train.json')
df_val = pd.read_json('synthetic_val.json')
df_test = pd.read_json('synthetic_test.json')

In [48]:
df_train.head()

Unnamed: 0,index,seq_len,seq,rep_token_first_pos,query_token,target_val
44654,44654,25,[81 9 32 70 73 68 19 85 1 30 15 45 82 64 38 ...,12,82,1
85524,85524,45,[54 48 92 66 35 94 53 95 24 49 4 5 72 62 15 ...,20,85,1
9508,9508,7,[54 30 8 1 19 71 73],0,54,1
21160,21160,13,[68 5 84 25 39 12 35 15 97 31 64 88 38],8,97,1
193670,193670,99,[44 50 31 65 97 60 95 40 42 30 14 62 10 33 72 ...,62,93,1


In [35]:
INP_PATH = '/data/sherin/'
orth_vectors = np.load(
    INP_PATH + 'orthonormal_vectors_512.npy')

In [43]:
def get_tensors(inp, Orth):
    C = np.zeros((len(inp), 100, 512))
    Q = np.zeros((len(inp), 512))
    for idx, c_q in enumerate(inp):
        Q[idx] = Orth[c_q[-1]]
        C[idx,100 - len(c_q)+1:,:] = Orth[c_q[:-1]]
    return C, Q

In [44]:
train_C, train_Q = get_tensors(train_context_query, orth_vectors)

In [45]:
val_C, val_Q = get_tensors(val_context_query, orth_vectors)
test_C, test_Q = get_tensors(test_context_query, orth_vectors)

In [49]:
train_label = df_train['target_val'].values
val_label = df_val['target_val'].values
test_label = df_test['target_val'].values

In [46]:
def get_data_loader(context_reps, query_reps, label, batch_size, shuffle):
    data_set = TensorDataset(context_reps, query_reps, label)
    loader = DataLoader(data_set, batch_size=batch_size, shuffle=shuffle)
    return loader

In [47]:
batch_size = 128

In [63]:
train_loader = get_data_loader(torch.tensor(train_C).float(), torch.tensor(train_Q).float(),
                               torch.tensor(train_label), batch_size, shuffle=True)
val_loader = get_data_loader(torch.tensor(val_C).float(), torch.tensor(val_Q).float(),
                             torch.tensor(val_label), batch_size, shuffle=False)
test_loader = get_data_loader(torch.tensor(test_C).float(), torch.tensor(test_Q).float(),
                              torch.tensor(test_label), batch_size, shuffle=False)

In [105]:
class TransformerPredictor(nn.Module):

    def __init__(self, input_dim=512, model_dim=32,
                 num_heads=1, num_layers=1,
                 dropout=0.0, input_dropout=0.0):

        super().__init__()

        # Input dim -> Model dim
        self.input_net = nn.Sequential(
            #nn.Dropout(input_dropout),
            nn.Linear(input_dim, model_dim)
        )
        # Positional encoding for sequences
        #self.positional_encoding = PositionalEncoding(d_model=model_dim)
        # Transformer
        self.transformer = TransformerEncoder(num_layers=num_layers,
                                              input_dim=model_dim,
                                              dim_feedforward=2*model_dim,
                                              num_heads=num_heads,
                                              dropout=dropout)

    def forward(self, x, y):
        x = self.input_net(x)
        y = self.input_net(y)
        #x = self.positional_encoding(x)
        #y = self.positional_encoding(y)
        x = self.transformer(x)
        #print(x.shape)
        #print(y.shape)
        #op = torch.sum(x[:,-1,:]*y, dim=1)
        op = torch.sum(torch.sum(x, dim=1)*y, dim=1)
        return op


In [106]:
model = TransformerPredictor().to(device)

In [107]:
num_epochs = 100
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = StepLR(optimizer, step_size=5, gamma=0.8)

In [108]:
checkpoint_path = '/data/sherin/checkpoint_lm/chkpt_synth_tran_steplr_recall_best.pt.tar'

In [109]:
epoch_loss_list = []
accuracy_list = []
val_loss_list = []
val_acc_list = []
valid_acc_max = 0 

for epoch in range(num_epochs):
    train_count = 0
    model.train()
    epoch_loss = 0.0
    accuracy = 0.0
    

    for context, query, labels in tqdm(train_loader):
        train_count = train_count+1
        context = context.to(device)
        query = query.to(device)    
        target = labels.to(device)
        label = labels.float().to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(context, query)
        loss = criterion(outputs, label)
        loss.backward()

        optimizer.step()
        
        epoch_loss += loss.item()
        accuracy += torchmetrics.functional.accuracy(outputs, target, threshold=0.5).item()

    
    model.eval()
    val_loss = 0.0
    val_acc = 0.0
    
    test_count = 0
    for context, query, labels in tqdm(val_loader):
        test_count = test_count + 1
        context = context.to(device)
        query = query.to(device)
        target = labels.to(device)
        label = labels.float().to(device)
        
        outputs = model(context, query)
        loss = criterion(outputs, label)
        val_loss += loss.item()
        val_acc += torchmetrics.functional.accuracy(outputs, target, threshold=0.5).item()
        
    accuracy = accuracy / train_count
    epoch_loss = epoch_loss / train_count
    val_loss = val_loss / test_count
    val_acc = val_acc / test_count
    
    if val_acc > valid_acc_max:
        logging.info("saving best model")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': val_loss,
            'accuracy': val_acc,
            }, checkpoint_path)
        valid_acc_max = val_acc
    else:
        logging.info("not saving the model")
    
    curr_lr = optimizer.param_groups[0]['lr']
    logging.info(f'curr_lr: {curr_lr}')
    logging.info(f'[{epoch + 1}] Training loss: {epoch_loss:.3f} Training accuracy : {accuracy:.3f}')
    logging.info(f'[{epoch + 1}] Validation loss: {val_loss:.3f} Validation accuracy : {val_acc:.3f}')
    epoch_loss_list.append(epoch_loss)
    accuracy_list.append(accuracy)
    val_loss_list.append(val_loss)
    val_acc_list.append(val_acc)
    
    scheduler.step()

print('Finished Training')

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 742/742 [01:08<00:00, 10.89it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 318/318 [00:52<00:00,  6.01it/s]
saving best model
curr_lr: 0.01
[1] Training loss: 1.381 Training accuracy : 0.524
[1] Validation loss: 0.686 Validation accuracy : 0.519
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 742/742 [00:30<00:00, 24.48it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 318/318 [00:04<00:00, 73.31it/s]
saving best model
curr_lr: 0.01
[2] Training loss:

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 742/742 [00:10<00:00, 69.99it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 318/318 [00:03<00:00, 98.61it/s]
saving best model
curr_lr: 0.00512
[16] Training loss: 0.538 Training accuracy : 0.731
[16] Validation loss: 0.629 Validation accuracy : 0.702
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 742/742 [00:12<00:00, 61.65it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 318/318 [00:03<00:00, 102.05it/s]
saving best model
curr_lr: 0.00512
[17] Train

[44] Validation loss: 0.525 Validation accuracy : 0.771
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 742/742 [00:12<00:00, 58.27it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 318/318 [00:04<00:00, 71.31it/s]
not saving the model
curr_lr: 0.001677721600000001
[45] Training loss: 0.360 Training accuracy : 0.833
[45] Validation loss: 0.516 Validation accuracy : 0.770
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 742/742 [00:11<00:00, 63.90it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 318/31

[73] Validation loss: 0.533 Validation accuracy : 0.780
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 742/742 [00:14<00:00, 52.01it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 318/318 [00:04<00:00, 67.33it/s]
saving best model
curr_lr: 0.0004398046511104004
[74] Training loss: 0.302 Training accuracy : 0.862
[74] Validation loss: 0.533 Validation accuracy : 0.783
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 742/742 [00:12<00:00, 58.41it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 318/318 [

Finished Training


In [110]:
print("finished")

finished


In [111]:
PATH = checkpoint_path
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
acc = checkpoint['accuracy']

# inferece
model.eval()

TransformerPredictor(
  (input_net): Sequential(
    (0): Linear(in_features=512, out_features=32, bias=True)
  )
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0): EncoderBlock(
        (self_attn): MultiheadAttention(
          (qkv_proj): Linear(in_features=32, out_features=96, bias=True)
          (o_proj): Linear(in_features=32, out_features=32, bias=True)
        )
        (linear_net): Sequential(
          (0): Linear(in_features=32, out_features=64, bias=True)
          (1): Dropout(p=0.0, inplace=False)
          (2): ReLU(inplace=True)
          (3): Linear(in_features=64, out_features=32, bias=True)
        )
        (norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
    )
  )
)

In [112]:
print(acc)

0.7835763636625038


In [113]:
test_count = 0
output_logits_trans = []
test_acc = 0
for context, query, labels in tqdm(test_loader):
    test_count = test_count + 1
    context = context.to(device)
    query = query.to(device)
    target = labels.to(device)
    label = labels.float().to(device)

    outputs = model(context, query)
    output_logits = outputs.detach().cpu().numpy()
    test_acc += torchmetrics.functional.accuracy(outputs, target, threshold=0.5).item()
    output_logits_trans.append(output_logits)

accuracy = test_acc/test_count
print("The test accuracy is {}".format(accuracy))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 455/455 [01:58<00:00,  3.86it/s]

The test accuracy is 0.7847413004099668





In [114]:
test_output_logits_trans = np.hstack(output_logits_trans)

In [115]:
trans_pred = torch.sigmoid(torch.tensor(test_output_logits_trans))
trans_pred_label = 1.0 * (trans_pred > 0.5)

In [116]:
torch.save(trans_pred, 'trans_pred.pt')
torch.save(trans_pred_label, 'trans_pred_label.pt')