In [185]:
%matplotlib inline

Loaded backend module://matplotlib_inline.backend_inline version unknown.


In [186]:
# https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html

In [187]:
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 18})

In [188]:
import logging
 
logging.basicConfig(filename = 'mem_with_transf_synth_train.log',
                    level = logging.DEBUG,
                    format = '%(asctime)s:%(levelname)s:%(name)s:%(message)s')
logging.getLogger().addHandler(logging.StreamHandler())

In [189]:
from tqdm import tqdm

In [190]:
import torchmetrics

In [191]:
import torch
import numpy as np
import pandas as pd
import math

In [192]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

In [193]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import random_split

In [194]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

Device: cuda


In [195]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

In [196]:
class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Stack all weight matrices 1...h together for efficiency
        # Note that in many implementations you see "bias=False" which is optional
        self.qkv_proj = nn.Linear(input_dim, 3*embed_dim)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        # Original Transformer initialization, see PyTorch documentation
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)

    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_length, embed_dim = x.size()
        qkv = self.qkv_proj(x)

        # Separate Q, K, V from linear output
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
        q, k, v = qkv.chunk(3, dim=-1)

        # Determine value outputs
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
        values = values.reshape(batch_size, seq_length, embed_dim)
        o = self.o_proj(values)

        if return_attention:
            return o, attention
        else:
            return o

In [197]:
class EncoderBlock(nn.Module):

    def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0):
        """
        Inputs:
            input_dim - Dimensionality of the input
            num_heads - Number of heads to use in the attention block
            dim_feedforward - Dimensionality of the hidden layer in the MLP
            dropout - Dropout probability to use in the dropout layers
        """
        super().__init__()

        # Attention layer
        self.self_attn = MultiheadAttention(input_dim, input_dim, num_heads)

        # Two-layer MLP
        self.linear_net = nn.Sequential(
            nn.Linear(input_dim, dim_feedforward),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(dim_feedforward, input_dim)
        )

        # Layers to apply in between the main layers
        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Attention part
        attn_out = self.self_attn(x, mask=mask)
        x = x + self.dropout(attn_out)
        x = self.norm1(x)

        # MLP part
        linear_out = self.linear_net(x)
        x = x + self.dropout(linear_out)
        x = self.norm2(x)
        
        return x


In [198]:
class TransformerEncoder(nn.Module):

    def __init__(self, num_layers, **block_args):
        super().__init__()
        self.layers = nn.ModuleList([EncoderBlock(**block_args) for _ in range(num_layers)])

    def forward(self, x, mask=None):
        for l in self.layers:
            x = l(x, mask=mask)
        return x

    def get_attention_maps(self, x, mask=None):
        attention_maps = []
        for l in self.layers:
            _, attn_map = l.self_attn(x, mask=mask, return_attention=True)
            attention_maps.append(attn_map)
            x = l(x)
        return attention_maps

In [199]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000):
        """
        Inputs
            d_model - Hidden dimensionality of the input.
            max_len - Maximum length of a sequence to expect.
        """
        super().__init__()

        # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        # register_buffer => Tensor which is not a parameter, but should be part of the modules state.
        # Used for tensors that need to be on the same device as the module.
        # persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model)
        self.register_buffer('pe', pe, persistent=False)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

In [200]:
class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler):

    def __init__(self, optimizer, warmup, max_iters):
        self.warmup = warmup
        self.max_num_iters = max_iters
        super().__init__(optimizer)

    def get_lr(self):
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]

    def get_lr_factor(self, epoch):
        lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))
        if epoch <= self.warmup:
            lr_factor *= epoch * 1.0 / self.warmup
        return lr_factor

In [201]:
train_context_query = np.load('train_context_query.npy', allow_pickle=True)
val_context_query = np.load('val_context_query.npy', allow_pickle=True)
test_context_query = np.load('test_context_query.npy', allow_pickle=True)

In [202]:
df_train= pd.read_json('synthetic_train.json')
df_val = pd.read_json('synthetic_val.json')
df_test = pd.read_json('synthetic_test.json')

In [203]:
print("The size of the synthetic train dataset is {}".format(len(df_train)))
print("The size of the synthetic val dataset is {}".format(len(df_val)))
print("The size of the synthetic test dataset is {}".format(len(df_test)))

The size of the synthetic train dataset is 94919
The size of the synthetic val dataset is 40680
The size of the synthetic test dataset is 58115


In [204]:
df_all = pd.concat([df_train, df_val, df_test])
all_context_query = np.hstack([train_context_query, val_context_query, test_context_query])

In [205]:
df_all.head()

Unnamed: 0,index,seq_len,seq,rep_token_first_pos,query_token,target_val
44654,44654,25,[81 9 32 70 73 68 19 85 1 30 15 45 82 64 38 ...,12,82,1
85524,85524,45,[54 48 92 66 35 94 53 95 24 49 4 5 72 62 15 ...,20,85,1
9508,9508,7,[54 30 8 1 19 71 73],0,54,1
21160,21160,13,[68 5 84 25 39 12 35 15 97 31 64 88 38],8,97,1
193670,193670,99,[44 50 31 65 97 60 95 40 42 30 14 62 10 33 72 ...,62,93,1


In [206]:
from sklearn.model_selection import train_test_split

train_val_query, test_query = train_test_split(df_all["query_token"].unique().tolist(), test_size=0.3, random_state=2, shuffle=True)
train_query, val_query = train_test_split(train_val_query, test_size=0.3, random_state=2, shuffle=True)

In [207]:
train_ids = df_all["query_token"].isin(train_query)
val_ids = df_all["query_token"].isin(val_query)
test_ids = df_all["query_token"].isin(test_query)

new_df_train = df_all[train_ids]
new_df_val = df_all[val_ids]
new_df_test = df_all[test_ids]

new_context_query_train = all_context_query[train_ids]
new_context_query_val = all_context_query[val_ids]
new_context_query_test = all_context_query[test_ids]

In [208]:
print("The size of the synthetic train dataset is {}".format(len(df_train)))
print("The size of the synthetic val dataset is {}".format(len(df_val)))
print("The size of the synthetic test dataset is {}".format(len(df_test)))

The size of the synthetic train dataset is 94919
The size of the synthetic val dataset is 40680
The size of the synthetic test dataset is 58115


In [274]:
INP_PATH = '/data/sherin/'
orth_vectors = np.load(INP_PATH + 'orthonormal_vectors_512.npy')

In [243]:
#orth_vectors = np.random.randn(512, 512)
#orth_vectors = np.random.normal(0,0.01, (512, 75))
orth_vectors = np.random.rand(512, 512)

In [275]:
def get_tensors(inp, Orth):
    C = np.zeros((len(inp), 99, 512))
    Q = np.zeros((len(inp), 512))
    for idx, c_q in enumerate(inp):
        #print(c_q.shape)
        Q[idx] = Orth[c_q[-1]]
        C[idx,99 - (len(c_q)-1):,:] = Orth[c_q[:-1]]
    return C, Q

In [276]:
train_C, train_Q = get_tensors(new_context_query_train, orth_vectors)

In [277]:
val_C, val_Q = get_tensors(new_context_query_val, orth_vectors)
test_C, test_Q = get_tensors(new_context_query_test, orth_vectors)

In [278]:
train_label = new_df_train['target_val'].values
val_label = new_df_val['target_val'].values
test_label = new_df_test['target_val'].values

In [279]:
def get_data_loader(context_reps, query_reps, label, batch_size, shuffle):
    data_set = TensorDataset(context_reps, query_reps, label)
    loader = DataLoader(data_set, batch_size=batch_size, shuffle=shuffle)
    return loader

In [280]:
batch_size = 32

In [281]:
train_loader = get_data_loader(torch.tensor(train_C).float(), torch.tensor(train_Q).float(),
                               torch.tensor(train_label), batch_size, shuffle=True)
val_loader = get_data_loader(torch.tensor(val_C).float(), torch.tensor(val_Q).float(),
                             torch.tensor(val_label), batch_size, shuffle=False)
test_loader = get_data_loader(torch.tensor(test_C).float(), torch.tensor(test_Q).float(),
                              torch.tensor(test_label), batch_size, shuffle=False)

In [282]:
"""
class TransformerPredictor(nn.Module):

    def __init__(self, input_dim=512, model_dim=32,
                 num_heads=1, num_layers=1,
                 dropout=0.0, input_dropout=0.0):

        super().__init__()

        # Input dim -> Model dim
        self.input_net = nn.Sequential(
            #nn.Dropout(input_dropout),
            nn.Linear(input_dim, model_dim)
        )
        # Positional encoding for sequences
        self.positional_encoding = PositionalEncoding(d_model=model_dim, max_len=100)
        # Transformer
        self.transformer = TransformerEncoder(num_layers=num_layers,
                                              input_dim=model_dim,
                                              dim_feedforward=2*model_dim,
                                              num_heads=num_heads,
                                              dropout=dropout)

    def forward(self, x, y):
        x = self.input_net(x)
        y = self.input_net(y)
        #print(x.shape)
        #print(y.shape)
        x = self.positional_encoding(x)
        #y = self.positional_encoding(y.unsqueeze(1))
        x = self.transformer(x)
        #print(x.shape)
        #print(y.shape)
        #op = torch.sum(x[:,-1,:]*y.squeeze(1), dim=1)
        #op = torch.sum(torch.sum(x, dim=1)*y.squeeze(1), dim=1)
        op = torch.sum(torch.sum(x, dim=1)*y, dim=1)
        #op = torch.sum(x[:,-1,:]*y, dim=1)
        return op
"""

'\nclass TransformerPredictor(nn.Module):\n\n    def __init__(self, input_dim=512, model_dim=32,\n                 num_heads=1, num_layers=1,\n                 dropout=0.0, input_dropout=0.0):\n\n        super().__init__()\n\n        # Input dim -> Model dim\n        self.input_net = nn.Sequential(\n            #nn.Dropout(input_dropout),\n            nn.Linear(input_dim, model_dim)\n        )\n        # Positional encoding for sequences\n        self.positional_encoding = PositionalEncoding(d_model=model_dim, max_len=100)\n        # Transformer\n        self.transformer = TransformerEncoder(num_layers=num_layers,\n                                              input_dim=model_dim,\n                                              dim_feedforward=2*model_dim,\n                                              num_heads=num_heads,\n                                              dropout=dropout)\n\n    def forward(self, x, y):\n        x = self.input_net(x)\n        y = self.input_net(y)\n       

In [283]:

class TransformerPredictor(nn.Module):

    def __init__(self, input_dim=512, model_dim=32,
                 num_heads=1, num_layers=1,
                 dropout=0.0, input_dropout=0.0):

        super().__init__()

        # Input dim -> Model dim
        self.input_net = nn.Sequential(
            #nn.Dropout(input_dropout),
            nn.Linear(input_dim, model_dim)
        )
        self.output_net = nn.Sequential(
            #nn.Dropout(input_dropout),
            nn.Linear(model_dim, input_dim)
        )
        # Positional encoding for sequences
        #self.positional_encoding = PositionalEncoding(d_model=model_dim, max_len=100)
        # Transformer
        self.transformer = TransformerEncoder(num_layers=num_layers,
                                              input_dim=model_dim,
                                              dim_feedforward=2*model_dim,
                                              num_heads=num_heads,
                                              dropout=dropout)

    def forward(self, x, y):
        x = self.input_net(x)
        #y = self.input_net(y)
        #print(x.shape)
        #print(y.shape)
        #x = self.positional_encoding(x)
        #y = self.positional_encoding(y.unsqueeze(1))
        x = self.transformer(x)
        #print(x.shape)
        #print(y.shape)
        #op = torch.sum(x[:,-1,:]*y.squeeze(1), dim=1)
        #op = torch.sum(torch.sum(x, dim=1)*y.squeeze(1), dim=1)
        x = self.output_net(torch.sum(x, dim=1))
        op = torch.sum(x*y, dim=1)
        #op = torch.sum(x[:,-1,:]*y, dim=1)
        return op


In [284]:
model = TransformerPredictor(num_layers=2, num_heads=2, model_dim=240).to(device)

In [285]:
model

TransformerPredictor(
  (input_net): Sequential(
    (0): Linear(in_features=512, out_features=240, bias=True)
  )
  (output_net): Sequential(
    (0): Linear(in_features=240, out_features=512, bias=True)
  )
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0): EncoderBlock(
        (self_attn): MultiheadAttention(
          (qkv_proj): Linear(in_features=240, out_features=720, bias=True)
          (o_proj): Linear(in_features=240, out_features=240, bias=True)
        )
        (linear_net): Sequential(
          (0): Linear(in_features=240, out_features=480, bias=True)
          (1): Dropout(p=0.0, inplace=False)
          (2): ReLU(inplace=True)
          (3): Linear(in_features=480, out_features=240, bias=True)
        )
        (norm1): LayerNorm((240,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((240,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (1): EncoderBlock(
        (self_attn): 

In [286]:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])

In [287]:
print(params)

1173392


In [288]:
num_epochs = 100
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
#scheduler = StepLR(optimizer, step_size=5, gamma=0.8)
scheduler = CosineWarmupScheduler(optimizer, warmup=50,
                                             max_iters=num_epochs*len(train_loader))

In [289]:
#checkpoint_path = '/data/sherin/checkpoint_lm/chkpt_synth_tran_posenc_steplr_recall_best.pt.tar'
checkpoint_path = '/data/sherin/checkpoint_lm/chkpt_synth_tran_dim_240_steplr_recall_best.pt.tar'

In [290]:
epoch_loss_list = []
accuracy_list = []
val_loss_list = []
val_acc_list = []
valid_acc_max = 0 

for epoch in range(num_epochs):
    train_count = 0
    model.train()
    epoch_loss = 0.0
    accuracy = 0.0
    

    for context, query, labels in tqdm(train_loader):
        train_count = train_count+1
        context = context.to(device)
        query = query.to(device)    
        target = labels.to(device)
        label = labels.float().to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(context, query)
        loss = criterion(outputs, label)
        loss.backward()

        optimizer.step()
        
        epoch_loss += loss.item()
        accuracy += torchmetrics.functional.accuracy(outputs, target, threshold=0.5).item()
        
        scheduler.step()

    
    model.eval()
    val_loss = 0.0
    val_acc = 0.0
    
    test_count = 0
    for context, query, labels in tqdm(val_loader):
        test_count = test_count + 1
        context = context.to(device)
        query = query.to(device)
        target = labels.to(device)
        label = labels.float().to(device)
        
        outputs = model(context, query)
        loss = criterion(outputs, label)
        val_loss += loss.item()
        val_acc += torchmetrics.functional.accuracy(outputs, target, threshold=0.5).item()
        
    accuracy = accuracy / train_count
    epoch_loss = epoch_loss / train_count
    val_loss = val_loss / test_count
    val_acc = val_acc / test_count
    
    if val_acc > valid_acc_max:
        logging.info("saving best model")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': val_loss,
            'accuracy': val_acc,
            }, checkpoint_path)
        valid_acc_max = val_acc
    else:
        logging.info("not saving the model")
    
    curr_lr = optimizer.param_groups[0]['lr']
    logging.info(f'curr_lr: {curr_lr}')
    logging.info(f'[{epoch + 1}] Training loss: {epoch_loss:.3f} Training accuracy : {accuracy:.3f}')
    logging.info(f'[{epoch + 1}] Validation loss: {val_loss:.3f} Validation accuracy : {val_acc:.3f}')
    epoch_loss_list.append(epoch_loss)
    accuracy_list.append(accuracy)
    val_loss_list.append(val_loss)
    val_acc_list.append(val_acc)
    
    # scheduler.step()

print('Finished Training')

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2955/2955 [00:56<00:00, 52.57it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1281/1281 [00:22<00:00, 58.22it/s]
saving best model
saving best model
curr_lr: 0.0009997532801828658
curr_lr: 0.0009997532801828658
[1] Training loss: 3.334 Training accuracy : 0.543
[1] Training loss: 3.334 Training accuracy : 0.543
[1] Validation loss: 3.607 Validation accuracy : 0.498
[1] Validation loss: 3.607 Validation accuracy : 0.498
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2955/2955 [01:25<00:00, 34.70it/s]
100%|██████████████████████████████████████████████████████████████████████████████████

KeyboardInterrupt: 

In [161]:
print("finished")

finished


In [260]:
PATH = checkpoint_path
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
acc = checkpoint['accuracy']

# inferece
model.eval()

TransformerPredictor(
  (input_net): Sequential(
    (0): Linear(in_features=512, out_features=240, bias=True)
  )
  (output_net): Sequential(
    (0): Linear(in_features=240, out_features=512, bias=True)
  )
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0): EncoderBlock(
        (self_attn): MultiheadAttention(
          (qkv_proj): Linear(in_features=240, out_features=720, bias=True)
          (o_proj): Linear(in_features=240, out_features=240, bias=True)
        )
        (linear_net): Sequential(
          (0): Linear(in_features=240, out_features=480, bias=True)
          (1): Dropout(p=0.0, inplace=False)
          (2): ReLU(inplace=True)
          (3): Linear(in_features=480, out_features=240, bias=True)
        )
        (norm1): LayerNorm((240,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((240,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (1): EncoderBlock(
        (self_attn): 

In [261]:
print(acc)

0.5011953551912568


In [262]:
test_count = 0
output_logits_trans = []
test_acc = 0
for context, query, labels in tqdm(test_loader):
    test_count = test_count + 1
    context = context.to(device)
    query = query.to(device)
    target = labels.to(device)
    label = labels.float().to(device)

    outputs = model(context, query)
    output_logits = outputs.detach().cpu().numpy()
    test_acc += torchmetrics.functional.accuracy(outputs, target, threshold=0.5).item()
    output_logits_trans.append(output_logits)

accuracy = test_acc/test_count
print("The test accuracy is {}".format(accuracy))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1818/1818 [00:28<00:00, 62.81it/s]

The test accuracy is 0.501117299229923





In [263]:
test_output_logits_trans = np.hstack(output_logits_trans)

In [264]:
trans_pred = torch.sigmoid(torch.tensor(test_output_logits_trans))
trans_pred_label = 1.0 * (trans_pred > 0.5)

In [272]:
np.sum(1.0 * (trans_pred_label.cpu().numpy()==1))/len(trans_pred_label)

0.18624518701870188

In [273]:
np.sum(1.0 * (test_label==1))/len(test_label)

0.5011344884488449

In [167]:
torch.save(trans_pred, 'new_trans_pred.pt')
torch.save(trans_pred_label, 'new_trans_pred_label.pt')