In [10]:
import os
import pandas as pd
import azureml.core
import numpy as np
import plotly.graph_objects as go
from IPython.core.display import HTML
from utils import *

from azureml.core import Workspace, Environment, Experiment, Datastore, Dataset, ScriptRunConfig
from azureml.train.automl.run import AutoMLRun
from azureml.core.datastore import Datastore
from azureml.data.data_reference import DataReference

from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
import torch
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
from sklearn import preprocessing
import torch
from transformers import TrainingArguments, Trainer
from transformers import BertTokenizer, BertForSequenceClassification, AutoModelForSequenceClassification, AutoTokenizer
from transformers import EarlyStoppingCallback
from transformers.integrations import AzureMLCallback
from transformers import AutoTokenizer, DataCollatorWithPadding
import joblib
# from datasets import Dataset, DatasetDict

# Check core SDK version number
print("SDK version:", azureml.core.VERSION)


SDK version: 1.42.0


In [2]:
from azureml.core import Workspace
ws = Workspace.from_config()

In [3]:
ds_X_train = Dataset.get_by_name(ws, name="owner_g_classfication_train", version=8)
ds_X_val = Dataset.get_by_name(ws, name="owner_g_classfication_val", version=8)
ds_X_test = Dataset.get_by_name(ws, name="owner_g_classfication_test", version=8)
ds_temporal_test = Dataset.get_by_name(ws, name="owner_g_classfication_temporal_test", version=3)

In [4]:
print(f'{ds_X_train.tags}: V{ds_X_train.version}')
print(f'{ds_X_val.tags}: V{ds_X_val.version}')
print(f'{ds_X_test.tags}: V{ds_X_test.version}')
print(f'{ds_temporal_test.tags}: V{ds_temporal_test.version}')

{'top': '50', 'period': 'REPORT_MONTH >= 201808', 'ratio': '80%', 'description': 'records of the retired classes removed', 'valid_period': '202105'}: V8
{'top': '50', 'period': 'REPORT_MONTH >= 201808', 'ratio': '10%', 'description': 'records of the retired classes removed', 'valid_period': '202105'}: V8
{'top': '50', 'period': 'REPORT_MONTH >= 201808', 'ratio': '10%', 'description': 'records of the retired classes removed', 'valid_period': '202105'}: V8
{'top': '50', 'period': 'REPORT_MONTH == 202206', 'description': 'records of the retired classes removed', 'valid_period': '202105'}: V3


In [5]:
pdf_X_train = ds_X_train.to_pandas_dataframe()
pdf_X_val = ds_X_val.to_pandas_dataframe()
pdf_X_test = ds_X_test.to_pandas_dataframe()
pdf_temporal_test = ds_temporal_test.to_pandas_dataframe()

In [7]:
base_checkpoint = "bert-base-uncased"
text_field_name = "TEXT_FINAL"
target_name = "target"

In [8]:
model_directory = 'model_output/model'
model = AutoModelForSequenceClassification.from_pretrained(model_directory, num_labels=51)
tokenizer = AutoTokenizer.from_pretrained(model_directory)

In [11]:
le=joblib.load(model_directory + '/labelEncoder.joblib')
le

LabelEncoder()

In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [13]:
# model = best_model.model
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
# tokenizer = best_model.tokenizer
print(device)

cuda:0


In [14]:
from datasets import Dataset

train = Dataset.from_pandas(pdf_temporal_test)


In [15]:
def predict(inputs):
    return model(inputs)[0]

In [16]:
ref_token_id = tokenizer.pad_token_id
sep_token_id = tokenizer.sep_token_id
cls_token_id = tokenizer.cls_token_id

In [17]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

In [18]:
def custom_forward(inputs):
    preds = predict(inputs)
    return torch.softmax(preds, dim = 1)[0][1].unsqueeze(-1)

In [19]:
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

In [22]:
text = list(pdf_temporal_test['TEXT_FINAL'].iloc[0:1])[0]

In [23]:
# print(pdf_temporal_test[(pdf_temporal_test['target'] == pdf_temporal_test['pred'])][['TEXT_FINAL', 'target', 'pred']])
# text = list(pdf_temporal_test[(pdf_temporal_test['target'] == pdf_temporal_test['pred'])]['TEXT_FINAL'])[1]
# text

In [24]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [25]:
predict(input_ids)

tensor([[ 0.1896, -1.3090, -0.3546,  1.2868, -2.0566, -2.0536, -3.1872, -1.2230,
          0.6219, -1.7276, -1.3456, -1.9479, -1.2706,  0.7728,  1.5984,  8.2513,
          2.8983, -0.5932, -1.3020, -0.9615, -0.6784,  0.7940, -1.1398, -1.4858,
         -0.5806,  0.3947, -1.0058, -0.6386,  1.5125,  0.6689, -1.1472, -0.6822,
         -1.8423, -0.9726, -0.5699, -1.5530, -1.5124, -0.2898,  0.9582, -2.2328,
         -1.2474, -1.7703, -1.5466,  0.0880, -0.2355, -3.7614,  6.7723,  2.1291,
         -1.4845, -0.7794, -2.2017]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

In [26]:
custom_forward(input_ids)

tensor([5.6603e-05], device='cuda:0', grad_fn=<UnsqueezeBackward0>)

In [27]:
attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    n_steps=700,
                                    internal_batch_size=3,
                                    return_convergence_delta=True)

In [28]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [29]:
attributions_sum = summarize_attributions(attributions)
attributions_sum

tensor([ 0.0000e+00, -1.5867e-01,  3.0951e-01,  3.5359e-02, -4.7821e-02,
        -1.1050e-02,  1.2610e-02,  3.0527e-02, -7.7894e-03, -4.7559e-02,
        -9.7044e-02, -2.3414e-01, -1.1015e-01,  5.6865e-03,  4.3947e-03,
        -1.2311e-01, -1.3072e-01, -1.0126e-01,  2.9526e-02, -2.4167e-02,
        -3.1643e-01,  7.7432e-03, -2.3589e-03, -5.2080e-02,  7.0373e-04,
        -1.3044e-03, -5.4402e-03,  1.5186e-03, -1.0655e-02, -2.5694e-03,
        -4.4033e-03, -8.3682e-03, -2.2264e-02, -4.9726e-02,  1.1375e-02,
         2.0325e-02,  1.4290e-01,  6.2434e-01, -1.4773e-02,  4.7644e-02,
         3.6899e-02,  3.1217e-02,  2.5073e-03, -9.1011e-03, -4.4093e-02,
        -2.2101e-01, -2.2543e-03, -2.1773e-02,  1.8778e-03,  5.7791e-03,
        -2.2757e-02, -2.7707e-02, -7.4659e-02,  3.8572e-02, -7.5440e-03,
         1.1907e-03, -4.7631e-02,  1.3289e-02,  3.0799e-02,  9.9922e-02,
        -9.9393e-02, -1.6267e-04,  1.3959e-02,  3.1790e-03, -9.5089e-04,
         2.2764e-03, -4.4162e-04,  1.6550e-02,  6.4

In [34]:
score = predict(input_ids)

print('Sentence: ', text)
score

Sentence:  SSC-SPC  you are being notified of 1 alarm s . alarm url . urlremoved webapp alarm url .  . severity  major. date time    10 50 09  . name  sqlmdb 0006259.ds.gc.ca. network address         . secure domain . type  uimhost. acknowledged  no. alarm title  memoryphysical percentused total. landscape  dc01swe0026 0x700000 . event    2022  22 50 09  uim s spectrum gateway probe. uim event generated with the following details . average 3 samples physical memory usage is now 90  which is above the error threshold 90 . top processes sqlservr.exe 1764 78.20   mcshield.exe 3348 1.18   cvd.exe 2800 0.58   monitoringhost.exe 7848 0.36   sscqry.exe 2476 0.28  . source sqlmdb 0006259. ip hostname        . level 2. suppression key memory physical. probe name cdm. origin cdoq uim hubs. uim alarm id sb92763652 07100. alarm source 2. changeowner 3. m


tensor([[ 0.1896, -1.3090, -0.3546,  1.2868, -2.0566, -2.0536, -3.1872, -1.2230,
          0.6219, -1.7276, -1.3456, -1.9479, -1.2706,  0.7728,  1.5984,  8.2513,
          2.8983, -0.5932, -1.3020, -0.9615, -0.6784,  0.7940, -1.1398, -1.4858,
         -0.5806,  0.3947, -1.0058, -0.6386,  1.5125,  0.6689, -1.1472, -0.6822,
         -1.8423, -0.9726, -0.5699, -1.5530, -1.5124, -0.2898,  0.9582, -2.2328,
         -1.2474, -1.7703, -1.5466,  0.0880, -0.2355, -3.7614,  6.7723,  2.1291,
         -1.4845, -0.7794, -2.2017]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

In [32]:
score_vis = viz.VisualizationDataRecord(attributions_sum,
                                        torch.softmax(score, dim = 1)[0][1],
                                        torch.argmax(torch.softmax(score, dim = 0)[0]),
                                        1,
                                        text,
                                        attributions_sum.sum(),       
                                        all_tokens,
                                        delta)


In [33]:
print('Visualization For Score')
viz.visualize_text([score_vis])

Visualization For Score


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,0 (0.00),SSC-SPC you are being notified of 1 alarm s . alarm url . urlremoved webapp alarm url . . severity major. date time 10 50 09 . name sqlmdb 0006259.ds.gc.ca. network address . secure domain . type uimhost. acknowledged no. alarm title memoryphysical percentused total. landscape dc01swe0026 0x700000 . event 2022 22 50 09 uim s spectrum gateway probe. uim event generated with the following details . average 3 samples physical memory usage is now 90 which is above the error threshold 90 . top processes sqlservr.exe 1764 78.20 mcshield.exe 3348 1.18 cvd.exe 2800 0.58 monitoringhost.exe 7848 0.36 sscqry.exe 2476 0.28 . source sqlmdb 0006259. ip hostname . level 2. suppression key memory physical. probe name cdm. origin cdoq uim hubs. uim alarm id sb92763652 07100. alarm source 2. changeowner 3. m,-0.95,[CLS] ssc-spc you are being notified of 1 alarm s . alarm ur ##l . ur ##lr ##em ##ove ##d web ##app alarm ur ##l . . severity major . date time 10 50 09 . name sql ##md ##b 000 ##6 ##25 ##9 . ds . g ##c . ca . network address . secure domain . type ui ##m ##hos ##t . acknowledged no . alarm title memory ##physical percent ##used total . landscape dc ##01 ##sw ##e ##00 ##26 0 ##x ##70 ##00 ##00 . event 202 ##2 22 50 09 ui ##m s spectrum gateway probe . ui ##m event generated with the following details . average 3 samples physical memory usage is now 90 which is above the error threshold 90 . top processes sql ##ser ##vr . ex ##e 1764 78 . 20 mc ##shi ##eld . ex ##e 334 ##8 1 . 18 cv ##d . ex ##e 280 ##0 0 . 58 monitoring ##hos ##t . ex ##e 78 ##48 0 . 36 ss ##c ##q ##ry . ex ##e 247 ##6 0 . 28 . source sql ##md ##b 000 ##6 ##25 ##9 . ip host ##name . level 2 . suppression key memory physical . probe name cd ##m . origin cd ##o ##q ui ##m hub ##s . ui ##m alarm id sb ##9 ##27 ##6 ##36 ##52 07 ##100 . alarm source 2 . change ##own ##er 3 . m [SEP]
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,0 (0.00),SSC-SPC you are being notified of 1 alarm s . alarm url . urlremoved webapp alarm url . . severity major. date time 10 50 09 . name sqlmdb 0006259.ds.gc.ca. network address . secure domain . type uimhost. acknowledged no. alarm title memoryphysical percentused total. landscape dc01swe0026 0x700000 . event 2022 22 50 09 uim s spectrum gateway probe. uim event generated with the following details . average 3 samples physical memory usage is now 90 which is above the error threshold 90 . top processes sqlservr.exe 1764 78.20 mcshield.exe 3348 1.18 cvd.exe 2800 0.58 monitoringhost.exe 7848 0.36 sscqry.exe 2476 0.28 . source sqlmdb 0006259. ip hostname . level 2. suppression key memory physical. probe name cdm. origin cdoq uim hubs. uim alarm id sb92763652 07100. alarm source 2. changeowner 3. m,-0.95,[CLS] ssc-spc you are being notified of 1 alarm s . alarm ur ##l . ur ##lr ##em ##ove ##d web ##app alarm ur ##l . . severity major . date time 10 50 09 . name sql ##md ##b 000 ##6 ##25 ##9 . ds . g ##c . ca . network address . secure domain . type ui ##m ##hos ##t . acknowledged no . alarm title memory ##physical percent ##used total . landscape dc ##01 ##sw ##e ##00 ##26 0 ##x ##70 ##00 ##00 . event 202 ##2 22 50 09 ui ##m s spectrum gateway probe . ui ##m event generated with the following details . average 3 samples physical memory usage is now 90 which is above the error threshold 90 . top processes sql ##ser ##vr . ex ##e 1764 78 . 20 mc ##shi ##eld . ex ##e 334 ##8 1 . 18 cv ##d . ex ##e 280 ##0 0 . 58 monitoring ##hos ##t . ex ##e 78 ##48 0 . 36 ss ##c ##q ##ry . ex ##e 247 ##6 0 . 28 . source sql ##md ##b 000 ##6 ##25 ##9 . ip host ##name . level 2 . suppression key memory physical . probe name cd ##m . origin cd ##o ##q ui ##m hub ##s . ui ##m alarm id sb ##9 ##27 ##6 ##36 ##52 07 ##100 . alarm source 2 . change ##own ##er 3 . m [SEP]
,,,,


In [58]:
text

'CED-DEC  ca alert notification.   name email  exchange online  ced.   with service exchange online. alert type cloud exchange online ex387365 sev2 lo pri. alert status major. url n a. description cloud monitoring for service  exchange online. support period 7 24. support period gmt offset atlantic standard time. srmis system application email  exchange online. full message . uim   ok newstatus major. uim alarmid  rk34858639 17174  user tag2  7 24 . alarm details . incident ex387365  severity sev2  status investigating  title delays or problems connecting to exchange online  user impact users may experience problems connecting to exchange online due to network issues.  current status we re investigating a potential issue and checking for impact to your organization. we ll provide an update within  . alert notification created at    . '