In [None]:
import numpy as np
import os
import time
import random
import torch
import torch.nn as nn
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModelForCausalLM
import pickle
import math
from sklearn.utils import resample
from sklearn.utils.class_weight import compute_class_weight


def set_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
class LLMGraphTransformer(nn.Module):
    def __init__(self, model_name="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", device="cpu"):
        super().__init__()
        self.device = device

        # Load the tokenizer and model for TinyLlama
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)

        # Ensure padding token is set for TinyLlama
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model.resize_token_embeddings(len(self.tokenizer))

        self.dropout = nn.Dropout(p=0.2)

        # New layers to process edge features and reduce text logits dimension
        self.edge_fc = nn.Linear(77, 64).to(self.device)
        self.edge_dropout = nn.Dropout(p=0.2)
        
        # Reduce text logits to match edge embedding dimensions
        self.text_fc = nn.Linear(self.model.config.vocab_size, 64).to(self.device)
        
        # Final classification layer to output 7 classes
        self.classifier = nn.Linear(128, 7).to(self.device)

    def forward(self, batch_text, edge_features):
        # Tokenize text
        inputs = self.tokenizer(batch_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device)
        outputs = self.model(inputs['input_ids'], attention_mask=inputs['attention_mask'])
        
        # Get the logits for the last token in each sequence
        text_logits = outputs.logits[:, -1, :]  # Shape is (batch_size, vocab_size)
        text_emb = self.text_fc(text_logits)    # Reduce text logits to (batch_size, 64)

        # Process edge features through a fully connected layer
        edge_emb = self.edge_fc(edge_features)  # Shape (batch_size, 64)
        edge_emb = self.edge_dropout(edge_emb)

        # Concatenate the text logits and the edge feature embeddings
        combined_logits = torch.cat((text_emb, edge_emb), dim=1)  # Shape (batch_size, 128)
        
        # Pass through final classifier layer to get 7-class output
        final_logits = self.classifier(combined_logits)  # Shape (batch_size, 7)
        
        return final_logits

    def generate_text(self, graph_data, labels, max_new_tokens=50):
        # Convert the graph adjacency list to text directly within this method
        batch_text = []
        for node, neighbors in enumerate(graph_data):
            if isinstance(neighbors, (list, set, np.ndarray)):
                for neighbor in neighbors:
                    question = f"What is the relationship between Node {node} and Node {neighbor}? Choices: {', '.join(labels)}."
                    batch_text.append(question)
            else:
                question = f"What is the relationship between Node {node} and Node {neighbors}? Choices: {', '.join(labels)}."
                batch_text.append(question)

        # Tokenize and generate predictions
        inputs = self.tokenizer(batch_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device)
        outputs = self.model.generate(
            inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            do_sample=True
        )
        generated_text = [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]

        return generated_text


def balance_data(data, labels, n_samples_per_label):
    # Find unique labels and their counts
    random.seed(42)
    label_groups = {}
    for label in np.unique(labels):
        label_indices = np.where(labels == label)[0]
        # If the label has fewer samples than the target, we use replace=True to oversample.
        sampled_indices = np.random.choice(label_indices, size=n_samples_per_label, replace=(len(label_indices) < n_samples_per_label))
        label_groups[label] = sampled_indices

    # Concatenate the balanced data
    balanced_indices = np.concatenate(list(label_groups.values()))
    balanced_data = data[balanced_indices]
    balanced_labels = labels[balanced_indices]

    return balanced_data, balanced_labels


def process_llm_output(llm_output):
    llm_output = llm_output.lower().strip()
    label_mapping = {
        'Benign':0, 'BruteForce':1, 'DoS':2, 'DDoS':3, 'Web':4, 'Bot':5, 'Infilteration':6
    }
    for keyword, index in label_mapping.items():
        if keyword in llm_output:
            return index
    return -1


def save_data_splits(train, val, test, train_labels, val_labels, test_labels, path="data_splits/ces-cic"):
    os.makedirs(path, exist_ok=True)
    with open(os.path.join(path, "train.pkl"), "wb") as f:
        pickle.dump((train, train_labels), f)
    with open(os.path.join(path, "val.pkl"), "wb") as f:
        pickle.dump((val, val_labels), f)
    with open(os.path.join(path, "test.pkl"), "wb") as f:
        pickle.dump((test, test_labels), f)
    print("Data splits and labels saved successfully.")

def load_data_splits(path="data_splits/ces-cic"):
    with open(os.path.join(path, "train.pkl"), "rb") as f:
        train, train_labels = pickle.load(f)
    with open(os.path.join(path, "val.pkl"), "rb") as f:
        val, val_labels = pickle.load(f)
    with open(os.path.join(path, "test.pkl"), "rb") as f:
        test, test_labels = pickle.load(f)
    print("Data splits and labels loaded successfully.")
    return train, val, test, train_labels, val_labels, test_labels

def fit(args):
    data = args["dataset"]
    binary = args["binary"]

    # Update the path to use ../cyber_gnn/ instead of datasets/
    path = "datasets/" + data
    if not path.endswith('/'):
        path += '/'

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load the data manually (edge_feat, label, adj, adj_lists, config)
    edge_feat = np.load(path + "edge_feat_scaled.npy", allow_pickle=True)
    edge_feat = torch.tensor(edge_feat, dtype=torch.float, device=device)

    # Load the label for multiclass classification
    label = np.load(path + "label_mul.npy", allow_pickle=True)
    label = torch.tensor(label, dtype=torch.long, device=device)
    adj = np.load(path + "adj_random.npy", allow_pickle=True)
    with open(path + 'adj_random_list.dict', 'rb') as file:
        adj_lists = pickle.load(file)

    config = {
        "num_of_layers": 3,
        "num_heads_per_layer": [6, 6, 6],
        "num_features_per_layer": [edge_feat.shape[1], 8, 8, 8],
        "num_identity_feats": 8,
        "add_skip_connection": False,
        "bias": True,
        "dropout": 0.2
    }

    # Initialize LLMGraphTransformer using TinyLlama
    llm_graph_transformer = LLMGraphTransformer(model_name="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", device=device)

    # Define labels for relationship types
    labels = ['Benign', 'BruteForce', 'DoS', 'DDoS', 'Web', 'Bot', 'Infilteration']
    # Define the optimizer with Adam
    optimizer = torch.optim.Adam(llm_graph_transformer.parameters(), lr=1e-5)
    
    # Assuming `train_labels` holds your training set labels
    class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_labels), y=train_labels)
    class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)

    # Update CrossEntropyLoss with class weights
    loss_fn = nn.CrossEntropyLoss(weight=class_weights_tensor)

    num_edges = len(edge_feat)
    label_cpu = label.cpu().numpy()
    unique, counts = np.unique(label_cpu, return_counts=True)

    balanced_data, balanced_labels = balance_data(np.arange(num_edges), label_cpu, n_samples_per_label=120)

    # Check if saved splits exist, else create and save them
    if not os.path.exists("data_splits/train.pkl"):
        # Perform initial train-validation-test split and save the splits
        train_val, test, train_val_labels, test_labels = train_test_split(
            balanced_data, balanced_labels, test_size=0.1, stratify=balanced_labels, random_state=42
        )
        train, val, train_labels, val_labels = train_test_split(
            train_val, train_val_labels, test_size=0.1, stratify=train_val_labels, random_state=42
        )
        save_data_splits(train, val, test, train_labels, val_labels, test_labels)
    else:
        # Load the saved splits and their labels for consistent use
        train, val, test, train_labels, val_labels, test_labels = load_data_splits()

    print(len(train), len(val), len(test))

    # Print the distribution of labels for each set
    print("Label distribution in Train Set:")
    unique_train, counts_train = np.unique(train_labels, return_counts=True)
    print(dict(zip(unique_train, counts_train)))

    print("Label distribution in Validation Set:")
    unique_val, counts_val = np.unique(val_labels, return_counts=True)
    print(dict(zip(unique_val, counts_val)))

    print("Label distribution in Test Set:")
    unique_test, counts_test = np.unique(test_labels, return_counts=True)
    print(dict(zip(unique_test, counts_test)))

    times = []
    trainscores = []
    valscores = []

    for epoch in range(10):
        print("Epoch: ", epoch)
        random.shuffle(train)
        epoch_start = time.time()
        
        # Print the number of batches
        print(f"Training data size: {len(train)}")
        print(f"Number of batches: {len(train) // 10}")
        
        for batch in range(int(len(train) / 10)):  # Batch size is 10
            batch_edges = train[10 * batch:10 * (batch + 1)]
            
            if len(batch_edges) == 0:
                print(f"Skipping empty batch {batch + 1}")
                continue

            start_time = time.time()
            
            # Convert batch_edges to text
            batch_text = llm_graph_transformer.generate_text(batch_edges, labels, max_new_tokens=10)

            # Generate logits from text inputs
            edge_batch = edge_feat[batch_edges]
            logits = llm_graph_transformer(batch_text, edge_batch)
            
            # Ensure logits and labels are both on the same device
            logits = logits.to(device)
            batch_labels = label[batch_edges].to(device)

            # Calculate loss using logits and target labels
            loss = loss_fn(logits, batch_labels)
            # Backpropagation
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # Calculate accuracy
            predicted_labels = torch.argmax(logits, dim=-1)
            acc_train = f1_score(label_cpu[batch_edges], predicted_labels.cpu().numpy(), average="weighted")

            end_time = time.time()
            times.append(end_time - start_time)
            trainscores.append(acc_train)

            # Print the result
            print(f'batch: {batch + 1:03d}, loss_train: {loss.item():.4f}, acc_train: {acc_train:.4f}, time: {end_time - start_time:.4f}s')

            if batch >= 179:
                break

        # Perform validation after each epoch
        print(f"Validation after epoch {epoch}:")
        val_acc, val_loss, val_output = predict_(llm_graph_transformer, label, loss_fn, val, device, edge_feat)
        print(f"Validation set results: loss= {val_loss:.4f}, accuracy= {val_acc:.4f}, label acc= {f1_score(label_cpu[val], val_output, average=None)}")
        valscores.append(val_acc)

    acc_test, loss_test, predict_output = predict_(llm_graph_transformer, label, loss_fn, test, device, edge_feat)
    print(f"Test set results: loss= {loss_test:.4f}, accuracy= {acc_test:.4f}, label acc= {f1_score(label_cpu[test], predict_output, average=None)}")
    save_model(llm_graph_transformer, optimizer, epoch)



def save_model(model, optimizer, epoch, path="llm_w_edgefeat.pth"):
    # Get current time and format it
    current_time = time.strftime("%Y%m%d-%H%M%S")
    
    # Add the directory 'model/' and append the time to the path
    path = f"model/{current_time}_{path}"
    
    # Create checkpoint to save model and optimizer state
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    
    # Save the checkpoint
    torch.save(checkpoint, path)
    
    # Print confirmation that the model has been saved
    print(f"Model saved to {path}")

def predict_(model, label, loss_fn, data_idx, device, edge_feat):
    predict_output = []
    loss = 0.0
    num_batches = math.ceil(len(data_idx) / 10)

    for batch in range(num_batches):
        batch_edges = data_idx[10 * batch:10 * (batch + 1)]
        labels = ['Benign', 'BruteForce', 'DoS', 'DDoS', 'Web', 'Bot', 'Infilteration']

        # Generate text from batch_edges
        batch_text = model.generate_text(batch_edges, labels, max_new_tokens=10)
        edge_batch = edge_feat[batch_edges]
        # Get logits from the model (floating point values representing class probabilities)
        logits = model(batch_text, edge_batch).to(device)  # Use the model to get logits

        # Target labels
        batch_labels = label[batch_edges].to(device)  # Long type labels for cross_entropy

        # Compute the loss using logits (input) and batch_labels (target)
        batch_loss = loss_fn(logits, batch_labels)
        loss += batch_loss.item()

        # Calculate predictions based on logits
        predicted_labels = torch.argmax(logits, dim=-1).cpu().numpy()
        predict_output.extend(predicted_labels)

    # Normalize loss by the number of batches
    loss /= num_batches

    # Calculate accuracy using F1 score
    acc = f1_score(label.cpu().numpy()[data_idx], predict_output, average="weighted")
    return acc, loss, predict_output


if __name__ == '__main__':
    set_seeds(42) 
    fit({
        "dataset": "CSE-CIC",
        "binary": False,
        "residual": True
    })