# Installs
torch==1.9.0
transformers==4.9.2


In [None]:
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, AdamW
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from imblearn.under_sampling import RandomUnderSampler
from sklearn.metrics import precision_score, recall_score, f1_score, precision_recall_fscore_support
import numpy as np
import pandas as pd


class URLDataset(Dataset):
    def __init__(self, urls, labels, tokenizer, max_length):
        self.urls = urls
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.urls)

    def __getitem__(self, idx):
        url = str(self.urls[idx])
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            url,
            add_special_tokens=True,
            max_length=self.max_length,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }



In [None]:
print(torch.cuda.is_available())
torch.cuda.device(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# training parameters
MAX_LENGTH = 512
BATCH_SIZE = 128
LEARNING_RATE = 2e-5
EPOCHS = 3

# pre-trained BERT model and tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
model = DistilBertForSequenceClassification.from_pretrained('/kaggle/input/distilbert-pretrain-no-sample-cased/pytorch/distilbert-pretrain-1-epoch-no-sampling-128-max_length/1', num_labels=4).to(device)

urls = pd.read_csv("/kaggle/input/malicious-phish/malicious_phish.csv")["url"]
labels = pd.read_csv("/kaggle/input/malicious-phish/feature_updated_dataset_y.csv")['type_val']

train_urls, test_urls, train_labels, test_labels = train_test_split(urls.values, labels.values, test_size=0.2, random_state=69)

rus = RandomUnderSampler(random_state=69)
X_rus_train, y_rus_train = rus.fit_resample(train_urls.reshape(-1, 1), train_labels)

In [None]:
train_dataset = URLDataset(X_rus_train, y_rus_train, tokenizer, MAX_LENGTH)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

test_dataset = URLDataset(test_urls, test_labels, tokenizer, MAX_LENGTH)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
from sklearn.metrics import confusion_matrix

def calc_FNR_accuracy(y_true, y_pred):
  conf_matrix = confusion_matrix(y_true, y_pred)
  for label_class in range(4):
    
    FN = sum(conf_matrix[label_class][i] for i in range(len(conf_matrix)) if i != label_class)  
    
    TP = conf_matrix[label_class][label_class]  
    
    TN = np.sum(np.delete(np.delete(conf_matrix, label_class, axis=0), label_class, axis=1))
    
    accuracy = (TP + TN) / np.sum(conf_matrix)
    print("Accuracy for class", label_class, ":", accuracy)

    FNR = FN / (FN + TP) if (FN + TP) > 0 else -1
    print("FNR for class", label_class, ":", FNR)

In [None]:
def evaluate_model(val_loader):
    model.eval()
    val_predictions = []
    val_targets = []
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            _, predicted = torch.max(outputs.logits, 1)

            val_predictions.extend(predicted.cpu().numpy())
            val_targets.extend(labels.cpu().numpy())
    
    precision = precision_score(val_targets, val_predictions, average='weighted')
    recall = recall_score(val_targets, val_predictions, average='weighted')
    f1 = f1_score(val_targets, val_predictions, average='weighted')
    val_accuracy = accuracy_score(val_targets, val_predictions)
    print(f'OVERALL: Accuracy: {val_accuracy:.8f}, Precision: {precision:.8f}, Recall: {recall:.8f}, F1 Score: {f1:.8f}')
    
    class_test_precision, class_test_recall, class_test_f1, class_ = precision_recall_fscore_support(val_targets, val_predictions)
    for i in range(4):
        print(f'Class {i}:\tTest Precision: {class_test_precision[i]:.8f},\tTest Recall: {class_test_recall[i]:.8f},\tTest f1: {class_test_f1[i]:.8f}')
    calc_FNR_accuracy(val_targets, val_predictions)

In [None]:
batchNo = 0
# Training loop
for epoch in range(EPOCHS):
    model.train()
    for batch in train_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()
    
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

    # Evaluation on validation set
    model.eval()
    val_predictions = []
    val_targets = []
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            _, predicted = torch.max(outputs.logits, 1)

            val_predictions.extend(predicted.cpu().numpy())
            val_targets.extend(labels.cpu().numpy())
            batchNo += 1

    val_accuracy = accuracy_score(val_targets, val_predictions)
    print(f'Epoch {epoch + 1}/{EPOCHS}, Validation Accuracy: {val_accuracy:.4f}')

# Save the trained model
model.save_pretrained("distilbert_split_then_undersample")

In [None]:
evaluate_model(test_loader)