In [11]:
import os
import google.generativeai as genai
from dotenv import load_dotenv
from tqdm import tqdm
import pandas as pd
load_dotenv() 

api_key = os.environ.get("GOOGLE_API_KEY")
if api_key is None:
    raise ValueError("The GOOGLE_API_KEY environment variable is not set in your .env file.")
genai.configure(api_key=api_key)

### Gemini Classification

Prompt: Use Gemini 2.0 Flash to classify each TLD as DGA (1) or legitimate (0) (you'll need to come up with a prompt and decide on using a structured output). 

In [13]:
model = genai.GenerativeModel("gemini-1.5-flash-002")


def classify_tld(tld):
    prompt = f"""
Classify the following Top-Level Domain (TLD) as either a Domain Generation Algorithm (DGA) generated domain (1) or a legitimate domain (0). 
Provide ONLY the classification number, nothing else.

TLD: {tld}
Classification:
"""
    response = model.generate_content(prompt)
    classification = response.text.strip()

    if classification == "1" or classification == "0":
        return int(classification)
    else:
        print(f"Unexpected response: {response.text} for TLD: {tld}")
        return None  


In [7]:
print(classify_tld("google.com"))

0


In [14]:
tld_df = pd.read_csv("data/tld.csv")
tld_df['dga_classification'] = None 

for index, row in tqdm(tld_df.iterrows(), total=len(tld_df), desc="Classifying TLDs"):
    tld = row['domain_tld']
    classification = classify_tld(tld)
    tld_df.loc[index, 'dga_classification'] = classification 


Classifying TLDs: 100%|██████████| 176/176 [01:05<00:00,  2.71it/s]


In [15]:
print(tld_df.head(10))
print(tld_df['dga_classification'].value_counts(dropna=False))

         domain_tld dga_classification
0  wunderground.com                  0
1       dropbox.com                  0
2         aoltw.net                  0
3              home                  0
4       mozilla.com                  0
5    metasploit.com                  0
6           aol.com                  0
7         aoltw.net                  0
8           aol.com                  0
9           aol.com                  0
dga_classification
0    164
1     12
Name: count, dtype: int64


In [16]:
tld_df = tld_df.drop_duplicates(subset=['domain_tld'])

print('Unique records:', tld_df['domain_tld'].nunique())
print('Total records:', len(tld_df))

tld_df.to_csv("data/tld_dga_classification.csv", index=False)

Unique records: 102
Total records: 103
