## Project Overview
This project implements malware detection using both classical machine learning models (Logistic Regression, SVM, Random Forest, XGBoost) and advanced models (TabTransformer, TabNet, FT-Transformer, and Graph Neural Networks).

## Group Details  

-Sanduni Kanapeddala Gamage- Student ID: 1598065

-Kahandawita Arachchige Arosh Malindra Perera- Student ID: 1579940

-Nanayakkarawasam Juliyan Stephan Nalaka De Silva- Student ID: 1585221

-Hungampala Ralalage Malaka Prasad- Student ID: 1599986

In [3]:
# Basic libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import gc
import pickle
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Kaggle Dataset Loading - Only basic kaggle API (no kagglehub)
import sys
import subprocess
try:
    from kaggle.api.kaggle_api_extended import KaggleApi
    KAGGLE_AVAILABLE = True
except ImportError:
    try:
        print("Installing kaggle dependencies...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "kaggle"])
        from kaggle.api.kaggle_api_extended import KaggleApi
        KAGGLE_AVAILABLE = True
    except Exception as e:
        print(f"Could not install kaggle API: {e}")
        KAGGLE_AVAILABLE = False

# Data processing
from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder, LabelEncoder
from sklearn.model_selection import train_test_split, GridSearchCV, StratifiedKFold, cross_val_score
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from sklearn.metrics import roc_curve, auc, precision_recall_curve, classification_report
from sklearn.pipeline import Pipeline
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import RandomUnderSampler

# Classical ML Models
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier

# Feature Importance and Explainability
import shap
from lime import lime_tabular

# Advanced ML Models - TabTransformer
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pytorch_tabnet.tab_model import TabNetClassifier

# FT-Transformer (using PyTorch implementation)
class FTTransformer(nn.Module):
    def __init__(self, input_dim, output_dim, dim=32, depth=2, heads=4, dropout=0.1):
        super().__init__()
        self.embedding = nn.Linear(input_dim, dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dropout=dropout)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.mlp_head = nn.Sequential(
            nn.Linear(dim, dim*2),
            nn.ReLU(),
            nn.Linear(dim*2, output_dim)
        )
        
    def forward(self, x):
        x = self.embedding(x)
        x = x.unsqueeze(1)  # Add sequence dimension
        x = self.transformer(x)
        x = x.squeeze(1)
        return self.mlp_head(x)

# Graph Neural Networks dependencies
# Checking if GNN functionality is available
""""
GNN_AVAILABLE = False
try:
    import torch_geometric
    from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
    from torch_geometric.data import Data
    from torch.utils.data import DataLoader as GraphDataLoader
    GNN_AVAILABLE = True
    print("GNN functionality is available!")
except ImportError:
    print("torch_geometric is not installed. GNN functionality will not be available.")
"""
# For visualization
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

In [6]:
# Kaggle data loading function
def load_kaggle_datasets(dataset_id="agungpambudi/network-malware-detection-connection-analysis"):
    """
    Load datasets directly from Kaggle using the Kaggle API.
    Uses pipe (|) separator and keeps datasets separate.
    
    Args:
        dataset_id: The Kaggle dataset ID to download
        
    Returns:
        Dictionary of DataFrames with dataset names as keys
    """
    # Check if kaggle is installed
    try:
        import kaggle
        from kaggle.api.kaggle_api_extended import KaggleApi
    except ImportError:
        print("Kaggle API is not installed. Installing...")
        try:
            subprocess.check_call([sys.executable, "-m", "pip", "install", "kaggle"])
            import kaggle
            from kaggle.api.kaggle_api_extended import KaggleApi
            print("Kaggle API installed successfully")
        except Exception as e:
            print(f"Failed to install Kaggle API: {e}")
            print("Please manually install with: pip install kaggle")
            return None
    
    # Check if kaggle.json exists
    import os
    kaggle_dir = os.path.join(os.path.expanduser('~'), '.kaggle')
    kaggle_json = os.path.join(kaggle_dir, 'kaggle.json')
    
    if not os.path.exists(kaggle_json):
        print(f"Kaggle API credentials not found at: {kaggle_json}")
        print("\nTo fix Kaggle authentication issues:")
        print("1. Create a Kaggle account at https://www.kaggle.com")
        print("2. Go to Account -> API -> Create New API Token")
        print("3. This will download kaggle.json")
        print("4. Place this file in ~/.kaggle/ (Linux/Mac) or C:\\Users\\<username>\\.kaggle\\ (Windows)")
        print(f"   For your system, that's: {kaggle_dir}")
        print("5. Run this cell again")
        
        # Create the .kaggle directory if it doesn't exist
        if not os.path.exists(kaggle_dir):
            try:
                os.makedirs(kaggle_dir)
                print(f"Created directory: {kaggle_dir}")
            except Exception as e:
                print(f"Could not create directory: {e}")
        
        return None
    
    # Initialize dictionary to store DataFrames
    datasets = {}
    
    # Create a directory in DataSet to store Kaggle files
    base_dir = os.path.dirname(os.path.abspath('__file__'))
    dataset_dir = os.path.join(base_dir, 'DataSet', 'kaggle_data')
    os.makedirs(dataset_dir, exist_ok=True)
    
    print(f"Downloading dataset {dataset_id} to {dataset_dir}...")
    
    try:
        # Initialize and authenticate Kaggle API
        api = KaggleApi()
        api.authenticate()
        
        # Download the dataset files
        api.dataset_download_files(dataset_id, path=dataset_dir, unzip=True)
        print("Download complete!")
        
        # List all CSV files in the directory
        csv_files = [f for f in os.listdir(dataset_dir) if f.endswith('.csv')]
        print(f"Found {len(csv_files)} CSV files in the Kaggle dataset")
        
        # Load each CSV file separately with pipe separator
        for file in tqdm(csv_files, desc="Loading CSV files"):
            file_path = os.path.join(dataset_dir, file)
            try:
                # Use pipe separator specifically
                df = pd.read_csv(file_path, sep='|', low_memory=False)
                
                # Store the DataFrame in the dictionary with filename as key
                key = file.replace('.csv', '')
                datasets[key] = df
                print(f"Successfully loaded {file} with {df.shape[0]} rows and {df.shape[1]} columns")
            except Exception as e:
                print(f"Error loading {file}: {e}")
        
        # Cache the datasets for future use
        cache_dir = os.path.join(base_dir, 'DataSet', 'cache')
        os.makedirs(cache_dir, exist_ok=True)
        cache_file = os.path.join(cache_dir, 'kaggle_datasets.pkl')
        
        print(f"Saving datasets to cache: {cache_file}")
        with open(cache_file, 'wb') as f:
            pickle.dump(datasets, f)
        print("Datasets cached successfully")
        
    except Exception as e:
        print(f"Error accessing Kaggle API: {e}")
        print("\nTo fix Kaggle authentication issues:")
        print("1. Create a Kaggle account at https://www.kaggle.com")
        print("2. Go to Account -> API -> Create New API Token")
        print("3. This will download kaggle.json")
        print("4. Place this file in ~/.kaggle/ (Linux/Mac) or C:\\Users\\<username>\\.kaggle\\ (Windows)")
        print("5. Run this cell again")
        return None
        
    if not datasets:
        print("No datasets were loaded from Kaggle")
        return None
        
    return datasets

# Load datasets from Kaggle or cache
try:
    # Check for cached datasets first
    cache_file = os.path.join(os.path.dirname(os.path.abspath('__file__')), 'DataSet', 'cache', 'kaggle_datasets.pkl')
    
    if os.path.exists(cache_file):
        print(f"Loading datasets from cache: {cache_file}")
        with open(cache_file, 'rb') as f:
            datasets = pickle.load(f)
        print(f"Loaded {len(datasets)} datasets from cache")
    else:
        print("No cached datasets found. Loading from Kaggle...")
        datasets = load_kaggle_datasets()
        
    # Display available datasets
    if datasets:
        print("\nAvailable datasets:")
        for key, df in datasets.items():
            print(f"- {key}: {df.shape[0]} rows, {df.shape[1]} columns")
except Exception as e:
    print(f"Error loading datasets: {e}")
    import traceback
    traceback.print_exc()

No cached datasets found. Loading from Kaggle...
Downloading dataset agungpambudi/network-malware-detection-connection-analysis to d:\Unitec\Sem2\Machine learning\ML Project\DataSet\kaggle_data...
Dataset URL: https://www.kaggle.com/datasets/agungpambudi/network-malware-detection-connection-analysis
Download complete!
Found 12 CSV files in the Kaggle dataset
Download complete!
Found 12 CSV files in the Kaggle dataset


Loading CSV files:   8%|▊         | 1/12 [00:04<00:47,  4.33s/it]

Successfully loaded CTU-IoT-Malware-Capture-1-1conn.log.labeled.csv with 1008748 rows and 23 columns
Successfully loaded CTU-IoT-Malware-Capture-20-1conn.log.labeled.csv with 3209 rows and 23 columns
Successfully loaded CTU-IoT-Malware-Capture-21-1conn.log.labeled.csv with 3286 rows and 23 columns


Loading CSV files:  33%|███▎      | 4/12 [00:05<00:08,  1.05s/it]

Successfully loaded CTU-IoT-Malware-Capture-3-1conn.log.labeled.csv with 156103 rows and 23 columns
Successfully loaded CTU-IoT-Malware-Capture-34-1conn.log.labeled.csv with 23145 rows and 23 columns


Loading CSV files:  58%|█████▊    | 7/12 [01:03<00:51, 10.23s/it]

Successfully loaded CTU-IoT-Malware-Capture-35-1conn.log.labeled.csv with 10447787 rows and 23 columns
Successfully loaded CTU-IoT-Malware-Capture-42-1conn.log.labeled.csv with 4426 rows and 23 columns
Successfully loaded CTU-IoT-Malware-Capture-44-1conn.log.labeled.csv with 237 rows and 23 columns


Loading CSV files:  75%|███████▌  | 9/12 [01:21<00:29,  9.73s/it]

Successfully loaded CTU-IoT-Malware-Capture-48-1conn.log.labeled.csv with 3394338 rows and 23 columns


Loading CSV files:  83%|████████▎ | 10/12 [01:36<00:21, 10.79s/it]

Successfully loaded CTU-IoT-Malware-Capture-60-1conn.log.labeled.csv with 3581028 rows and 23 columns
Successfully loaded CTU-IoT-Malware-Capture-8-1conn.log.labeled.csv with 10403 rows and 23 columns


Loading CSV files: 100%|██████████| 12/12 [02:25<00:00, 12.15s/it]

Successfully loaded CTU-IoT-Malware-Capture-9-1conn.log.labeled.csv with 6378293 rows and 23 columns
Saving datasets to cache: d:\Unitec\Sem2\Machine learning\ML Project\DataSet\cache\kaggle_datasets.pkl





Datasets cached successfully

Available datasets:
- CTU-IoT-Malware-Capture-1-1conn.log.labeled: 1008748 rows, 23 columns
- CTU-IoT-Malware-Capture-20-1conn.log.labeled: 3209 rows, 23 columns
- CTU-IoT-Malware-Capture-21-1conn.log.labeled: 3286 rows, 23 columns
- CTU-IoT-Malware-Capture-3-1conn.log.labeled: 156103 rows, 23 columns
- CTU-IoT-Malware-Capture-34-1conn.log.labeled: 23145 rows, 23 columns
- CTU-IoT-Malware-Capture-35-1conn.log.labeled: 10447787 rows, 23 columns
- CTU-IoT-Malware-Capture-42-1conn.log.labeled: 4426 rows, 23 columns
- CTU-IoT-Malware-Capture-44-1conn.log.labeled: 237 rows, 23 columns
- CTU-IoT-Malware-Capture-48-1conn.log.labeled: 3394338 rows, 23 columns
- CTU-IoT-Malware-Capture-60-1conn.log.labeled: 3581028 rows, 23 columns
- CTU-IoT-Malware-Capture-8-1conn.log.labeled: 10403 rows, 23 columns
- CTU-IoT-Malware-Capture-9-1conn.log.labeled: 6378293 rows, 23 columns


In [7]:
# Detailed analysis of a specific dataset
def analyze_dataset(dataset_name):
    """
    Perform detailed analysis on a specific dataset
    
    Args:
        dataset_name: Key name of the dataset to analyze
    """
    if dataset_name not in datasets:
        print(f"Dataset '{dataset_name}' not found.")
        print(f"Available datasets: {list(datasets.keys())}")
        return
    
    df = datasets[dataset_name]
    
    # Basic information
    print(f"='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='=")
    print(f"DATASET: {dataset_name}")
    print(f"='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='=")
    print(f"Shape: {df.shape[0]} rows x {df.shape[1]} columns")
    
    # Column information
    print("\n--- COLUMNS ---")
    for col in df.columns:
        print(f"- {col}: {df[col].dtype}")
    
    # Sample data
    print("\n--- SAMPLE DATA ---")
    display(df.head())
    
    # Statistical summary
    print("\n--- STATISTICAL SUMMARY ---")
    display(df.describe())
    
    # Missing values
    print("\n--- MISSING VALUES ---")
    missing = df.isna().sum()
    if missing.sum() > 0:
        display(pd.DataFrame({
            'Column': missing.index,
            'Missing Values': missing.values,
            'Percentage': (missing.values / len(df) * 100).round(2)
        })[missing.values > 0].sort_values('Missing Values', ascending=False))
    else:
        print("No missing values found!")
    
    # Value counts for categorical columns (showing top categories)
    print("\n--- CATEGORICAL COLUMNS ---")
    cat_columns = df.select_dtypes(include=['object', 'category']).columns
    for col in cat_columns[:5]:  # Limiting to first 5 categorical columns
        print(f"\n{col}:")
        display(df[col].value_counts().head(10))
    
    if len(cat_columns) > 5:
        print(f"... and {len(cat_columns) - 5} more categorical columns (not shown)")
    
    # Numeric column distributions
    print("\n--- NUMERIC COLUMNS SUMMARY ---")
    num_cols = df.select_dtypes(include=['int64', 'float64']).columns
    if len(num_cols) > 0:
        display(df[num_cols].describe())
    else:
        print("No numeric columns found!")
    
    return df

# List available datasets
print("Available datasets:")
for i, key in enumerate(datasets.keys()):
    print(f"{i+1}. {key}")

# Example: Analyze the first dataset (you can change the index as needed)
if datasets:
    dataset_name = list(datasets.keys())[0]  # First dataset
    print(f"\nAnalyzing first dataset: {dataset_name}")
    analyze_dataset(dataset_name)
else:
    print("No datasets available to analyze.")

Available datasets:
1. CTU-IoT-Malware-Capture-1-1conn.log.labeled
2. CTU-IoT-Malware-Capture-20-1conn.log.labeled
3. CTU-IoT-Malware-Capture-21-1conn.log.labeled
4. CTU-IoT-Malware-Capture-3-1conn.log.labeled
5. CTU-IoT-Malware-Capture-34-1conn.log.labeled
6. CTU-IoT-Malware-Capture-35-1conn.log.labeled
7. CTU-IoT-Malware-Capture-42-1conn.log.labeled
8. CTU-IoT-Malware-Capture-44-1conn.log.labeled
9. CTU-IoT-Malware-Capture-48-1conn.log.labeled
10. CTU-IoT-Malware-Capture-60-1conn.log.labeled
11. CTU-IoT-Malware-Capture-8-1conn.log.labeled
12. CTU-IoT-Malware-Capture-9-1conn.log.labeled

Analyzing first dataset: CTU-IoT-Malware-Capture-1-1conn.log.labeled
='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='=
DATASET: CTU-IoT-Malware-Capture-1-1conn.log.labeled
='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='='=
Shape: 1008748 rows x 23 columns

--- COLUMNS ---
- ts: float64
- uid: object
- id.orig_h: object
- id.orig_p: float64

Unnamed: 0,ts,uid,id.orig_h,id.orig_p,id.resp_h,id.resp_p,proto,service,duration,orig_bytes,...,local_resp,missed_bytes,history,orig_pkts,orig_ip_bytes,resp_pkts,resp_ip_bytes,tunnel_parents,label,detailed-label
0,1525880000.0,CUmrqr4svHuSXJy5z7,192.168.100.103,51524.0,65.127.233.163,23.0,tcp,-,2.999051,0,...,-,0.0,S,3.0,180.0,0.0,0.0,-,Malicious,PartOfAHorizontalPortScan
1,1525880000.0,CH98aB3s1kJeq6SFOc,192.168.100.103,56305.0,63.150.16.171,23.0,tcp,-,-,-,...,-,0.0,S,1.0,60.0,0.0,0.0,-,Malicious,PartOfAHorizontalPortScan
2,1525880000.0,C3GBTkINvXNjVGtN5,192.168.100.103,41101.0,111.40.23.49,23.0,tcp,-,-,-,...,-,0.0,S,1.0,60.0,0.0,0.0,-,Malicious,PartOfAHorizontalPortScan
3,1525880000.0,CDe43c1PtgynajGI6,192.168.100.103,60905.0,131.174.215.147,23.0,tcp,-,2.998796,0,...,-,0.0,S,3.0,180.0,0.0,0.0,-,Malicious,PartOfAHorizontalPortScan
4,1525880000.0,CJaDcG3MZzvf1YVYI4,192.168.100.103,44301.0,91.42.47.63,23.0,tcp,-,-,-,...,-,0.0,S,1.0,60.0,0.0,0.0,-,Malicious,PartOfAHorizontalPortScan



--- STATISTICAL SUMMARY ---


Unnamed: 0,ts,id.orig_p,id.resp_p,missed_bytes,orig_pkts,orig_ip_bytes,resp_pkts,resp_ip_bytes
count,1008748.0,1008748.0,1008748.0,1008748.0,1008748.0,1008748.0,1008748.0,1008748.0
mean,1526075000.0,44436.84,16097.71,0.0,1.496242,81.14562,0.1424647,9.049184
std,115743.1,9660.592,19562.8,0.0,1.741176,94.7309,1.850414,119.6776
min,1525880000.0,3.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,1525975000.0,43730.0,23.0,0.0,1.0,40.0,0.0,0.0
50%,1526071000.0,43763.0,8080.0,0.0,1.0,60.0,0.0,0.0
75%,1526174000.0,48814.0,28180.25,0.0,1.0,60.0,0.0,0.0
max,1526283000.0,65394.0,65535.0,0.0,60.0,2990.0,75.0,9415.0



--- MISSING VALUES ---
No missing values found!

--- CATEGORICAL COLUMNS ---

uid:
No missing values found!

--- CATEGORICAL COLUMNS ---

uid:


uid
CUmrqr4svHuSXJy5z7    1
COj4Eq4lmR86amgfI6    1
CBqL9l4KOG0Y3zauml    1
CLIXVIOgCuf9Pv6j      1
CImdWC3kB3Zhudx4q3    1
CkA64v2h8KpX8Q74Z3    1
CogrdgtK89aNUYYJd     1
CaKHRA4f0t6bWYCo93    1
CWtdXj3YoOcd5MAjm1    1
C2lNJW3oXHusTkzmdg    1
Name: count, dtype: int64


id.orig_h:


id.orig_h
192.168.100.103    991061
192.168.100.1        1651
4.68.110.10            43
194.70.98.42           23
218.248.235.161        13
83.168.243.156         12
144.75.175.50          11
38.104.45.226          11
218.248.235.129        10
159.226.254.70         10
Name: count, dtype: int64


id.resp_h:


id.resp_h
192.168.100.103    17687
147.231.100.5       4313
213.239.154.12      1428
37.187.104.44       1408
89.221.214.130      1402
210.206.154.134      129
70.45.29.240         128
175.196.5.46         125
92.255.209.3         125
221.5.224.77         124
Name: count, dtype: int64


proto:


proto
tcp     583134
udp     408193
icmp     17421
Name: count, dtype: int64


service:


service
-       1005507
http       3238
dhcp          1
ssh           1
dns           1
Name: count, dtype: int64

... and 10 more categorical columns (not shown)

--- NUMERIC COLUMNS SUMMARY ---


Unnamed: 0,ts,id.orig_p,id.resp_p,missed_bytes,orig_pkts,orig_ip_bytes,resp_pkts,resp_ip_bytes
count,1008748.0,1008748.0,1008748.0,1008748.0,1008748.0,1008748.0,1008748.0,1008748.0
mean,1526075000.0,44436.84,16097.71,0.0,1.496242,81.14562,0.1424647,9.049184
std,115743.1,9660.592,19562.8,0.0,1.741176,94.7309,1.850414,119.6776
min,1525880000.0,3.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,1525975000.0,43730.0,23.0,0.0,1.0,40.0,0.0,0.0
50%,1526071000.0,43763.0,8080.0,0.0,1.0,60.0,0.0,0.0
75%,1526174000.0,48814.0,28180.25,0.0,1.0,60.0,0.0,0.0
max,1526283000.0,65394.0,65535.0,0.0,60.0,2990.0,75.0,9415.0
