In [None]:
import torch
import pandas as pd
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

# Hardcoded MITRE ATT&CK Mapping for CVEs
MITRE_MAPPING = {
    "CVE-2021-26855": "T1190 - Exploit Public-Facing Application",
    "CVE-2021-26857": "T1210 - Remote Code Execution",
    "CVE-2021-26858": "T1072 - Remote Services",
    "CVE-2021-27065": "T1203 - Exploitation for Client Execution",
}

def extract_mitre_ttp(text):
    for cve, ttp in MITRE_MAPPING.items():
        if cve in text:
            return ttp
    return "No MITRE ATT&CK mapping found"

# Load dataset
train_path = "df_train.csv"
test_path = "df_test.csv"
df_train = pd.read_csv(train_path)
df_test = pd.read_csv(test_path)

# Define dataset class
class ThreatIntelDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=256):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt")
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'label': torch.tensor(label, dtype=torch.long)
        }
