# Deep Reinforcement Learning for Malicious URL Detection

## Overview
This notebook trains a Deep Q-Network (DQN) model for multi-class URL classification (Benign, Malware, Phishing, Spam, Defacement) using Google Colab's GPU. It combines a Transformer model for feature extraction with a DRL agent for adaptive decision-making.

## Prerequisites
- Set runtime to GPU: `Runtime > Change runtime type > Hardware accelerator > GPU`.
- Upload `kaggle.json` for dataset access (get from Kaggle > Account > API > Create New API Token).
- Obtain Google Safe Browsing API key from [Google Cloud Console](https://console.cloud.google.com/) and store in Colab Secrets.
- Ensure stable internet for API calls and package installation.

## Steps
1. Install dependencies and restart runtime.
2. Configure API keys and mount Google Drive.
3. Fetch and preprocess data from URLhaus, Google Safe Browsing, and Kaggle.
4. Define a custom Gym environment for URL classification.
5. Train the DQN model on GPU.
6. Save models to Google Drive.
7. Evaluate model performance.


In [1]:
# Cell 1: Install Dependencies
#!pip install --force-reinstall stable-baselines3==2.3.2 gym==0.25.2 requests==2.32.3 pandas==2.2.2 transformers==4.41.0 kaggle==1.6.14 shimmy
#!pip install transformers==4.41.0
# Import libraries
import gym
from gym import spaces
import numpy as np
import pandas as pd
import requests
from transformers import AutoTokenizer, AutoModel
import tensorflow as tf
from stable_baselines3 import DQN
from stable_baselines3.common.env_checker import check_env
from google.colab import files, drive

# Verify installations
try:
    print("Pandas version:", pd.__version__)
    print("Requests version:", requests.__version__)
    print("TensorFlow version:", tf.__version__)
    print("Gym version:", gym.__version__)
    print("Transformers version:", transformers.__version__)
    print("Dependencies installed successfully.")
except Exception as e:
    print(f"Error importing libraries: {e}")

# Important: Restart the runtime to apply new package versions
print("Please go to Runtime > Restart runtime, then rerun this cell to confirm imports.")

Pandas version: 2.2.2
Requests version: 2.32.3
TensorFlow version: 2.18.0
Gym version: 0.25.2
Error importing libraries: name 'transformers' is not defined
Please go to Runtime > Restart runtime, then rerun this cell to confirm imports.


In [2]:
# Cell 2: Configure API Keys and Google Drive
# Mount Google Drive to save models
from google.colab import drive
from google.colab import userdata
drive.mount('/content/drive')

# Upload Kaggle API key
print("Upload kaggle.json")
files.upload()
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# Test Kaggle API
try:
    !kaggle datasets list
    print("Kaggle API working.")
except Exception as e:
    print(f"Kaggle API error: {e}")

# Set Google Safe Browsing API key (use Colab Secrets)
from google.colab import userdata
try:
    GOOGLE_API_KEY = userdata.get('google_search_api')
    print("Google Safe Browsing API key loaded.")
except Exception as e:
    GOOGLE_API_KEY = 'YOUR_GOOGLE_SAFE_BROWSING_API_KEY'  # Replace with your key if not using Secrets
    print(f"Error loading API key from Secrets: {e}. Using placeholder key (replace it).")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Upload kaggle.json


Saving kaggle.json to kaggle.json
ref                                                             title                                               size  lastUpdated          downloadCount  voteCount  usabilityRating  
--------------------------------------------------------------  -------------------------------------------------  -----  -------------------  -------------  ---------  ---------------  
rakeshkapilavai/extrovert-vs-introvert-behavior-data            Extrovert vs. Introvert Behavior Data               31KB  2025-06-13 14:26:48          18603        405  1.0              
bismasajjad/global-ai-job-market-and-salary-trends-2025         Global AI Job Market & Salary Trends 2025          517KB  2025-06-01 07:20:49           6907        111  0.9411765        
adilshamim8/social-media-addiction-vs-relationships             Students' Social Media Addiction                     8KB  2025-05-10 14:38:02          17898        272  1.0              
prajwaldongre/loan-application-

In [3]:
# Updated Cell-3
import pandas as pd
import requests
import os
import zipfile
from google.colab import files
import ipywidgets as widgets
from IPython.display import display

# Install and enable ipywidgets for interactive UI
!pip install ipywidgets
!jupyter nbextension enable --py widgetsnbextension

# Initialize variables
urls, labels = [], []
urls_primary, labels_primary = [], []
urls_urlhaus, labels_urlhaus = [], []

# Attempt automatic download of Cisco Umbrella Top 1M and URLhaus
auto_success = False
try:
    print("Automatically fetching benign URLs from Cisco Umbrella Top 1M...")
    umbrella_url = "http://s3-us-west-1.amazonaws.com/umbrella-static/top-1m.csv.zip"
    response = requests.get(umbrella_url, timeout=10)
    response.raise_for_status()
    with open("top-1m.csv.zip", "wb") as f:
        f.write(response.content)
    with zipfile.ZipFile("top-1m.csv.zip", "r") as zip_ref:
        zip_ref.extractall()
    df_umbrella = pd.read_csv("top-1m.csv", names=["rank", "domain"])
    urls_primary = ["http://" + domain for domain in df_umbrella["domain"].head(250)]  # Limit to 250
    labels_primary = [0] * len(urls_primary)  # Benign (0)
    print(f"Fetched {len(urls_primary)} benign URLs from Cisco Umbrella Top 1M.")
except Exception as e:
    print(f"Error fetching Cisco Umbrella Top 1M: {e}")
    print("Using hardcoded benign URLs as fallback.")
    urls_primary = [
        "http://google.com", "http://wikipedia.org", "http://youtube.com", "http://amazon.com", "http://facebook.com",
        "http://twitter.com", "http://linkedin.com", "http://microsoft.com", "http://apple.com", "http://netflix.com"
    ] * 25  # 250 URLs
    labels_primary = [0] * len(urls_primary)
    print(f"Loaded {len(urls_primary)} hardcoded benign URLs.")

# Fetch malicious URLs from URLhaus
def fetch_urlhaus_data():
    try:
        print("Fetching malicious URLs from URLhaus...")
        response = requests.get('https://urlhaus.abuse.ch/downloads/text/', timeout=10)
        response.raise_for_status()
        urls = [url for url in response.text.splitlines() if url and not url.startswith('#')]
        return urls[:250], [1] * min(250, len(urls))  # Malware (1), limit to 250
    except Exception as e:
        print(f"Error fetching URLhaus data: {e}")
        return [], []

urls_urlhaus, labels_urlhaus = fetch_urlhaus_data()

# Combine data if both sources succeeded
if urls_primary and urls_urlhaus:
    urls = urls_primary + urls_urlhaus
    labels = labels_primary + labels_urlhaus
    auto_success = True
    print(f"Automatic download succeeded: {len(urls)} URLs loaded (Umbrella: {len(urls_primary)}, URLhaus: {len(urls_urlhaus)}).")

# Interactive UI if automatic download fails
if not auto_success:
    print("Automatic download failed. Select a dataset to download:")
    dataset_options = [
        "Umbrella Top 1M + URLhaus (Benign + Malware)",
        "CIC-MalURL-2020 (Manual Upload from https://www.unb.ca/cic/datasets/)",
        "PhishTank (Phishing URLs)"
    ]
    dataset_dropdown = widgets.Dropdown(
        options=dataset_options,
        description="Dataset:",
        style={'description_width': 'initial'}
    )
    download_button = widgets.Button(
        description="Download",
        button_style="primary",
        tooltip="Click to download the selected dataset"
    )
    output = widgets.Output()

    def on_download_button_clicked(b):
        with output:
            output.clear_output()
            selected_dataset = dataset_dropdown.value
            global urls, labels, urls_primary, labels_primary, urls_urlhaus, labels_urlhaus
            urls, labels = [], []
            urls_primary, labels_primary = [], []
            urls_urlhaus, labels_urlhaus = [], []

            if selected_dataset == "Umbrella Top 1M + URLhaus (Benign + Malware)":
                try:
                    print("Fetching benign URLs from Cisco Umbrella Top 1M...")
                    response = requests.get(umbrella_url, timeout=10)
                    response.raise_for_status()
                    with open("top-1m.csv.zip", "wb") as f:
                        f.write(response.content)
                    with zipfile.ZipFile("top-1m.csv.zip", "r") as zip_ref:
                        zip_ref.extractall()
                    df_umbrella = pd.read_csv("top-1m.csv", names=["rank", "domain"])
                    urls_primary = ["http://" + domain for domain in df_umbrella["domain"].head(250)]
                    labels_primary = [0] * len(urls_primary)
                    print(f"Fetched {len(urls_primary)} benign URLs from Cisco Umbrella Top 1M.")
                except Exception as e:
                    print(f"Error fetching Cisco Umbrella Top 1M: {e}")
                    urls_primary = [
                        "http://google.com", "http://wikipedia.org", "http://youtube.com", "http://amazon.com", "http://facebook.com",
                        "http://twitter.com", "http://linkedin.com", "http://microsoft.com", "http://apple.com", "http://netflix.com"
                    ] * 25
                    labels_primary = [0] * len(urls_primary)
                    print(f"Loaded {len(urls_primary)} hardcoded benign URLs.")
                urls_urlhaus, labels_urlhaus = fetch_urlhaus_data()

            elif selected_dataset == "CIC-MalURL-2020 (Manual Upload from https://www.unb.ca/cic/datasets/)":
                print("Please upload the CIC-MalURL-2020 CSV file.")
                uploaded = files.upload()
                if uploaded:
                    try:
                        csv_file = list(uploaded.keys())[0]
                        df = pd.read_csv(csv_file)
                        if 'type' in df.columns:
                            df['label'] = df['type'].map({'benign': 0, 'malware': 1, 'phishing': 2, 'spam': 3, 'defacement': 4})
                        elif 'label' in df.columns and df['label'].dtype == object:
                            df['label'] = df['label'].map({'Benign': 0, 'Malware': 1, 'Phishing': 2, 'Spam': 3, 'Defacement': 4, 'benign': 0, 'malware': 1, 'phishing': 2, 'spam': 3, 'defacement': 4})
                        elif 'label' in df.columns and df['label'].dtype in [int, float]:
                            pass
                        else:
                            raise ValueError("No recognizable 'type' or 'label' column in CSV")
                        urls_primary = df['url'].tolist()[:500]
                        labels_primary = df['label'].tolist()[:500]
                        print(f"Uploaded CIC dataset loaded: {len(urls_primary)} URLs.")
                    except Exception as e:
                        print(f"Error loading CIC dataset: {e}")

            elif selected_dataset == "PhishTank (Phishing URLs)":
                try:
                    print("Fetching phishing URLs from PhishTank...")
                    response = requests.get('http://data.phishtank.com/data/online-valid.csv.gz', timeout=10)
                    df = pd.read_csv(response.content, compression='gzip')
                    urls_primary = df['url'].tolist()[:250]
                    labels_primary = [2] * len(urls_primary)  # Phishing (2)
                    print(f"Fetched {len(urls_primary)} phishing URLs from PhishTank.")
                except Exception as e:
                    print(f"Error fetching PhishTank data: {e}")

            urls = urls_primary + urls_urlhaus
            labels = labels_primary + labels_urlhaus
            if urls:
                print(f"Loaded {len(urls)} URLs with labels (Primary: {len(urls_primary)}, URLhaus: {len(urls_urlhaus)}).")
            else:
                print("No URLs loaded. Using synthetic fallback dataset.")
                urls = [
                    'http://example.com', 'http://safe-site.org', 'http://legit-site.net', 'http://trusted-page.com', 'http://benign-url.org',
                    'http://malware-site.com', 'http://virus-download.net', 'http://trojan-page.org', 'http://malicious-code.com', 'http://harmful-site.net',
                    'http://phish-site.com', 'http://fake-login-page.org', 'http://scam-bank.com', 'http://phishing-url.net', 'http://credential-stealer.org',
                    'http://spam-offer.com', 'http://unwanted-ad.net', 'http://fake-deal.org', 'http://spam-promotion.com', 'http://adware-site.net',
                    'http://defaced-site.org', 'http://hacked-page.com', 'http://vandalized-url.net', 'http://defacement-page.org', 'http://compromised-site.com'
                ] * 4  # 100 URLs
                labels = [0] * 20 + [1] * 20 + [2] * 20 + [3] * 20 + [4] * 20
                print(f"Fallback dataset loaded: {len(urls)} URLs.")

    download_button.on_click(on_download_button_clicked)
    display(dataset_dropdown, download_button, output)

# Google Safe Browsing
def fetch_safe_browsing_data(urls):
    url = f'https://safebrowsing.googleapis.com/v4/threatMatches:find?key={GOOGLE_API_KEY}'
    payload = {
        'client': {'clientId': 'mycompany', 'clientVersion': '1.0'},
        'threatInfo': {
            'threatTypes': ['MALWARE', 'SOCIAL_ENGINEERING', 'UNWANTED_SOFTWARE', 'POTENTIALLY_HARMFUL_APPLICATION'],
            'platformTypes': ['ANY_PLATFORM'],
            'threatEntryTypes': ['URL'],
            'threatEntries': [{'url': url} for url in urls[:100]]
        }
    }
    try:
        response = requests.post(url, json=payload, timeout=10)
        response.raise_for_status()
        threats = response.json().get('matches', [])
        labels = [0] * len(urls)
        for threat in threats:
            url = threat['threat']['url']
            threat_type = threat['threatType']
            if url in urls:
                idx = urls.index(url)
                if threat_type == 'MALWARE':
                    labels[idx] = 1
                elif threat_type == 'SOCIAL_ENGINEERING':
                    labels[idx] = 2
                elif threat_type == 'UNWANTED_SOFTWARE':
                    labels[idx] = 3
                else:
                    labels[idx] = 4
        return labels
    except Exception as e:
        print(f"Error fetching Safe Browsing data: {e}")
        return [0] * len(urls)

if GOOGLE_API_KEY != userdata.get('google_search_api') and urls:
    labels = fetch_safe_browsing_data(urls)
else:
    print("Skipping Safe Browsing API (invalid key or no URLs). Using default labels.")

# Filter invalid URLs
valid_data = [(url, label) for url, label in zip(urls, labels) if url and isinstance(url, str) and pd.notna(label)]
urls, labels = zip(*valid_data) if valid_data else ([], [])

print(f"Loaded {len(urls)} URLs with labels (Primary: {len(urls_primary)}, URLhaus: {len(urls_urlhaus)}).")
if len(urls) == 0:
    print("Warning: No valid URLs loaded. Check dataset sources and API keys.")

Enabling notebook extension jupyter-js-widgets/extension...
Paths used for configuration of notebook: 
    	/root/.jupyter/nbconfig/notebook.json
Paths used for configuration of notebook: 
    	
      - Validating: [32mOK[0m
Paths used for configuration of notebook: 
    	/root/.jupyter/nbconfig/notebook.json
Automatically fetching benign URLs from Cisco Umbrella Top 1M...
Fetched 250 benign URLs from Cisco Umbrella Top 1M.
Fetching malicious URLs from URLhaus...
Automatic download succeeded: 500 URLs loaded (Umbrella: 250, URLhaus: 250).
Skipping Safe Browsing API (invalid key or no URLs). Using default labels.
Loaded 500 URLs with labels (Primary: 250, URLhaus: 250).


In [4]:
print(len(urls))

500


In [5]:
# Updated Cell-4
import gym
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
import os
from google.colab import files

# Upload model files
def upload_model_files():
    print("Upload the following albert-base-v2 files: config.json, pytorch_model.bin, tokenizer_config.json, spiece.model")
    uploaded = files.upload()
    if uploaded:
        model_dir = "/content/albert-base-v2"
        os.makedirs(model_dir, exist_ok=True)
        for filename, content in uploaded.items():
            with open(os.path.join(model_dir, filename), 'wb') as f:
                f.write(content)
        print(f"Uploaded model files to {model_dir}.")
        return model_dir
    raise RuntimeError("No files uploaded. Please upload the required model files.")

# URL environment
class URLEnvironment(gym.Env):
    def __init__(self, urls, labels, model_dir):
        super(URLEnvironment, self).__init__()
        self.urls = urls
        self.labels = labels
        self.current_idx = 0
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True, local_files_only=True)
            self.model = AutoModel.from_pretrained(model_dir, local_files_only=True)
            print("Loaded albert-base-v2 successfully.")
        except Exception as e:
            raise RuntimeError(f"Failed to load albert-base-v2 from {model_dir}: {e}")
        self.action_space = gym.spaces.Discrete(5)  # 0: benign, 1: malware, 2: phishing, 3: spam, 4: defacement
        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(768,), dtype=np.float32)

    def step(self, action):
        url = self.urls[self.current_idx]
        inputs = self.tokenizer(url, return_tensors="pt", max_length=128, truncation=True, padding=True)
        with torch.no_grad():
            outputs = self.model(**inputs)
        embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
        reward = 1.0 if action == self.labels[self.current_idx] else -1.0
        self.current_idx = (self.current_idx + 1) % len(self.urls)
        done = self.current_idx == 0
        return embedding, reward, done, {}

    def reset(self):
        self.current_idx = 0
        url = self.urls[self.current_idx]
        inputs = self.tokenizer(url, return_tensors="pt", max_length=128, truncation=True, padding=True)
        with torch.no_grad():
            outputs = self.model(**inputs)
        embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
        return embedding

    def render(self):
        pass

# Initialize environment
model_dir = upload_model_files()
try:
    env = URLEnvironment(urls, labels, model_dir)
    print("Environment initialized successfully.")
except Exception as e:
    raise RuntimeError(f"Error initializing environment: {e}")

Upload the following albert-base-v2 files: config.json, pytorch_model.bin, tokenizer_config.json, spiece.model


Saving config.json to config.json
Saving pytorch_model.bin to pytorch_model.bin
Saving spiece.model to spiece.model
Saving tokenizer_config.json to tokenizer_config.json
Uploaded model files to /content/albert-base-v2.
Loaded albert-base-v2 successfully.
Environment initialized successfully.


In [6]:
# Cell 5: Train DQN Model
try:
    dqn_model = DQN(
        'MlpPolicy',
        env,
        verbose=1,
        learning_rate=1e-3,
        buffer_size=10000,
        batch_size=32,
        device='cuda' if tf.config.list_physical_devices('GPU') else 'cpu'
    )
    dqn_model.learn(total_timesteps=10000)
    print("Training completed.")
except Exception as e:
    print(f"Error during training: {e}")


Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




----------------------------------
| rollout/            |          |
|    ep_len_mean      | 500      |
|    ep_rew_mean      | 230      |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 4        |
|    fps              | 8        |
|    time_elapsed     | 225      |
|    total_timesteps  | 2000     |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.00073  |
|    n_updates        | 474      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 500      |
|    ep_rew_mean      | 346      |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 8        |
|    fps              | 8        |
|    time_elapsed     | 449      |
|    total_timesteps  | 4000     |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.000355 |
|    n_updates      

In [9]:
# Cell 6: Save Models
try:
    dqn_model.save('/content/drive/MyDrive/drl_url_detector.zip')
    #model.save_pretrained('/content/drive/MyDrive/transformer_url_detector')
    #tokenizer.save_pretrained('/content/drive/MyDrive/transformer_url_detector')
    print("Models saved to Google Drive.")
except Exception as e:
    print(f"Error saving models: {e}")


Models saved to Google Drive.


In [10]:
# Cell 7: Evaluate Model
try:
    correct = 0
    total = 100
    obs = env.reset()
    for _ in range(total):
        action, _ = dqn_model.predict(obs)
        obs, reward, done, info = env.step(action)
        if reward > 0:
            correct += 1
    accuracy = correct / total * 100
    print(f'Accuracy: {accuracy:.2f}%')
except Exception as e:
    print(f"Error during evaluation: {e}")


Accuracy: 95.00%
