<a href="https://colab.research.google.com/github/Tier1Security/roberta_classifier/blob/main/Autoencoder_Training_Script.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import pandas as pd
from transformers import RobertaModel, RobertaTokenizer
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import numpy as np

# ==========================================
# 1. Model Definition (RoBERTa + Decoder)
# ==========================================
class RobertaAutoencoder(nn.Module):
    def __init__(self, model_name="roberta-base"):
        super(RobertaAutoencoder, self).__init__()
        # Load your fine-tuned encoder weights here
        self.encoder = RobertaModel.from_pretrained(model_name)

        # The Decoder: Attempts to reconstruct the 768-dim embedding
        # Malicious commands will cause high Reconstruction Loss (MSE)
        self.decoder = nn.Sequential(
            nn.Linear(768, 1024),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(1024, 768)
        )

    def forward(self, input_ids, attention_mask):
        # Get the hidden state (CLS token)
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        embedding = outputs.last_hidden_state[:, 0, :] # Shape: [batch, 768]

        # Attempt reconstruction
        reconstructed = self.decoder(embedding)
        return embedding, reconstructed

# ==========================================
# 2. Dataset Handler
# ==========================================
class BaselineDataset(Dataset):
    def __init__(self, csv_file, tokenizer, max_len=128):
        self.data = pd.read_csv(csv_file)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = str(self.data.iloc[idx]['command'])
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten()
        }

# ==========================================
# 3. Training Loop
# ==========================================
def train_anomaly_engine(csv_path, model_save_path="anomaly_engine.pt"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
    model = RobertaAutoencoder().to(device)

    dataset = BaselineDataset(csv_path, tokenizer)
    loader = DataLoader(dataset, batch_size=32, shuffle=True)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    criterion = nn.MSELoss() # Loss between original and reconstructed embedding

    print(f"[*] Training Anomaly Engine on {len(dataset)} benign samples...")

    model.train()
    for epoch in range(3): # Usually converges quickly on benign data
        total_loss = 0
        for batch in tqdm(loader):
            optimizer.zero_grad()

            input_ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)

            emb, rec = model(input_ids, mask)

            # Loss is the 'struggle' to understand the command
            loss = criterion(rec, emb)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1} Loss: {total_loss/len(loader)}")

    # Calculate Anomaly Threshold (Z-Score Baseline)
    model.eval()
    errors = []
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)
            emb, rec = model(input_ids, mask)
            # Row-wise MSE
            error = torch.mean((emb - rec)**2, dim=1)
            errors.extend(error.cpu().numpy())

    threshold = np.mean(errors) + (3 * np.std(errors))
    print(f"[+] Training Complete. Anomaly Threshold set to: {threshold}")

    torch.save({
        'model_state': model.state_dict(),
        'threshold': threshold
    }, model_save_path)

if __name__ == "__main__":
    # Ensure you have the benign_baseline.csv generated first
    # train_anomaly_engine("benign_baseline.csv")
    pass