## Install Deps
- Transformer
- einops
- pytorch-lightning
- neptune
- gpytorch



In [None]:
!pip install pytorch-lightning==1.2.8
!pip install transformers neptune-client neptune-contrib gpytorch einops tflearn sklearn



## Download and prepare the dataset
- Data_Integration_Dataset
 - Covid-19
 - Machine_Log

In [None]:
import gdown
gdown.download('https://drive.google.com/uc?id=19oLAKktjI0uk8v4lcdBTnRBTyqN-tGeR', output=None, quiet=False)

In [None]:
!unzip -qq Data_Integration_Dataset.zip

In [None]:
import torch
import einops
import math
import heapq

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as F
import pytorch_lightning as pl
import pytorch_lightning.metrics.functional.accuracy as get_accuracy

from pytorch_lightning.metrics.utils import select_topk
from pytorch_lightning import Trainer, seed_everything
from sklearn.metrics import accuracy_score,precision_score, recall_score, f1_score


from transformers import (
    BertModel,
    BertTokenizer,
    AdamW,
    get_cosine_with_hard_restarts_schedule_with_warmup,
    AutoModel
    )

from torch.utils.data import (
    random_split,
    DataLoader,
    RandomSampler,
    Subset,
    TensorDataset
    )


# Configuration

## Hyperparameters 

In [None]:
PRETRAINED_MODEL = 'bert-base-uncased'
SAMPLE=10000
CONTENT_FRAC = 0.20
BATCH_SIZE = 128
WARMUP = 2000
MAX_CYCLES = 10
MAX_EPOCHS = 60
LEARNING_RATE = 0.00005

SIGMOID_THRESHOLD = 0.5
IS_BERT = False
IS_MASK = True
USE_CASE = 'Covid19' # 'Covid19' | 'MachineLog' 
EXP_TYPE = 'COL' # 'KEY' | 'COL' | 'AGGRE'(only for covid-19 dataset)

CORE_TRANSFORMER_PARAMS = dict(
    num_layers=12,\
    embedding_size=128,\
    layer_norm_epsilon=0.00001,\
    scale=0.01,\
    resid_pdrop=0.1,\
    attn_pdrop=0.1,\
    num_attention_heads = 8,\
    embd_pdrop=0.1,
    num_actions=3,\
    common_conv_dim=128
)

# Import Data

In [None]:
def get_data(data_type=USE_CASE,exp_type=EXP_TYPE):
  import json
  import os
  base_pth = None
  train_pth = None
  test_pth = None
  if data_type == 'MachineLog':
    base_pth = '/content/Data_Integration_Dataset/Machine_log'
    if exp_type =='KEY':
      train_pth = 'training_machine_log_data_with_key_index_prediction.json'
      test_pth = 'testing_machine_log_data_with_key_index_prediction.json'
    elif exp_type =='COL':
      train_pth = 'training_machine_log_data_with_column_label_prediction.json'
      test_pth = 'testing_machine_log_data_with_column_label_prediction.json'
  
  elif data_type == 'Covid19':
    base_pth = '/content/Data_Integration_Dataset/Covid-19' 
   
    if exp_type =='COL':
       train_pth = 'training_covid-19_data_with_column_label_prediction.json'
       test_pth = 'testing_covid-19_data_with_column_label_prediction.json'
    elif exp_type =='KEY':
      
      train_pth = 'training_covid-19_data_with_key_index_prediction.json'
      test_pth = 'testing_covid-19_data_with_key_index_prediction.json'
    elif exp_type =='AGGRE':
      train_pth = 'training_covid-19_data_with_aggregation_label_prediction.json'
      test_pth = 'testing_covid-19_data_with_aggregation_label_prediction.json'
      
  if train_pth is None or test_pth is None:
    raise Exception("dataset path not exist!")

  with open(os.path.join(base_pth,train_pth),'r',encoding='utf-8') as f:
    train_data_json = [json.loads(i) for i in f.readlines()]
  with open(os.path.join(base_pth,test_pth),'r',encoding='utf-8') as f:
    test_data_json = [json.loads(i) for i in f.readlines()]
  return train_data_json,test_data_json

In [None]:
train_data_json,test_data_json = get_data(data_type=USE_CASE,exp_type=EXP_TYPE)

In [None]:
dats = []
for i in train_data_json:
  dats.extend(i['label_index'])
fd={i:0 for i in dats}
for i in dats:
  fd[i]+=1

NUM_LABELS = len(set(dats))+1


## Inititalize weights 

In [None]:
total = sum([fd[f] for f in fd])
wt_init = [1/(fd[f]/total) for f in fd]
WEIGHTS = [w/sum(wt_init) for w in wt_init]


# Data Loader

In [None]:

seed_everything()


class DocumentDataPreprocessor():
    """DocumentDataPreprocessor 
    """
    CLASS_TOKEN = '[CLS]'
    SEP_TOKEN = '[EOS]'
    SPECIAL_TOKENS = []

    def __init__(self,tokenizer:BertTokenizer,\
                column_split_order=[],
                ):

        self.tokenizer = tokenizer

    def get_tokenized_text(self,content_text,max_length=1024,pad_to_max_length=True):
        attention_mask = []
        input_ids = []
        encoded_dict = self.tokenizer.encode_plus(
                    content_text,                      # Sentence to encode.
                    add_special_tokens = True,         # Add '[CLS]' and '[SEP]'
                    max_length = max_length,           # Pad & truncate all sentences.
                    padding = 'max_length',
                    truncation=True,
                    return_attention_mask=True,
                    return_tensors = 'pt',             # Return pytorch tensors.
            )
    
        # Add the encoded sentence to the list.    
        input_ids.append(encoded_dict['input_ids'])
        attention_mask.append(encoded_dict['attention_mask'])
        # And its attention mask (simply differentiates padding from non-padding).
        input_ids = torch.cat(input_ids,dim=0)
        attention_mask = torch.cat(attention_mask,dim=0)
        return input_ids,attention_mask
    

    @staticmethod
    def split_dataset(dataset,train_percent=0.9):
        # Create a split in train-validation 
        # Calculate the number of samples to include in each set.
        if train_percent > 1:
            raise Exception('Training Percentage cannot be > 1')
        train_size = int(train_percent * len(dataset))
        val_size = len(dataset) - train_size

        # Divide the dataset by randomly selecting samples.
        train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
        return train_dataset,val_dataset

In [None]:
def create_training_set(data_json,tokenizer,exp_type=EXP_TYPE):
  import json
  proc = DocumentDataPreprocessor(tokenizer)
  input_ids,att_mask,labels = [],[],[]
  for d in data_json:
    json_value = json.dumps(d['value'])
    i,m = proc.get_tokenized_text(json_value,max_length=70)
    input_ids.append(i)
    att_mask.append(m)
    if exp_type == 'KEY':
        labels.append(d['label_index'][0])
    elif exp_type == 'COL':
      hot_tensor= F.one_hot(torch.LongTensor([r-1 for r in d['label_index']]),num_classes=NUM_LABELS).sum(dim=0)
      labels.append(hot_tensor)

  input_ids = torch.cat(input_ids,dim=0)
  att_mask = torch.cat(att_mask,dim=0)
  if exp_type == 'KEY':
    labels = torch.Tensor(labels)
  elif exp_type == 'COL':
    labels = torch.stack(labels,dim=0)
 
 
  return TensorDataset(input_ids,att_mask,labels)

# Pretrained Tokenizer

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)#prajjwal1/bert-tiny

In [None]:
train_dataset = create_training_set(train_data_json,tokenizer)
train_dataset,val_dataset = DocumentDataPreprocessor.split_dataset(train_dataset)
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)
val_loader = DataLoader(val_dataset,batch_size=BATCH_SIZE,shuffle=True)

test_dataset =  create_training_set(test_data_json,tokenizer)
test_loader = DataLoader(test_dataset,batch_size=BATCH_SIZE)

# Model Description

## Pretrained BERT Model

In [None]:
class PredictionHead(nn.Module):
    def __init__(self, hidden_size=768,layer_norm_eps=0.00001,num_preds=NUM_LABELS):
        super().__init__()
        
        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(hidden_size, num_preds)

    def forward(self, hidden_states):
        
        hidden_states = self.decoder(hidden_states)
        return hidden_states

class BertDataIntegrationClassifier(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.model = BertModel.from_pretrained('bert-base-uncased')
    self.pred_head = PredictionHead(hidden_size=self.model.config.hidden_size,layer_norm_eps=self.model.config.layer_norm_eps)
    self.sfmx = nn.Softmax(dim=1)
  
  def forward(self,input,att_mask=None):
    dx = self.model(input,attention_mask=att_mask)
    val = dx.pooler_output
    final_st = dx.pooler_output # val[:,0] # Getting classtoken op
    return self.pred_head(final_st) #self.sfmx()


## Encoder-Only Vanilla Transformer 

### Vanilla Self Attention
- NO bs of Other things like mask/cross attention etc.  
- Attention defined by
  - $$ W_Q,W_K,W_V: \text{Are key,query,and value matrixes}$$
  - $$attention = (softmax(W_Q \times W_K)*scale) \times W_V$$ 

In [None]:
class SimpleSelfAttention(nn.Module):
  '''
  Vanilla Self attention on sequence. No Masking etc.
  '''
  def __init__(self,hidden_size,dropout=0.1,num_heads=4,scale=0.2,mlp_dim=3072):
    super().__init__()
    self.kqv_layer = nn.Linear(hidden_size,3*hidden_size)
    self.num_heads = num_heads
    self.ff_layer = nn.Sequential(
          nn.Linear(hidden_size, hidden_size),
          nn.Dropout(dropout)
    )
    
    self.scale = hidden_size ** -scale

  def forward(self,sequence_embedding):
   
    kqv = self.kqv_layer(sequence_embedding).chunk(3, dim = -1)
    
    k,q,v = map(lambda x:einops.rearrange(x,'b s (h d) -> b h s d',h=self.num_heads),kqv)
    scaled_dot_product = torch.einsum('bhsd,bhnd->bhsn',k,q) * self.scale
    weighted_sum = F.softmax(scaled_dot_product,dim=-1)
    value_weighted_sum = torch.einsum('bhsn,bhsd->bhnd',weighted_sum,v)
    reweighted_sequence_embedding = einops.rearrange(value_weighted_sum,'b h s d -> b s (h d)',h=self.num_heads)
    return self.ff_layer(reweighted_sequence_embedding)

In [None]:

class Conv1D(nn.Module):
    """
    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
    Basically works like a linear layer but the weights are transposed.
    Args:
        nf (:obj:`int`): The number of output features.
        nx (:obj:`int`): The number of input features.
    """

    def __init__(self, nf, nx):
        super().__init__()
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(nf))

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(*size_out)
        return x

In [None]:
class MLP(nn.Module):
    def __init__(self, n_state,embedding_size=256,resid_pdrop=0.1):  # in MLP: n_state=(n * embedding_size)
        super().__init__()
        nx = embedding_size # n_state = outputfeatures
        self.c_fc = Conv1D(n_state, nx)
        self.c_proj = Conv1D(nx, n_state) # nx = inputfeatures
        self.act = nn.GELU()
        self.dropout = nn.Dropout(resid_pdrop)

    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
        return self.dropout(h2)

In [None]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

In [None]:
class Block(nn.Module):
    def __init__(self, 
                 embedding_size=256,\
                 layer_norm_epsilon=0.00001,\
                 scale=0.2,\
                 resid_pdrop=0.1,\
                 attn_pdrop=0.1,\
                 num_attention_heads = 8):
        super().__init__()
        hidden_size = embedding_size
        inner_dim = 4 * hidden_size
        self.ln_1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        self.attn = Residual(SimpleSelfAttention(hidden_size,num_heads = num_attention_heads,scale=scale,dropout=attn_pdrop,mlp_dim=inner_dim))
        self.ln_2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        self.mlp = Residual(MLP(inner_dim,embedding_size=embedding_size,resid_pdrop=resid_pdrop))
    def forward(
        self,
        hidden_states,
    ):
        attn_outputs = self.attn(
            self.ln_1(hidden_states),
        )
        feed_forward_hidden_states = self.mlp(self.ln_2(attn_outputs))
        return feed_forward_hidden_states

### Positional Embedding




In [None]:
def make_positions(tensor, padding_idx, left_pad):
    """Replace non-padding symbols with their position numbers.
    Position numbers begin at padding_idx+1.
    Padding symbols are ignored, but it is necessary to specify whether padding
    is added on the left side (left_pad=True) or right side (left_pad=False).
    """
    max_pos = padding_idx + 1 + tensor.size(1)
    device = tensor.get_device()
    buf_name = f'range_buf_{device}'
    if not hasattr(make_positions, buf_name):
        setattr(make_positions, buf_name, tensor.new())
    setattr(make_positions, buf_name, getattr(make_positions, buf_name).type_as(tensor))
    if getattr(make_positions, buf_name).numel() < max_pos:
        torch.arange(padding_idx + 1, max_pos, out=getattr(make_positions, buf_name))
    mask = tensor.ne(padding_idx)
    positions = getattr(make_positions, buf_name)[:tensor.size(1)].expand_as(tensor)
    if left_pad:
        positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
    new_tensor = tensor.clone()
    return new_tensor.masked_scatter_(mask, positions[mask]).long()


class SinusoidalPositionalEmbedding(nn.Module):
    """This module produces sinusoidal positional embeddings of any length.
    Padding symbols are ignored, but it is necessary to specify whether padding
    is added on the left side (left_pad=True) or right side (left_pad=False).
    """

    def __init__(self, embedding_dim, padding_idx=0, left_pad=0, init_size=128):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.padding_idx = padding_idx
        self.left_pad = left_pad
        self.weights = dict()   # device --> actual weight; due to nn.DataParallel :-(
        self.register_buffer('_float_tensor', torch.FloatTensor(1))

    @staticmethod
    def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
        """Build sinusoidal embeddings.
        This matches the implementation in tensor2tensor, but differs slightly
        from the description in Section 3.5 of "Attention Is All You Need".
        """
        half_dim = embedding_dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
        emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
        if embedding_dim % 2 == 1:
            # zero pad
            emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
        if padding_idx is not None:
            emb[padding_idx, :] = 0
        return emb

    def forward(self, input):
        """Input is expected to be of size [bsz x seqlen]."""
        bsz, seq_len = input.size()
        max_pos = self.padding_idx + 1 + seq_len
        device = input.get_device()
        if device not in self.weights or max_pos > self.weights[device].size(0):
            # recompute/expand embeddings if needed
            self.weights[device] = SinusoidalPositionalEmbedding.get_embedding(
                max_pos,
                self.embedding_dim,
                self.padding_idx,
            )
        self.weights[device] = self.weights[device].type_as(self._float_tensor)
        positions = make_positions(input, self.padding_idx, self.left_pad)
        return self.weights[device].index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()

    def max_positions(self):
        """Maximum number of supported positions."""
        return int(1e5)  # an arbitrary large number

In [None]:
class VanillaTransformer(nn.Module):
  '''
  Contains Sinusoidal Embedding for sequences as part of the Framework. 
  '''
  def __init__(self,
              num_layers = 4,
              embed_dropout =0.1 ,\
              embedding_size=256,\
              layer_norm_epsilon=0.00001,\
              scale=0.2,\
              resid_pdrop=0.1,\
              attn_pdrop=0.1,\
              num_attention_heads = 8):
    super().__init__()
    self.layers = nn.ModuleList([])
    self.embed_dropout=embed_dropout
    self.embed_scale = math.sqrt(embedding_size)
    self.embed_positions = SinusoidalPositionalEmbedding(embedding_size)

    for _ in range(num_layers): 
        self.layers.append(Block(
                          embedding_size=embedding_size,\
                          layer_norm_epsilon=layer_norm_epsilon,\
                          scale=scale,\
                          resid_pdrop=resid_pdrop,\
                          attn_pdrop=attn_pdrop,\
                          num_attention_heads = num_attention_heads))

  def forward(self, x, mask = None):
    # Add positional embedding
    x = self.embed_scale * x # (b,len,d)
    if self.embed_positions is not None:
        x += self.embed_positions(x[:, :, 0])   
    x = F.dropout(x, p=self.embed_dropout, training=self.training)

    hidden = x
    for attention_block in self.layers:
        hidden = attention_block(hidden)
    return hidden

## Ground Up Transformer with Bert Tokenized Embedding. 

In [None]:
class BertEmbedTransformer(nn.Module):
  def __init__(self,
              num_layers=8,\
              embedding_size=128,\
              layer_norm_epsilon=0.00001,\
              scale=0.01,\
              resid_pdrop=0.1,\
              attn_pdrop=0.1,\
              num_attention_heads = 8,\
              embd_pdrop=0.1,
              num_actions=3,\
              common_conv_dim=64):
    super().__init__()
    data_input = dict(
        embedding_size=embedding_size,
        num_layers =num_layers,
        layer_norm_epsilon =layer_norm_epsilon,
        scale =scale,
        resid_pdrop =resid_pdrop,
        attn_pdrop =attn_pdrop,
        num_attention_heads =num_attention_heads,
    )
    
    self.transformer = VanillaTransformer(**data_input)
    bert_model = AutoModel.from_pretrained(PRETRAINED_MODEL)
    bert_emb = bert_model.embeddings.word_embeddings
    text_embedding_dim = bert_emb.embedding_dim
    num_emb = bert_emb.num_embeddings
    self.text_embeddings = nn.Embedding(num_embeddings=num_emb,embedding_dim=text_embedding_dim)
    
    self.text_embeddings.load_state_dict(bert_emb.state_dict())

    self.text_embeddings.weight.requires_grad = False
    self.text_conv = nn.Conv1d(text_embedding_dim,common_conv_dim,kernel_size=1,padding=0,bias=False)
    self.text_cls_token = nn.Parameter(torch.randn(1, 1, common_conv_dim))
    self.to_cls = nn.Identity()
    self.final_layer = nn.Sequential(
        nn.Linear(common_conv_dim, common_conv_dim),
        nn.ReLU(),
        nn.Dropout(embd_pdrop),
        nn.Linear(common_conv_dim, common_conv_dim)
    )
    
  def forward(self,inputs,att_mask=None):
    text_tensor = self.text_embeddings(inputs)
    
    text_tensor = text_tensor.transpose(1,2)

    text_tensor = self.text_conv(text_tensor).permute(0,2,1)
    
    # Prepend CLS tokens and Finally extract thoose instead of the last token.
    b,n,_ = text_tensor.size()
    text_cls_token = einops.repeat(self.text_cls_token,'() n d -> b n d', b = b)

    text_tensor = torch.cat((text_cls_token,text_tensor),dim=1)
    # adding Extra one for cls tokens that get prepended the tensors
    if att_mask!=None:
      att_mask = torch.cat((torch.ones(b).unsqueeze(1).to(text_tensor.device),att_mask),dim=1)
    
    text_j_tensor = self.transformer(text_tensor,mask=att_mask)
    
    l_txt = self.to_cls(text_j_tensor[:,0])
    
    # A residual block
    concat_tensor_proj = self.final_layer(l_txt)
    concat_tensor_proj+=l_txt
    
    return concat_tensor_proj 
    

class TransformerDIClassifier(torch.nn.Module):
  def __init__(self,
              num_layers=8,\
              embedding_size=64,\
              layer_norm_epsilon=0.00001,\
              scale=0.01,\
              resid_pdrop=0.1,\
              attn_pdrop=0.1,\
              num_attention_heads = 8,\
              embd_pdrop=0.1,
              num_actions=3,\
              common_conv_dim=64):
    super().__init__()
    self.model = BertEmbedTransformer(
        num_layers = num_layers,
        embedding_size = embedding_size,
        layer_norm_epsilon = layer_norm_epsilon,
        scale = scale,
        resid_pdrop = resid_pdrop,
        attn_pdrop = attn_pdrop,
        num_attention_heads = num_attention_heads,
        embd_pdrop = embd_pdrop,
        num_actions = num_actions,
        common_conv_dim = common_conv_dim,
    )
    self.pred_head = PredictionHead(hidden_size=common_conv_dim,layer_norm_eps=layer_norm_epsilon)
    
  
  def forward(self,input,att_mask=None):
    final_st = self.model(input,att_mask=att_mask)
    return self.pred_head(final_st)  #self.sfmx()


## Lightning Module


In [None]:
class DataIntegrationClassifier(pl.LightningModule):
  def __init__(self,with_mask=False,is_bert=IS_BERT):
    super().__init__()
    
    self.with_mask = with_mask
    if is_bert:
      print("Using BERT Model")
      self.model = BertDataIntegrationClassifier()
    else:
      print("Using Vanilla Backbone Model")
      self.model = TransformerDIClassifier(**CORE_TRANSFORMER_PARAMS)
  
    
    self.sfmx = nn.Softmax(dim=1)
    self.sigmoid = nn.Sigmoid()
    self.loss = nn.BCEWithLogitsLoss()
    self.sigmoid_threshold = SIGMOID_THRESHOLD
    
  def forward(self,inputs,att_mask=None): 
    '''
    inputs : (input_ids)
        - input_ids: b k :
          - b: batchsize
          - k: sequence_length
    *_mask = mask: b s : binary tensor. 
    '''
    return self.model(inputs,att_mask=att_mask)

  
  def get_topk_accuracy(self, y_pred, y_true,topk=1):
    ## single label prediction task

    pred_vals = select_topk(self.sigmoid(y_pred))
 
    num_vals = pred_vals.size()[0]
    correct_vals = 0  # Exactly Correct Vals
    
    index_list = []
    for v1, v2 in zip(pred_vals, y_true):
            
        
        if str(v1.cpu().numpy()) == str(v2.cpu().numpy()):
            correct_vals += 1
    
    
    acc = torch.tensor(correct_vals / num_vals)
    precision = torch.tensor(correct_vals / num_vals)
    return torch.tensor(correct_vals / num_vals)
  

  

  def training_step(self,batch,batch_nb):
    inps,mask,labels = batch
    
    if not self.with_mask:
      logits = self(inps,att_mask=None)
    else:
      logits = self(inps,att_mask=mask)
    loss = self.loss(logits,labels.type_as(logits))
    
    
    logit  = (self.sigmoid(logits) > 0.5).long()
    accuracy = self.get_topk_accuracy(logits,labels)
    
    
    precision = precision_score(logit.detach().cpu().numpy(),labels.detach().cpu().numpy(),average = 'micro')
    recall = recall_score(logit.detach().cpu().numpy(),labels.detach().cpu().numpy(),average = 'micro')
    f1 = f1_score(logit.detach().cpu().numpy(),labels.detach().cpu().numpy(),average = 'micro')
    
    self.logger.log_metrics({
        'train_accuracy':accuracy.detach().cpu().numpy(),
        'train_precision':precision,
        'train_recall':recall,
        'train_f1':f1,
        'train_loss':loss.detach().cpu().numpy(),
        'epoch': self.current_epoch,
    })
    return {'loss':loss}


  def validation_step(self,batch,batch_nb):
    inps,mask,labels = batch
   
    if not self.with_mask:
      logits = self(inps,att_mask=None)
    else:
      logits = self(inps,att_mask=mask)
    
    loss = self.loss(logits,labels.type_as(logits))
    
    logit  = (self.sigmoid(logits) > 0.5).long()
    precision = precision_score(logit.detach().cpu().numpy(),labels.detach().cpu().numpy(),average = 'micro')
    recall = recall_score(logit.detach().cpu().numpy(),labels.detach().cpu().numpy(),average = 'micro')
    f1 = f1_score(logit.detach().cpu().numpy(),labels.detach().cpu().numpy(),average = 'micro')
    
    accuracy = self.get_topk_accuracy(logits,labels)
    self.logger.log_metrics({
        'val_accuracy':accuracy.detach().cpu().numpy(),
        'val_precision':precision,
        'val_recall':recall,
        'val_f1':f1,
        'val_loss':loss.detach().cpu().numpy(),
        'epoch': self.current_epoch,
    })
    return {'loss':loss,'val_loss':loss}
  
  def test_step(self,batch,batch_nb):
    inps,mask,labels = batch
 
    if not self.with_mask:
      logits = self(inps,att_mask=None)
    else:
      logits = self(inps,att_mask=mask)
    loss = self.loss(logits,labels.type_as(logits))
    
    logit  = (self.sigmoid(logits) > 0.5).long()
    logit_numpy = logit.cpu().numpy()
    with open('predicted_logits.txt','ab') as f:
      np.savetxt(f,logit_numpy)
    precision = precision_score(logit.detach().cpu().numpy(),labels.detach().cpu().numpy(),average = 'micro')
    recall = recall_score(logit.detach().cpu().numpy(),labels.detach().cpu().numpy(),average = 'micro')
    f1 = f1_score(logit.detach().cpu().numpy(),labels.detach().cpu().numpy(),average = 'micro')
    clas_rep = classification_report(logit.detach().cpu().numpy(),labels.detach().cpu().numpy())
    print(clas_rep)
    
    
    
    accuracy = self.get_topk_accuracy(logits,labels)
    self.logger.log_metrics({
        'test_accuracy':accuracy.detach().cpu().numpy(),
        'test_precision':precision,
        'test_recall':recall,
        'test_f1':f1,
        'test_loss':loss.detach().cpu().numpy(),
        'epoch': self.current_epoch,
    })
    return {'loss':loss,'test_accuracy':accuracy}

  def configure_optimizers(self):
    optimizer =  AdamW(self.parameters(), lr=LEARNING_RATE, eps=1e-12, betas=(0.9,0.999))
    num_minibatch_steps = NUM_TRAIN_SAMPLES/BATCH_SIZE
    max_epochs = MAX_EPOCHS
    warmup = WARMUP
    t_total = max_epochs * num_minibatch_steps
    num_cycles = MAX_CYCLES
    lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, warmup, t_total,num_cycles=num_cycles)
    return [optimizer] ,[lr_scheduler]



# Trainer

### Initialize Model

In [None]:
model = DataIntegrationColumnClassifier(with_mask=IS_MASK,is_bert=IS_BERT)

### Training Process

In [None]:
NUM_TRAIN_SAMPLES = len(train_dataset)
# Instantiate ModelCheckpoint
model_checkpoint = ModelCheckpoint(filename='model/checkpoints/{epoch:02d}-{val_loss:.2f}',
                                   save_weights_only=True,
                                   save_top_k=3,
                                   monitor='val_loss',
                                   period=1)
trainer = Trainer(
    automatic_optimization=True,
    max_epochs=MAX_EPOCHS,\
    progress_bar_refresh_rate=25,\
    gpus=1,\
    checkpoint_callback=model_checkpoint
)


trainer.fit(model, train_loader,val_dataloaders=val_loader)

### Testing process 

In [None]:
trainer.test(model, test_dataloaders=DataLoader(test_dataset,batch_size=1024,shuffle=False))