In [9]:
import json
from collections import defaultdict
import plotly.express as px
import plotly.io as plotly
#plotly.renderers.default='iframe'

In [10]:
BENIGN = "Benign"
SCAM = "Malicious"
class_names = [BENIGN, SCAM]

In [11]:
def sunburst_plot(data, title="Sunburst Chart"):
    fig = px.sunburst(
        data,
        names='names',
        parents='parents',
        values='values',
        title=title,
        branchvalues='total' ,

    )


    fig.update_traces(maxdepth=2)

    fig.show()

In [12]:
# 'URL', 'IPInfo', 'FullWhois', 'CertSH', 'HTML', 'HTMLInfo', 'Links'

def get_parent_for_item(item: str):
    tokens = item.split(':')
    first_token = tokens[0].strip().lower()

    if first_token.startswith('url information') or first_token.startswith('submitted url')  or first_token.startswith('tls support'): return 'URLInfo'
    elif first_token.startswith('html information') or first_token.startswith('redirects (url') \
        or first_token.startswith('device')   or first_token.startswith('meta')    : return 'HTMLInfo'
    elif first_token.startswith('ipinfo') or first_token.startswith('address') \
        or first_token.startswith('city')  or first_token.startswith('organization') \
        or first_token.startswith('country')  or first_token.startswith('region') : return 'IPInfo'
    elif first_token.startswith('links') or first_token.startswith('total links') \
        or first_token.startswith('image')  or first_token.startswith('iframe')  or first_token.startswith('meta contents')   : return 'Links'
    elif first_token.startswith('html'): return 'HTML'
    elif first_token.startswith('fullwhois') or first_token.startswith('creation')  or first_token.startswith('expiration') \
        or first_token.startswith('total')  or first_token.startswith('registrar') or first_token.startswith('domain status') : return 'FullWhois'
    elif first_token.startswith('certsh')  or first_token.startswith('common')  or first_token.startswith('certificate') or first_token.startswith('validity')  or first_token.startswith('issuer') \
        or first_token.startswith('common domains')  : return 'CertSH'
    else: 
        print(f'Couldn not find a parent category for: {first_token}')
        return ''



In [13]:
def sunburst_expl(json_file, label, sample_idx=None): 
    aggregated_scores_benign = defaultdict(float)
    aggregated_scores_scam = defaultdict(float)
    parents = list()
    tokens_to_parents_benign = dict()
    tokens_to_parents_scam = dict()
    scores = list()
    elements = list()
    tokens_to_parents_scam[SCAM] = ''
    tokens_to_parents_benign[BENIGN] = ''

    def process_item(item, label):
        importance = item['importance']
        tokens = item['token'].split(':')
        token = tokens[0].strip()
        if token.startswith('fullwhois') or  token.startswith('ipinfo') \
            or  token.startswith('htmlinfo')   or  token.startswith('html information')\
            or  token.startswith('certsh') or  token.startswith('url') :
                token = tokens[1].strip()
            

        parent = get_parent_for_item(token)
        if len(parent) > 1:
            if (importance >=0):
                tokens_to_parents_scam[token] = parent
                aggregated_scores_scam[parent] += importance
                aggregated_scores_scam[token] += importance
                tokens_to_parents_scam[parent]=SCAM
                aggregated_scores_scam[SCAM] += importance
                
            else:
                tokens_to_parents_benign[token] = parent
                aggregated_scores_benign[parent] += abs(importance)
                aggregated_scores_benign[token] += abs(importance)
                tokens_to_parents_benign[parent]=BENIGN
                aggregated_scores_benign[BENIGN] += abs(importance)
            

    with open(json_file, 'r') as file:
        lime_data = json.load(file)
        if sample_idx != None:
            input = lime_data['lime_results'][sample_idx]
            classification = input["classification"]
            if classification == BENIGN:
                classification = 0
            elif classification == SCAM:
                classification = 1
            else: classification = int(classification)
            for item in input['explanation']:
                process_item(item, classification)
        else:
            for input in lime_data['lime_results']:
                # print(input['classification'])
                if label == None or label == int(input['classification']):
                    for item in input['explanation']:
                        process_item(item, label)
                    
    for key in aggregated_scores_scam.keys():
        parents.append('' if key == SCAM else tokens_to_parents_scam[key] + ' ')
        scores.append(aggregated_scores_scam[key] / aggregated_scores_scam[SCAM] * float(lime_data['lime_results'][sample_idx][SCAM])) #normalize to actual prediction confidence
        elements.append(key + ' ')

    for key in aggregated_scores_benign.keys():
        parents.append(tokens_to_parents_benign[key])
        scores.append(aggregated_scores_benign[key] / aggregated_scores_benign[BENIGN] * float(lime_data['lime_results'][sample_idx][BENIGN]))  #normalize to actual prediction confidence
        elements.append(key)

    data = dict(
        names=elements,
        parents=parents,
        values=scores
    )
    return data


In [14]:
plot_data = sunburst_expl('weighted_explanation.json', 1, sample_idx=0)

# print(plot_data)

sunburst_plot(plot_data, "")

Couldn not find a parent category for: nan


ValueError: Mime type rendering requires nbformat>=4.2.0 but it is not installed