In [None]:
!pip install torch transformers scikit-learn

In [55]:
import torch
from transformers import (
    T5ForConditionalGeneration, 
    T5Tokenizer, 
    Trainer, 
    TrainingArguments,
    DataCollatorForSeq2Seq
)
from torch.utils.data import Dataset
import json
from sklearn.model_selection import train_test_split
import os
from collections import Counter

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class SnortRuleDataset(Dataset):
    def __init__(self, data, tokenizer, max_input_length=128, max_target_length=256):
        self.data = data
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        input_text = f"Generate Snort rule: {item['input'].strip().lower()}"
        target_text = item['target']

        input_encoding = self.tokenizer(
            input_text,
            max_length=self.max_input_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        target_encoding = self.tokenizer(
            target_text,
            max_length=self.max_target_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        return {
            "input_ids": input_encoding["input_ids"].flatten(),
            "attention_mask": input_encoding["attention_mask"].flatten(),
            "labels": target_encoding["input_ids"].flatten()
        }

with open("/kaggle/working/new_dataset.json", "r") as ddd:
    training_data = json.load(ddd)

def replicate_critical_pairs(data, min_count=10):
    counter = Counter([item['input'] for item in data])
    amplified = []
    for item in data:
        repeat = max(min_count - counter[item['input']], 1)
        amplified.extend([item] * repeat)
    return amplified

class SnortRuleGenerator:
    def __init__(self, model_name="t5-small"):
        self.model_name = model_name
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(model_name)
        self.device = device
        self.model.to(self.device)

        special_tokens = ["<snort>", "<rule>", "<alert>"]
        self.tokenizer.add_tokens(special_tokens)
        self.model.resize_token_embeddings(len(self.tokenizer))

    def prepare_data(self, data, test_size=0.3):
        data = replicate_critical_pairs(data)
        train_data, val_data = train_test_split(data, test_size=test_size, random_state=42)
        train_dataset = SnortRuleDataset(train_data, self.tokenizer)
        val_dataset = SnortRuleDataset(val_data, self.tokenizer)
        return train_dataset, val_dataset

    def train(self, train_dataset, val_dataset, output_dir="./snort-rule-model", num_epochs=10, batch_size=8, learning_rate=5e-5):
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=num_epochs,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            warmup_steps=50,
            weight_decay=0.01,
            logging_dir=f'{output_dir}/logs',
            logging_steps=10,
            evaluation_strategy="epoch",
            save_strategy="epoch",
            save_total_limit=2,
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            learning_rate=learning_rate,
            fp16=True,
            dataloader_pin_memory=True,
            gradient_accumulation_steps=1
        )

        data_collator = DataCollatorForSeq2Seq(
            tokenizer=self.tokenizer,
            model=self.model,
            padding=True
        )

        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            data_collator=data_collator,
            tokenizer=self.tokenizer,
        )

        print("Starting training...")
        trainer.train()
        trainer.save_model()
        self.tokenizer.save_pretrained(output_dir)
        print(f"Model saved to {output_dir}")

    def generate_rule(self, input_text, max_length=256, num_beams=1, temperature=0.0):
        input_text = f"Generate Snort rule: {input_text.strip().lower()}"
        input_ids = self.tokenizer.encode(
            input_text,
            return_tensors="pt",
            max_length=128,
            truncation=True
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model.generate(
                input_ids,
                max_length=max_length,
                num_beams=num_beams,
                temperature=temperature,
                do_sample=False,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
                early_stopping=True
            )
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

    def load_model(self, model_path):
        self.model = T5ForConditionalGeneration.from_pretrained(model_path)
        self.tokenizer = T5Tokenizer.from_pretrained(model_path)
        self.model.to(self.device)
        print(f"Model loaded from {model_path}")

def main():
    # Initialize generator
    generator = SnortRuleGenerator("t5-small")
    
    # Prepare datasets
    train_dataset, val_dataset = generator.prepare_data(training_data)
    
    # Train the model
    generator.train(
        train_dataset, 
        val_dataset, 
        num_epochs=15,  # Increase for better results
        batch_size=4,   # Adjust based on your GPU memory
        learning_rate=3e-4
    )
    return

def inference_example():
    """Example of using the trained model for inference"""
    generator = SnortRuleGenerator()
    
    # Load your trained model
    generator.load_model("./snort-rule-model")
    
    # Interactive generation
    while True:
        user_input = input("\nEnter security rule description (or 'quit' to exit): ")
        if user_input.lower() == 'quit':
            break
        
        rule = generator.generate_rule(user_input)
        print(f"Generated Snort Rule: {rule}")

if __name__ == "__main__":
    # Run training
    main()
    
    # Uncomment for interactive inference
    # inference_example()

# Additional utility functions for data preprocessing
def load_data_from_json(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data

def augment_data(original_data):
    """Simple data augmentation techniques"""
    augmented = []
    
    synonyms = {
    "detect": ["monitor", "inspect", "capture", "check", "catch", "track"],
    "bidirectional": ["in both directions", "two-way", "both ways"],
    "ICMP": ["ping", "echo", "ICMP packets"],
    "HTTP GET": ["web GET", "GET method", "GET requests", "port 80 GET"],
    "HTTP POST": ["form submission"],
    "SSH brute force": ["SSH password cracking", "SSH dictionary", "SSH login attempts", "SSH authentication", "SSH credential stuffing", "SSH password guessing"],
    "FTP file uploads": ["FTP STOR command", "FTP uploads", "FTP upload actions", "FTP file send operations"],
    "FTP file downloads": ["FTP RETR command"],
    "DNS tunneling": ["DNS exfiltration", "large DNS packets", "suspicious port 53", "DNS covert channel", "abnormal DNS query sizes"],
    "SQL injection": ["SQL attacks", "UNION SELECT", "SQLi", "database injection", "blind SQL", "time-based SQL", "boolean-based SQL", "order-by injections"],
    "XSS": ["cross-site scripting", "script injections", "JavaScript injection", "web code injection", "DOM-based XSS", "reflected XSS", "stored XSS", "XSS filter bypass"],
    "port scan": ["network scanning", "SYN scan", "open port scanning", "reconnaissance", "connect scan", "stealth scan", "UDP scanning", "nmap scanning"],
    "HTTPS anomalies": ["SSL traffic", "TLS traffic", "encrypted web traffic", "SSL certificate issues", "TLS handshake anomalies"],
    "DoS": ["denial of service", "SYN flood", "server overload"],
    "DDoS": ["distributed DoS", "multi-source DoS", "amplification attacks", "volumetric DDoS"],
    "malware C2": ["command and control", "botnet C2", "trojan C&C", "beacon traffic", "heartbeat signals", "callback communication"],
    "ransomware": ["file locking", "cryptovirus", "encrypted file", "ransomware note", "WannaCry", "Locky", "CryptoLocker"],
    "cryptojacking": ["crypto mining", "bitcoin mining", "coinhive", "browser-based mining", "Monero", "WebAssembly", "CPU mining"],
    "IoT botnet": ["smart device botnet", "embedded system threats", "Mirai"],
    "phishing": ["fake login", "phishing page", "fraudulent login", "social engineering", "PayPal phishing", "bank phishing", "Microsoft phishing"],
    "web shell": ["PHP shell", "ASP shell", "JSP shell", "remote code", "malicious upload", "backdoor script"],
    "lateral movement": ["pivot", "admin share", "WMI", "PsExec", "internal spread", "remote access"],
    "privilege escalation": ["root access", "admin gain", "superuser access", "UAC bypass", "Windows escalation", "kernel exploit"],
    "data exfiltration": ["data leaks", "info theft", "FTP exfil", "HTTP exfil", "email leaks"],
    "keylogger": ["keystroke theft", "keyboard logger", "screen capture", "mouse tracking", "credential theft"],
    "steganography": ["hidden data", "image hiding", "large JPEG", "PNG", "MP3"],
    "MITM": ["man-in-the-middle", "SSL spoofing", "SSL stripping", "cert pinning bypass", "DNS poisoning"],
    "DNS poisoning": ["DNS spoofing", "domain poisoning", "DNS hijacking", "pharming"],
    "ARP spoofing": ["ARP poisoning", "MAC spoofing", "Ethernet spoofing", "network poisoning", "gratuitous ARP", "ARP flood"],
    "VLAN hopping": ["double tagging"],
    "command injection": ["OS command", "system command", "shell injection", "Windows command", "PowerShell", "blind injection"],
    "XXE": ["XML external entity", "XML parser", "XML injection", "entity expansion", "billion laughs"],
    "directory traversal": ["path traversal", "filesystem access", "dot-dot-slash", "upper directory", "Windows traversal"],
    "credential harvesting": ["auth data theft", "login harvesting", "credit card theft", "SSN leaks"],
    "RAT": ["remote access trojan", "NetBus", "SubSeven", "Back Orifice"],
    "MITM cert": ["self-signed", "cert spoof", "SSL stripping"]
}

    for item in original_data:
        augmented.append(item)
        
        # Create variations
        input_text = item["input"]
        for original, alternatives in synonyms.items():
            if original in input_text:
                for alt in alternatives:
                    new_input = input_text.replace(original, alt)
                    augmented.append({
                        "input": new_input,
                        "target": item["target"]
                    })
    
    return augmented


print("DONE!")

Using device: cuda


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Starting training...


Epoch,Training Loss,Validation Loss
1,0.1378,0.106465
2,0.0779,0.049047
3,0.0473,0.02668
4,0.0311,0.014034
5,0.0209,0.007685
6,0.0137,0.003761
7,0.008,0.001919
8,0.0064,0.001204
9,0.0059,0.000577
10,0.0051,0.000451


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].


Model saved to ./snort-rule-model
Setup complete! Run main() to start training or use the SnortRuleGenerator class directly.


In [57]:
def test_single_input(model_path, input_text):
    generator = SnortRuleGenerator()
    generator.load_model(model_path)

    result = generator.generate_rule(input_text)
    return result
    print(f"\nInput: {input_text}")
    print(f"Generated Snort Rule: {result}")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Model loaded from ./snort-rule-model

Input: catch gratuitous ARP
Generated Snort Rule: alert arp any any -> any any (msg:"ARP Spoofing Attack"; threshold:type threshold, track by_src, count 5, seconds 60; sid:1000015; rev:1;)


In [70]:
with open("/kaggle/working/tr_eng.txt", "r", encoding="utf-8") as tr_eng:
    obj_list = [el.split(",") for el in tr_eng.readlines()]

In [72]:
with open("tr_eng.txt", "r", encoding="utf-8") as tr_eng:
    obj_list = [el.split(",") for el in tr_eng.readlines()]
def find_from_inputs(input:str):
    for el in obj_list:
        if input.lower() == el[0].lower():
            return el[1]

test_single_input("./snort-rule-model", find_from_inputs("web sitesi GET taleplerini tespit et"))
test_single_input("./snort-rule-model", find_from_inputs(user_input))

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Model loaded from ./snort-rule-model

Input: detect web GET requests

Generated Snort Rule: alert tcp any any -> any 80 (msg:"HTTP GET Request"; content:"GET"; http_method; sid:1000002; rev:1;)
