In [1]:
import os
import pandas as pd
import azureml.core
import numpy as np
import plotly.graph_objects as go
from IPython.core.display import HTML
from utils import *

from azureml.core import Workspace, Environment, Experiment, Datastore, Dataset, ScriptRunConfig
from azureml.train.automl.run import AutoMLRun
from azureml.core.datastore import Datastore
from azureml.data.data_reference import DataReference

from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
import torch
import matplotlib.pyplot as plt

# Check core SDK version number
print("SDK version:", azureml.core.VERSION)


SDK version: 1.42.0


In [2]:
from azureml.core import Workspace
ws = Workspace.from_config()

In [3]:
ds_X_train = Dataset.get_by_name(ws, name="owner_g_classfication_train")
ds_X_val = Dataset.get_by_name(ws, name="owner_g_classfication_val")
ds_X_test = Dataset.get_by_name(ws, name="owner_g_classfication_test")
ds_temporal_test = Dataset.get_by_name(ws, name="owner_g_classfication_temporal_test")

In [4]:
print(f'{ds_X_train.tags}: V{ds_X_train.version}')
print(f'{ds_X_val.tags}: V{ds_X_val.version}')
print(f'{ds_X_test.tags}: V{ds_X_test.version}')
print(f'{ds_temporal_test.tags}: V{ds_temporal_test.version}')

{'top': '50', 'period': 'REPORT_MONTH >= 201808', 'ratio': '90%', 'description': 'classes not rebalanced; classes not used in last year removed', 'valid_period': '202107'}: V14
{'top': '50', 'period': 'REPORT_MONTH >= 201808', 'ratio': '5%', 'description': 'classes not rebalanced; classes not used in last year removed', 'valid_period': '202107'}: V14
{'top': '50', 'period': 'REPORT_MONTH >= 201808', 'ratio': '5%', 'description': 'classes not rebalanced; classes not used in last year removed', 'valid_period': '202107'}: V14
{'top': '50', 'period': 'REPORT_MONTH == 202206', 'description': 'no processing except setting non-top50 to "other"', 'valid_period': '202107'}: V9


In [5]:
pdf_X_train = ds_X_train.to_pandas_dataframe()
pdf_X_val = ds_X_val.to_pandas_dataframe()
pdf_X_test = ds_X_test.to_pandas_dataframe()
pdf_temporal_test = ds_temporal_test.to_pandas_dataframe()

In [6]:
pdf_temporal_test = pd.read_csv('pdf_temporal_test_v8.csv')
pdf_temporal_test.head()

Unnamed: 0,TEXT_FINAL,target,pred
0,SSC-SPC you are being notified of 1 alarm s ....,DC000203,DC000203
1,CED-DEC ca alert notification. name email ...,NW000509,NW000509
2,CSC-SCC csc orion monitor reports the alert is...,DC000242,DC000242
3,IRCC-IRCC . affected party details name ...,DC000222,DC000222
4,SSC-SPC you are being notified of 1 alarm s ....,DC000203,DC000203


In [7]:
experiment_name = "ownergroup-classification-automl"

experiment = Experiment(ws, experiment_name)

In [8]:
automl_run = AutoMLRun(experiment=experiment, run_id="AutoML_83029f3e-810b-4b9c-95d1-7b4c8adc1926")

best_run, best_model = automl_run.get_output()
best_run




Experiment,Id,Type,Status,Details Page,Docs Page
ownergroup-classification-automl,AutoML_83029f3e-810b-4b9c-95d1-7b4c8adc1926_HD_0,azureml.scriptrun,Completed,Link to Azure Machine Learning studio,Link to Documentation


In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [31]:
model = best_model.model
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = best_model.tokenizer

In [33]:
from datasets import Dataset

train = Dataset.from_pandas(pdf_temporal_test)


In [11]:
def predict(inputs):
    return model(inputs)[0]

In [12]:
ref_token_id = tokenizer.pad_token_id
sep_token_id = tokenizer.sep_token_id
cls_token_id = tokenizer.cls_token_id

In [13]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

In [14]:
def custom_forward(inputs):
    preds = predict(inputs)
    return torch.softmax(preds, dim = 1)[0][1].unsqueeze(-1)

In [15]:
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

In [48]:
print(pdf_temporal_test[(pdf_temporal_test['target'] == pdf_temporal_test['pred'])][['TEXT_FINAL', 'target', 'pred']])
text = list(pdf_temporal_test[(pdf_temporal_test['target'] == pdf_temporal_test['pred'])]['TEXT_FINAL'])[1]
text

                                             TEXT_FINAL    target      pred
0     SSC-SPC  you are being notified of 1 alarm s ....  DC000203  DC000203
1     CED-DEC  ca alert notification.   name email  ...  NW000509  NW000509
2     CSC-SCC csc orion monitor reports the alert is...  DC000242  DC000242
3     IRCC-IRCC . affected party details name       ...  DC000222  DC000222
4     SSC-SPC  you are being notified of 1 alarm s ....  DC000203  DC000203
...                                                 ...       ...       ...
8093  ESDC-EDSC . partner service desk ticket number...  ITS00380  ITS00380
8094  ESDC-EDSC  vendor or partner service desk tick...  SM000562  SM000562
8095  PCH-PCH  vendor or partner service desk ticket...  NW000438  NW000438
8096  VAC-ACC  affected end user details name       ...  DC000155  DC000155
8098  VAC-ACC affected end user details name       v...  DC000155  DC000155

[5740 rows x 3 columns]


'CED-DEC  ca alert notification.   name email  exchange online  ced.   with service exchange online. alert type cloud exchange online ex387365 sev2 lo pri. alert status major. url n a. description cloud monitoring for service  exchange online. support period 7 24. support period gmt offset atlantic standard time. srmis system application email  exchange online. full message . uim   ok newstatus major. uim alarmid  rk34858639 17174  user tag2  7 24 . alarm details . incident ex387365  severity sev2  status investigating  title delays or problems connecting to exchange online  user impact users may experience problems connecting to exchange online due to network issues.  current status we re investigating a potential issue and checking for impact to your organization. we ll provide an update within  . alert notification created at    . '

In [49]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [50]:
predict(input_ids)

tensor([[-0.6064, -1.1990, -0.9460, -1.3105, -0.4936, -1.3775, -0.2894, -0.8588,
         -0.2234, -1.9332, -0.8662,  0.0800, -0.5278, -0.1348,  0.3483, -1.3537,
         -1.4502,  0.1984, -3.3187, -1.5394, -1.0437, -0.1271, -0.0535, -1.0378,
         -0.9141, -2.1266, -0.1341, -0.7088, -0.2511, -0.7757,  0.5961, -0.6229,
         -0.6179,  0.3886, -1.0177, -0.2188,  0.8255,  1.7482, -0.9314, -1.3142,
         -2.4569, -0.7439, -1.8520, -1.8809,  3.5210,  9.5727,  5.4712,  1.2765,
         -1.0309, -0.5070, -0.6725]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

In [51]:
custom_forward(input_ids)

tensor([2.0543e-05], device='cuda:0', grad_fn=<UnsqueezeBackward0>)

In [52]:
attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    n_steps=700,
                                    internal_batch_size=3,
                                    return_convergence_delta=True)

In [54]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [55]:
attributions_sum = summarize_attributions(attributions)
attributions_sum

tensor([ 0.0000e+00, -7.6655e-02, -7.3327e-02, -7.0261e-03, -4.4720e-01,
        -2.8971e-01, -5.9013e-02, -9.5984e-02,  8.3211e-02, -7.9321e-02,
         1.2714e-01, -1.4084e-02, -1.5596e-02,  5.9259e-03,  3.0305e-02,
         1.3332e-02, -5.6113e-02,  1.2894e-01, -1.6417e-01, -1.8260e-01,
        -1.0261e-01,  2.9147e-02,  1.3248e-03, -3.7193e-02, -1.2165e-04,
        -2.4074e-01, -1.0837e-01,  7.1543e-02, -9.5835e-02, -1.8123e-02,
         6.1496e-03, -2.4275e-02, -9.6232e-02,  7.4157e-02,  3.8227e-02,
         4.3980e-02,  8.2896e-02,  2.0468e-01, -5.7406e-02,  2.4514e-01,
         1.0236e-01, -2.6075e-02,  2.9989e-02, -7.6314e-02,  1.3934e-02,
        -2.2020e-02,  4.9194e-02, -2.0141e-02,  2.1494e-01,  1.2593e-01,
        -5.2886e-02, -1.2952e-01, -8.6297e-02, -2.7880e-02,  2.2914e-02,
         1.1988e-03, -4.5529e-04, -2.3379e-02,  3.9362e-02,  5.5600e-02,
         4.4396e-02,  4.0053e-02,  1.4121e-02,  8.1598e-04, -5.3042e-02,
         2.8761e-02,  5.8028e-02,  3.6147e-02, -3.7

In [56]:
score_vis = viz.VisualizationDataRecord(attributions_sum,
                                        torch.softmax(score, dim = 1)[0][1],
                                        torch.argmax(torch.softmax(score, dim = 0)[0]),
                                        1,
                                        text,
                                        attributions_sum.sum(),       
                                        all_tokens,
                                        delta)


In [57]:
print('Visualization For Score')
viz.visualize_text([score_vis])

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,0 (0.00),CED-DEC ca alert notification. name email exchange online ced. with service exchange online. alert type cloud exchange online ex387365 sev2 lo pri. alert status major. url n a. description cloud monitoring for service exchange online. support period 7 24. support period gmt offset atlantic standard time. srmis system application email exchange online. full message . uim ok newstatus major. uim alarmid rk34858639 17174 user tag2 7 24 . alarm details . incident ex387365 severity sev2 status investigating title delays or problems connecting to exchange online user impact users may experience problems connecting to exchange online due to network issues. current status we re investigating a potential issue and checking for impact to your organization. we ll provide an update within . alert notification created at .,-2.66,[CLS] CE ##D - DE ##C ca alert not ##ification . name email exchange online c ##ed . with service exchange online . alert type cloud exchange online ex ##38 ##7 ##36 ##5 se ##v ##2 lo p ##ri . alert status major . u ##rl n a . description cloud monitoring for service exchange online . support period 7 24 . support period g ##m ##t offset at ##lant ##ic standard time . s ##rm ##is system application email exchange online . full message . u ##im ok news ##tat ##us major . u ##im alarm ##id r ##k ##34 ##8 ##5 ##86 ##39 171 ##7 ##4 user tag ##2 7 24 . alarm details . incident ex ##38 ##7 ##36 ##5 severity se ##v ##2 status investigating title delays or problems connecting to exchange online user impact users may experience problems connecting to exchange online due to network issues . current status we re investigating a potential issue and checking for impact to your organization . we ll provide an update within . alert not ##ification created at . [SEP]
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,0 (0.00),CED-DEC ca alert notification. name email exchange online ced. with service exchange online. alert type cloud exchange online ex387365 sev2 lo pri. alert status major. url n a. description cloud monitoring for service exchange online. support period 7 24. support period gmt offset atlantic standard time. srmis system application email exchange online. full message . uim ok newstatus major. uim alarmid rk34858639 17174 user tag2 7 24 . alarm details . incident ex387365 severity sev2 status investigating title delays or problems connecting to exchange online user impact users may experience problems connecting to exchange online due to network issues. current status we re investigating a potential issue and checking for impact to your organization. we ll provide an update within . alert notification created at .,-2.66,[CLS] CE ##D - DE ##C ca alert not ##ification . name email exchange online c ##ed . with service exchange online . alert type cloud exchange online ex ##38 ##7 ##36 ##5 se ##v ##2 lo p ##ri . alert status major . u ##rl n a . description cloud monitoring for service exchange online . support period 7 24 . support period g ##m ##t offset at ##lant ##ic standard time . s ##rm ##is system application email exchange online . full message . u ##im ok news ##tat ##us major . u ##im alarm ##id r ##k ##34 ##8 ##5 ##86 ##39 171 ##7 ##4 user tag ##2 7 24 . alarm details . incident ex ##38 ##7 ##36 ##5 severity se ##v ##2 status investigating title delays or problems connecting to exchange online user impact users may experience problems connecting to exchange online due to network issues . current status we re investigating a potential issue and checking for impact to your organization . we ll provide an update within . alert not ##ification created at . [SEP]
,,,,


In [58]:
text

'CED-DEC  ca alert notification.   name email  exchange online  ced.   with service exchange online. alert type cloud exchange online ex387365 sev2 lo pri. alert status major. url n a. description cloud monitoring for service  exchange online. support period 7 24. support period gmt offset atlantic standard time. srmis system application email  exchange online. full message . uim   ok newstatus major. uim alarmid  rk34858639 17174  user tag2  7 24 . alarm details . incident ex387365  severity sev2  status investigating  title delays or problems connecting to exchange online  user impact users may experience problems connecting to exchange online due to network issues.  current status we re investigating a potential issue and checking for impact to your organization. we ll provide an update within  . alert notification created at    . '