<a href="https://colab.research.google.com/github/2002sahapriya/llm-guardrails/blob/main/%5BAiStrike%5D_Problem_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AiStrike Assignment
Name: Priyadarshini Saha \
Date: July 20, 2024

## Problem Statement 1
Given is a subset of alert data for cloud infrastructure.
Goal: Extract maximum information from each data point and develop a statistical method in order to classify the alert data into a pre-determined set of categories.

Classification process:
The 3 main categories are:
1. Internal category
2. Alert source entity
3. Alert target entity

**Note:** Classification is sequential. Meaning, you have to first classify the alert on category 1, then 2, and then 3.

The **1. Internal Category** has the following subcategories:
1. Malware
2. Network Traffic: Inbound
3. Network Traffic: Outbound
4. Privilege Escalation
5. Credential Access
6. Initial Access
7. Defense Evaluation
8. Discovery
9. Exfiltration
10. Impact
11. Persistence
12. Execution

The **2. Alert source entity** has the following subcategories:
1. host
2. resource_storage
3. identity
4. resource_container
5. resource_kubernetes
6. resource_lambda
7. resource_rds

The **3. Alert target entity** category has the following subcategories:
1. application / process / file
2. network connection
3. resource_storage
4. resource_api
5. resource_kubernetes
6. host
7. resource_rds
8. resource_securitymonitoring

Evaluation metric: accuracy
Solution: Ideally, use a LLM-enabled Semantic Search via text embeddings.

## Approach to Solve the Problem
1. Data Loading & Preliminary Analysis: Load the data to understand its structure and determine which fields are relevant for the classification tasks.
2. Data Processing: Clean and preprocess the data to handle missing values and convery text into a suitable format for further analysis.
3. Feature Engineering: Find a comprehensive text feature that provides valuable information for the classification. Prepare the input into text embedddings by using BERT model and label encode the output for the 3 categories
4. Model Development: Develop a sequential classifier model that emcompasses 3 smaller models. The first model takes the text embeddings as input and does the classification for the `internal_categorization` category. The second model takes the out from the output of the first model along with the original text embeddings to produce classification for the second category, that is, source entity. The third and final model takes the output of the second model along with the original text embeddings and produces the classification for the third category, that is, for target entity.
5. Model Training & Evaluation: Train the model using 80% of the data (train data), and remaining 20% of the data (test set) to test the developed models. Calculate accuracy to measure the success of the model.
6. Optimization: Depending on the achieved accuracy, optimize the model using fine-tuning or better loss functions.
---


In [1]:
!pip install transformers torch

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
Collecting nvidia-curand-cu12==10.3.2.106 (from torch)
  Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)
Collectin

# Step 1: Install Necessary Libraries
- numpy
- pandas
- pytorch
- transformers
- sklearn

In [44]:
# Load the necessary libraries
import pandas as pd
import numpy as np
from sklearn.preprocessing import OneHotEncoder
from transformers import BertTokenizer, BertModel
import torch
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

# Step 2: Load, Clean & Prepare the Dataset
1. Load the dataset from the CSV file as a Pandas DataFrame
2. Clean the data:
    - remove all empty columns
    - remove all rows where the labels (internal_categorization, source_entity, and target_entity) have missing values as we need a full dataset with all features and labels present for developing and training the model.
    

In [45]:
# Load the dataset
file_path = './AiStrike_alerts.csv'
raw_data = pd.read_csv(file_path)

raw_data.head()
raw_data.info()
raw_data.describe(include='all')

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2283 entries, 0 to 2282
Columns: 174 entries, id to target_entity
dtypes: bool(1), float64(107), int64(7), object(59)
memory usage: 3.0+ MB


Unnamed: 0,id,unique_id,inserttime,start_datetime,end_datetime,alert_priority,finding_arn,alert_starttime,description,account_id,...,machine_arch,machine_isexternal,machine_internetexposure,machine_name,fileexepath_lastseentime,fileexepath_firstseentime,instancename,internal_categorization,source_entity,target_entity
count,2283.0,2283.0,2283,2283.0,2283.0,41,0.0,2283.0,2283,936.0,...,40,40.0,40,23,0.0,0.0,13,2104,2104,1343
unique,,,677,,,2,,,1108,,...,2,,3,22,,,10,8,3,6
top,,,2024-07-14 22:50:01,,,Low,,,For account: 280566572992 : SecurityGroup sg-...,,...,amd64,,Yes,spa-prod-v55 EC2,,,aks-agentpool-17422648-vmss_4,Defense Evasion,identity,network connection
freq,,,13,,,40,,,104,,...,34,,24,2,,,4,1390,1888,1180
mean,20591.890057,218973.95138,,1720638000.0,1720642000.0,,,1720638000.0,,609753000000.0,...,,1.0,,,,,,,,
std,3055.889439,817.748313,,147333.4,147346.3,,,147333.4,,202320300000.0,...,,0.0,,,,,,,,
min,2161.0,217604.0,,1720352000.0,1720355000.0,,,1720352000.0,,157363500000.0,...,,1.0,,,,,,,,
25%,19536.5,218248.5,,1720505000.0,1720508000.0,,,1720505000.0,,432929400000.0,...,,1.0,,,,,,,,
50%,21106.0,218955.0,,1720652000.0,1720656000.0,,,1720652000.0,,533811400000.0,...,,1.0,,,,,,,,
75%,22434.5,219702.5,,1720753000.0,1720757000.0,,,1720753000.0,,797382100000.0,...,,1.0,,,,,,,,


In [46]:
# removing columns that are completely empty
data_non_empty = raw_data.dropna(axis = 1, how = 'all')

# Check the shape of the new dataframe and display the first few rows
print(f'Shape: {data_non_empty.shape}')
print(f'Non-Empty Column names: {data_non_empty.columns}')
data_non_empty.head()

Shape: (2283, 77)
Non-Empty Column names: Index(['id', 'unique_id', 'inserttime', 'start_datetime', 'end_datetime',
       'alert_priority', 'alert_starttime', 'description', 'account_id',
       'alert_id', 'region', 'image_id', 'instance_id', 'access_key_id',
       'user_name', 'instance_type', 'public_ipaddress', 'security_group',
       'alert_severity', 'message', 'alert_name', 'vendor', 'product',
       'technique', 'classification', 'cluster_name', 'city', 'country',
       'alert_api', 'service_name', 'alert_status', 'alert_endtime',
       'alert_type', 'alert_reachability', 'alert_source', 'alert_category',
       'alert_subcategory', 'alert_evolving', 'alert_internetexposure',
       'hostname', 'machine_externalipaddr', 'private_ip',
       'machine_internalipaddr', 'ipaddress_endtime', 'ipaddress_ipaddress',
       'port_number', 'ipaddress_inbytes', 'ipaddress_outbytes',
       'domain_bytesin', 'domain_bytesout', 'user_name_set',
       'application_isclient', 'applica

Unnamed: 0,id,unique_id,inserttime,start_datetime,end_datetime,alert_priority,alert_starttime,description,account_id,alert_id,...,machine_os,machine_zone,machine_arch,machine_isexternal,machine_internetexposure,machine_name,instancename,internal_categorization,source_entity,target_entity
0,19434,218176,2024-07-14 17:47:40,1720494000,1720497600,,1720494000,AWS Account 274059328831 : lacework-global-12...,,218176,...,,,,,,,,Defense Evasion,identity,
1,19435,218174,2024-07-14 17:47:40,1720494000,1720497600,,1720494000,AWS Account 822165131996 : lacework-global-12...,,218174,...,,,,,,,,Defense Evasion,identity,
2,21971,219490,2024-07-14 23:23:03,1720717200,1720720800,,1720717200,For account: 280566572992 : Route-table/route...,797382100000.0,219490,...,,,,,,,,Defense Evasion,identity,network connection
3,21972,219489,2024-07-14 23:23:03,1720717200,1720720800,,1720717200,For account: 280566572992 (and 1 more) : Secur...,,219489,...,,,,,,,,Defense Evasion,identity,network connection
4,21972,219489,2024-07-14 23:23:03,1720717200,1720720800,,1720717200,For account: 280566572992 (and 1 more) : Secur...,,219489,...,,,,,,,,Defense Evasion,identity,network connection


In [47]:
# Focus on the most relevant columns for the classification task
relevant_columns = ['description', 'alert_severity', 'alert_name', 'message', 'alert_type',
                    'alert_reachability', 'alert_source', 'alert_category', 'alert_subcategory',
                    'alert_evolving', 'alert_internetexposure',
                    'internal_categorization', 'source_entity', 'target_entity']
data_relevant = data_non_empty[relevant_columns]

# Check the number of missing values in the label columns
missing_label_values = data_relevant.isnull().sum()

# Display the missing values information and the first few rows of the relevant data
print(f'Missing Values:\n{missing_label_values}')
print(f'Shape: {data_relevant.shape}')
print(f'Relevant Column names: {data_relevant.columns}')
data_relevant.head()

Missing Values:
description                  0
alert_severity               0
alert_name                   0
message                      0
alert_type                   0
alert_reachability          16
alert_source                 0
alert_category               0
alert_subcategory            0
alert_evolving               0
alert_internetexposure       0
internal_categorization    179
source_entity              179
target_entity              940
dtype: int64
Shape: (2283, 14)
Relevant Column names: Index(['description', 'alert_severity', 'alert_name', 'message', 'alert_type',
       'alert_reachability', 'alert_source', 'alert_category',
       'alert_subcategory', 'alert_evolving', 'alert_internetexposure',
       'internal_categorization', 'source_entity', 'target_entity'],
      dtype='object')


Unnamed: 0,description,alert_severity,alert_name,message,alert_type,alert_reachability,alert_source,alert_category,alert_subcategory,alert_evolving,alert_internetexposure,internal_categorization,source_entity,target_entity
0,AWS Account 274059328831 : lacework-global-12...,Medium,New violations,New violations: AWS Account 274059328831 : la...,NewViolations,UnknownReachability,AWS,Policy,Compliance,False,UnknownInternetExposure,Defense Evasion,identity,
1,AWS Account 822165131996 : lacework-global-12...,Medium,New violations,New violations: AWS Account 822165131996 : la...,NewViolations,UnknownReachability,AWS,Policy,Compliance,False,UnknownInternetExposure,Defense Evasion,identity,
2,For account: 280566572992 : Route-table/route...,Info,Route Table Change,Route Table Change: For account: 280566572992...,RouteTableChange,UnknownReachability,AWS,Policy,Cloud Activity,False,UnknownInternetExposure,Defense Evasion,identity,network connection
3,For account: 280566572992 (and 1 more) : Secur...,Info,Security Group Change,Security Group Change: For account: 280566572...,SecurityGroupChange,UnknownReachability,AWS,Policy,Cloud Activity,False,UnknownInternetExposure,Defense Evasion,identity,network connection
4,For account: 280566572992 (and 1 more) : Secur...,Info,Security Group Change,Security Group Change: For account: 280566572...,SecurityGroupChange,UnknownReachability,AWS,Policy,Cloud Activity,False,UnknownInternetExposure,Defense Evasion,identity,network connection


In [48]:
# Remove rows with missing values for the specified labels: internal_categorization, source_entity and target_entity
clean_columns = ['internal_categorization', 'source_entity', 'target_entity']
data = data_relevant.dropna(subset = clean_columns)

# Display the final dataset to be used
print(f'Shape of final dataset: {data.shape}')
print(f'Relevant Column names: {data.columns}')
data.head()

Shape of final dataset: (1343, 14)
Relevant Column names: Index(['description', 'alert_severity', 'alert_name', 'message', 'alert_type',
       'alert_reachability', 'alert_source', 'alert_category',
       'alert_subcategory', 'alert_evolving', 'alert_internetexposure',
       'internal_categorization', 'source_entity', 'target_entity'],
      dtype='object')


Unnamed: 0,description,alert_severity,alert_name,message,alert_type,alert_reachability,alert_source,alert_category,alert_subcategory,alert_evolving,alert_internetexposure,internal_categorization,source_entity,target_entity
2,For account: 280566572992 : Route-table/route...,Info,Route Table Change,Route Table Change: For account: 280566572992...,RouteTableChange,UnknownReachability,AWS,Policy,Cloud Activity,False,UnknownInternetExposure,Defense Evasion,identity,network connection
3,For account: 280566572992 (and 1 more) : Secur...,Info,Security Group Change,Security Group Change: For account: 280566572...,SecurityGroupChange,UnknownReachability,AWS,Policy,Cloud Activity,False,UnknownInternetExposure,Defense Evasion,identity,network connection
4,For account: 280566572992 (and 1 more) : Secur...,Info,Security Group Change,Security Group Change: For account: 280566572...,SecurityGroupChange,UnknownReachability,AWS,Policy,Cloud Activity,False,UnknownInternetExposure,Defense Evasion,identity,network connection
5,For account: 280566572992 (and 1 more) : Secur...,Info,Security Group Change,Security Group Change: For account: 280566572...,SecurityGroupChange,UnknownReachability,AWS,Policy,Cloud Activity,False,UnknownInternetExposure,Defense Evasion,identity,network connection
6,For account: 280566572992 (and 1 more) : Secur...,Info,Security Group Change,Security Group Change: For account: 280566572...,SecurityGroupChange,UnknownReachability,AWS,Policy,Cloud Activity,False,UnknownInternetExposure,Defense Evasion,identity,network connection


## Analysis of Relevant Fields:
1. `description`: Contains detailed descriptions of the alerts, making it a rich source of textual data for embeddings.
2. `alert_severity`: Categorical data indicating the severity of the alert. Could be useful as a feature but may not provide rich textual content.
3. `alert_name`: Names of the alerts, possibly categorical but could contain key descriptors useful in classification.
4. `message`: Concatenation of description and alert_name, providing a comprehensive text feature that combines detailed information. Ideal for embeddings.
5. `alert_type`: Categorical data specifying the type of alert. Like alert_name, it provides contextual information
6. `alert_reachability`: Indicates the reachability associated with the alert, adding another layer of context. This is mostly categorical.
7. `alert_source`: The source of the alert, which is categorical and crucial for understanding the context of the alert.
8. `alert_category`: Broad categorization of alerts, useful for high-level classification tasks.
9. `alert_subcategory`: More granular classification within each category. Highly relevant for detailed classification.
10. `alert_evolving`: A boolean indicator of whether the alert is evolving. May not provide textual content but is useful as a contextual feature.
11. `alert_internetexposure`: Provides information on the internet exposure related to the alert, adding valuable context for risk assessment.

- **Primary**: Use `message` as the primary source for text embeddings since it combines `description` and `alert_name`, offering the most comprehensive textual context.
- Optional: Consider also using `alert_category` and `alert_subcategory` as additional sources of text for embeddings to capture different layers of information
- Optional: Use fields like `alert_severity` and `alert_reachability` can be encoded and used alongside embeddings to provide context and enhance the classifier's performance.

# Step 3: Prepare Input & Outputs
1. Generate text embeddings for the input `message` using an LLM Bert Tokenizer
2. Label Encode the columns: internal_categorization, source_entity, and target_entity.

In [49]:
# Initialize BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

In [52]:
# Generate embeddings for a list of text entries using BERT
def generate_embeddings(text_list):
  # Ensure the tokenizer and model are using GPU, if available
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  model.to(device)

  # Encode the input text list
  encoded_input = tokenizer(text_list, padding=True, truncation=True, return_tensors='pt', max_length=512)
  encoded_input = encoded_input.to(device) # Move encoded input to the same device

  # Perform forward pass and compute embeddings with no gradient calculation
  with torch.no_grad():
    outputs = model(**encoded_input)

  # Use mean pooling to get a single vector for each input
  embeddings = outputs.last_hidden_state.mean(dim=1)

  # Move embeddings back to CPU for compatibility with sklear or other CPU-bound operations
  return embeddings.detach().cpu().numpy()

In [53]:
# generate embeddings for the 'message' column
message_embeddings = generate_embeddings(data['message'].tolist())

# Show the shape of the generated embeddings
print(f'Shape of embeddings:{message_embeddings.shape}')

Shape of embeddings:(1343, 768)


In [55]:
# Test for generate_embeddings on a random message
test = generate_embeddings(['This is a test message.'])
print(test)
print(test.shape)

[[ 2.87486792e-01 -2.18833357e-01 -4.07918841e-02 -6.47928119e-02
  -3.38200927e-02 -5.46444893e-01  3.84085983e-01  4.10230607e-01
   3.65385935e-02 -3.19685340e-01  7.02630207e-02 -7.32268691e-01
  -3.70148510e-01  1.92829564e-01 -3.66340041e-01 -1.77986026e-02
   1.53455228e-01  5.58866188e-02  4.68251109e-01 -3.73134315e-02
  -2.31477886e-01 -8.08989704e-02 -2.92494595e-01  1.40282780e-01
   2.52456278e-01 -2.01782033e-01 -1.23287059e-01  1.32000700e-01
  -3.52564931e-01 -2.62341827e-01  2.96208531e-01  8.18497539e-02
  -1.89344645e-01 -1.30101535e-02 -3.89007628e-01 -1.72715411e-01
   1.88761353e-01 -2.42251441e-01 -4.34004307e-01  7.35941976e-02
  -4.56596226e-01 -4.52264071e-01  1.50622532e-01  1.33034497e-01
  -7.98780769e-02 -5.18825412e-01 -2.72642434e-01  3.26247588e-02
   1.34242579e-01  1.59057543e-01 -3.50323021e-01  4.03966010e-01
  -4.52474415e-01  5.96401095e-03  5.53199090e-03  2.37096488e-01
   3.31857055e-02 -6.05462551e-01 -4.64936018e-01 -7.82989487e-02
   2.04727

In [56]:
# Initialize label encoders for each categorical feature
label_encoders = {
    'internal_category': LabelEncoder(),
    'source_entity': LabelEncoder(),
    'target_entity': LabelEncoder()
}

In [59]:
# Example DataFrame
data_test = pd.DataFrame({
    'internal_category': ['Malware', 'NetworkTraffic#Inbound', 'NetworkTraffic#Outbound','Privilege Escalation'],
    'source_entity': ['host', 'resource_container', 'resource_rds', 'resource_storage'],
    'target_entity': ['application / process / file', 'network connection', 'resource_api', 'resource_storage']
})

data_test.loc[:, 'internal_category_encoded'] = label_encoders['internal_category'].fit_transform(data_test['internal_category'])
data_test.loc[:, 'source_entity_encoded'] = label_encoders['source_entity'].fit_transform(data_test['source_entity'])
data_test.loc[:, 'target_entity_encoded'] = label_encoders['target_entity'].fit_transform(data_test['target_entity'])

data_test.head()

Unnamed: 0,internal_category,source_entity,target_entity,internal_category_encoded,source_entity_encoded,target_entity_encoded
0,Malware,host,application / process / file,0,0,0
1,NetworkTraffic#Inbound,resource_container,network connection,1,1,1
2,NetworkTraffic#Outbound,resource_rds,resource_api,2,2,2
3,Privilege Escalation,resource_storage,resource_storage,3,3,3


In [62]:
# Encode the categories
data.loc[:, 'internal_category_encoded'] = label_encoders['internal_category'].fit_transform(data['internal_categorization'])
data.loc[:, 'source_entity_encoded'] = label_encoders['source_entity'].fit_transform(data['source_entity'])
data.loc[:, 'target_entity_encoded'] = label_encoders['target_entity'].fit_transform(data['target_entity'])

In [64]:
# Display the dataset to see how the embeddings have taken place
data

Unnamed: 0,description,alert_severity,alert_name,message,alert_type,alert_reachability,alert_source,alert_category,alert_subcategory,alert_evolving,alert_internetexposure,internal_categorization,source_entity,target_entity,internal_category_encoded,source_entity_encoded,target_entity_encoded
2,For account: 280566572992 : Route-table/route...,Info,Route Table Change,Route Table Change: For account: 280566572992...,RouteTableChange,UnknownReachability,AWS,Policy,Cloud Activity,False,UnknownInternetExposure,Defense Evasion,identity,network connection,0,1,2
3,For account: 280566572992 (and 1 more) : Secur...,Info,Security Group Change,Security Group Change: For account: 280566572...,SecurityGroupChange,UnknownReachability,AWS,Policy,Cloud Activity,False,UnknownInternetExposure,Defense Evasion,identity,network connection,0,1,2
4,For account: 280566572992 (and 1 more) : Secur...,Info,Security Group Change,Security Group Change: For account: 280566572...,SecurityGroupChange,UnknownReachability,AWS,Policy,Cloud Activity,False,UnknownInternetExposure,Defense Evasion,identity,network connection,0,1,2
5,For account: 280566572992 (and 1 more) : Secur...,Info,Security Group Change,Security Group Change: For account: 280566572...,SecurityGroupChange,UnknownReachability,AWS,Policy,Cloud Activity,False,UnknownInternetExposure,Defense Evasion,identity,network connection,0,1,2
6,For account: 280566572992 (and 1 more) : Secur...,Info,Security Group Change,Security Group Change: For account: 280566572...,SecurityGroupChange,UnknownReachability,AWS,Policy,Cloud Activity,False,UnknownInternetExposure,Defense Evasion,identity,network connection,0,1,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2277,External connection made to a known bad URL ou...,Medium,Outbound connection to a bad external URL,Outbound connection to a bad external URL: Ext...,ExternalServerBadDNSConn,UnknownReachability,Agent,Policy,Threat Intel,False,UnknownInternetExposure,Network Traffic#Outbound,host,network connection,4,0,2
2278,External connection made to a known bad URL d1...,Low,Outbound connection to a bad external URL,Outbound connection to a bad external URL: Ext...,ExternalServerBadDNSConn,UnknownReachability,Agent,Policy,Threat Intel,False,UnknownInternetExposure,Network Traffic#Outbound,host,network connection,4,0,2
2279,External connection made to a known bad URL yg...,Critical,Outbound connection to a bad external URL,Outbound connection to a bad external URL: Ext...,ExternalServerBadDNSConn,UnknownReachability,Agent,Policy,Threat Intel,False,UnknownInternetExposure,Network Traffic#Outbound,host,network connection,4,0,2
2281,External connection made to a known bad URL pr...,Medium,Outbound connection to a bad external URL,Outbound connection to a bad external URL: Ext...,ExternalServerBadDNSConn,UnknownReachability,Agent,Policy,Threat Intel,False,UnknownInternetExposure,Network Traffic#Outbound,host,network connection,4,0,2


# Step 4: Prepare Training & Testing Datasets

In [65]:
# Prepare Inputs & Outputs for Splitting
X = message_embeddings
y = data[['internal_category_encoded', 'source_entity_encoded', 'target_entity_encoded']]

In [66]:
# Split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [67]:
# Convert list of embeddings back to numpy array for training and testing
X_train = np.array(X_train)
X_test = np.array(X_test)

# Step 5: Model Development & Evaluation

In [73]:
# Model 1: Internal Category

# Train the internal category model
model_internal = LogisticRegression(max_iter=1000, multi_class='multinomial', solver='lbfgs')
model_internal.fit(X_train, y_train['internal_category_encoded'])

# Predict the probabilities for training and testing data
internal_probs_train = model_internal.predict_proba(X_train)
internal_probs_test = model_internal.predict_proba(X_test)

In [74]:
# Model 2: Source Entity

# Append the output from model 1 to original embeddings for Model 2
X_train_source = np.hstack((X_train, internal_probs_train))
X_test_source = np.hstack((X_test, internal_probs_test))

# Train the source entity model
model_source = LogisticRegression(max_iter=1000, multi_class='multinomial', solver='lbfgs')
model_source.fit(X_train_source, y_train['source_entity_encoded'])

# Predict the probabilities for training and testing data
source_probs_train = model_source.predict_proba(X_train_source)
source_probs_test = model_source.predict_proba(X_test_source)

In [75]:
# Model 3: Target Entity

# Append the output from model 2 to original embeddings for Model 3
X_train_target = np.hstack((X_train_source, source_probs_train))
X_test_target = np.hstack((X_test_source, source_probs_test))

# Train the target entity model
model_target = LogisticRegression(max_iter=1000, multi_class='multinomial', solver='lbfgs')
model_target.fit(X_train_target, y_train['target_entity_encoded'])

In [86]:
# Evaluate models

# Calculate accuracy for Internal Category Model
internal_train_accuracy = accuracy_score(y_train['internal_category_encoded'], model_internal.predict(X_train))
internal_test_accuracy = accuracy_score(y_test['internal_category_encoded'], model_internal.predict(X_test))

# Calculate accuracy for Source Entity Model
source_train_accuracy = accuracy_score(y_train['source_entity_encoded'], model_source.predict(X_train_source))
source_test_accuracy = accuracy_score(y_test['source_entity_encoded'], model_source.predict(X_test_source))

# Calculate accuracy for Target Entity Model
target_train_accuracy = accuracy_score(y_train['target_entity_encoded'], model_target.predict(X_train_target))
target_test_accuracy = accuracy_score(y_test['target_entity_encoded'], model_target.predict(X_test_target))

# Print the accuracies
print(f'Internal Category Model - Training Accuracy: {internal_train_accuracy * 100:.2f}% , Testing Accuracy: {internal_test_accuracy * 100:.2f}%')
print(f'Source Entity Model - Training Accuracy: {source_train_accuracy * 100:.2f}% , Testing Accuracy: {source_test_accuracy * 100:.2f}%')
print(f'Target Entity Model - Training Accuracy: {target_train_accuracy * 100:.2f}% , Testing Accuracy: {target_test_accuracy * 100:.2f}%')


Internal Category Model - Training Accuracy: 100.00% , Testing Accuracy: 99.63%
Source Entity Model - Training Accuracy: 100.00% , Testing Accuracy: 100.00%
Target Entity Model - Training Accuracy: 100.00% , Testing Accuracy: 99.63%


## Calculate the combined accuracy of the model

To evaluate the accuracy of the three models together in the sequential classification task, where a prediction is considered correct only if all categories (Internal Category, Source Entity, Target Entity) for a given row are predicted correctly. This approach ensures that a prediction is fully accuracte only if all parts of the classification chain are correct.

In [77]:
# Predict categories for the test dataset
internal_test_predictions = model_internal.predict(X_test)
source_test_predictions = model_source.predict(X_test_source)
target_test_predictions = model_target.predict(X_test_target)

In [87]:
# Combined accuracy calculation
correct_predictions = 0
total_predictions = len(y_test['internal_category_encoded'])

for i in range(total_predictions):
  # Check if all predictions for a row are correct
  if (internal_test_predictions[i] == y_test['internal_category_encoded'].iloc[i] and
      source_test_predictions[i] == y_test['source_entity_encoded'].iloc[i] and
      target_test_predictions[i] == y_test['target_entity_encoded'].iloc[i]):
    correct_predictions += 1

combined_accuracy = (correct_predictions / total_predictions) * 100

print(f"Combined Testing Accuracy of All Models: {combined_accuracy:2f}%")

Combined Testing Accuracy of All Models: 99.628253%
