# Setup

In [None]:
import warnings
from sklearn.exceptions import UndefinedMetricWarning
warnings.filterwarnings("ignore", category=UndefinedMetricWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

import os
import torch
import torch.distributed
import numpy as np
import utils
import random
from dataclasses import field, dataclass
from datasets.distributed import split_dataset_by_node
from typing import Optional
from copy import deepcopy
from torchinfo import summary
from torch.distributed.elastic.multiprocessing.errors import record

from transformers import (
    EvalPrediction,
    HfArgumentParser,
    TrainingArguments,
    EarlyStoppingCallback,
)

from sklearn.metrics import (
    f1_score,
    accuracy_score,
    precision_score,
    recall_score,
    top_k_accuracy_score,
    classification_report, confusion_matrix
)

from NetFoundDataCollator import DataCollatorForFlowClassification
from NetFoundModels import NetfoundFinetuningModel, NetfoundNoPTM
from NetFoundTrainer import NetfoundTrainer
from NetfoundConfig import NetfoundConfig, NetFoundTCPOptionsConfig, NetFoundLarge
from NetfoundTokenizer import NetFoundTokenizer

from utils import ModelArguments, CommonDataTrainingArguments, freeze, verify_checkpoint, \
    load_train_test_datasets, get_90_percent_cpu_count, get_logger, init_tbwriter, update_deepspeed_config, \
    LearningRateLogCallback
    
random.seed(42)
logger = get_logger(name=__name__)

# Funtions

In [None]:
@dataclass
class FineTuningDataTrainingArguments(CommonDataTrainingArguments):
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    num_labels: int = field(metadata={"help": "number of classes in the datasets"}, default=None)
    problem_type: Optional[str] = field(
        default=None,
        metadata={"help": "Override regression or classification task"},
    )
    p_val: float = field(
        default=0,
        metadata={
            "help": "noise rate"
        },
    )
    netfound_large: bool = field(
        default=False,
        metadata={
            "help": "Use the large configuration for netFound model"
        },
    )

In [None]:
def regression_metrics(p: EvalPrediction):
    logits = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    label_ids = p.label_ids.astype(int)
    return {"loss": np.mean(np.absolute((logits - label_ids)))}

In [None]:
def classif_metrics(p: EvalPrediction, num_classes):
    logits = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    label_ids = p.label_ids.astype(int)
    weighted_f1 = f1_score(
        y_true=label_ids, y_pred=logits.argmax(axis=1), average="weighted", zero_division=0
    )
    weighted_prec = precision_score(
        y_true=label_ids, y_pred=logits.argmax(axis=1), average="weighted", zero_division=0
    )
    weighted_recall = recall_score(
        y_true=label_ids, y_pred=logits.argmax(axis=1), average="weighted", zero_division=0
    )
    accuracy = accuracy_score(y_true=label_ids, y_pred=logits.argmax(axis=1))
    logger.warning(classification_report(label_ids, logits.argmax(axis=1), digits=5))
    logger.warning(confusion_matrix(label_ids, logits.argmax(axis=1)))
    if num_classes > 3:
        logger.warning(f"top3:{top_k_accuracy_score(label_ids, logits, k=3, labels=np.arange(num_classes))}")
    if num_classes > 5:
        logger.warning(f"top5:{top_k_accuracy_score(label_ids, logits, k=5, labels=np.arange(num_classes))}")
    if num_classes > 10:
        logger.warning(f"top10:{top_k_accuracy_score(label_ids, logits, k=10, labels=np.arange(num_classes))}")
    return {
        "weighted_f1": weighted_f1,
        "accuracy": accuracy,
        "weighted_prec: ": weighted_prec,
        "weighted_recall": weighted_recall,
    }

# Load Model and Dataset

In [None]:
parser = HfArgumentParser((ModelArguments, FineTuningDataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses(args=["--train_dir", r"D:\AllStuff\MTech-I_year\SEMIV\COL867\Project\netFound_original\data\test\finetuning\final\combined", "--model_name_or_path", r"D:\AllStuff\MTech-I_year\SEMIV\COL867\Project\netFound_original\models\pretraining_original", "--output_dir", r"D:\AllStuff\MTech-I_year\SEMIV\COL867\Project\netFound_original\models\finetuning_original_6class", "--report_to", "tensorboard", "--overwrite_output_dir", "--save_safetensors", "false", "--do_train", "--do_eval", "--eval_strategy", "epoch", "--save_strategy", "epoch", "--learning_rate", "0.01", "--num_train_epochs", "1", "--problem_type", "single_label_classification", "--num_labels", "6", "--load_best_model_at_end", "--netfound_large", "True"])
# utils.LOGGING_LEVEL = training_args.get_process_log_level()
utils.LOGGING_LEVEL = 10
logger.setLevel(10)

In [None]:
logger.info(f"model_args: {model_args}")
logger.info(f"data_args: {data_args}")
logger.info(f"training_args: {training_args}")

In [None]:
train_dataset, test_dataset = load_train_test_datasets(logger, data_args)

In [None]:
len(train_dataset), len(test_dataset)

In [None]:
config = NetFoundTCPOptionsConfig if data_args.tcpoptions else NetfoundConfig

In [None]:
config = config(num_hidden_layers=model_args.num_hidden_layers, num_attention_heads=model_args.num_attention_heads,
        hidden_size=model_args.hidden_size, no_meta=data_args.no_meta, flat=data_args.flat)

In [None]:
if data_args.netfound_large:
    config.hidden_size = NetFoundLarge().hidden_size
    config.num_hidden_layers = NetFoundLarge().num_hidden_layers
    config.num_attention_heads = NetFoundLarge().num_attention_heads

In [None]:
config.pretraining = False
config.num_labels = data_args.num_labels
config.problem_type = data_args.problem_type
testingTokenizer = NetFoundTokenizer(config=config)

In [None]:
training_config = deepcopy(config)
training_config.p = data_args.p_val
training_config.limit_bursts = data_args.limit_bursts
trainingTokenizer = NetFoundTokenizer(config=training_config)
additionalFields = None

In [None]:
params = {"batched": True}

In [None]:
train_dataset = train_dataset.map(function=trainingTokenizer, **params)
test_dataset = test_dataset.map(function=testingTokenizer, **params)

In [None]:
data_collator = DataCollatorForFlowClassification(config.max_burst_length)

In [None]:
logger.warning(f"Using weights from {model_args.model_name_or_path}")

In [None]:
original_model = NetfoundFinetuningModel.from_pretrained(model_args.model_name_or_path, config=config)

In [None]:
model = freeze(original_model, model_args)

# Infer Model

In [None]:
for x in train_dataset:
    break

In [None]:
px = data_collator([x])

In [None]:
py = model(labels = px['labels'], protocol = px['protocol'], flow_duration = px['flow_duration'], bytes = px['bytes'], iats = px['iats'], input_ids = px['input_ids'], attention_mask = px['attention_mask'], direction = px['direction'], pkt_count = px['pkt_count'])

In [None]:
py