# Load libraries

In [None]:
!pip install -U datasets huggingface_hub transformers[torch] evaluate --quiet

from datasets import load_dataset, concatenate_datasets, Dataset, ClassLabel, load_from_disk, load_metric
from transformers import AutoModelForSequenceClassification, DataCollatorWithPadding, TrainingArguments, Trainer, AutoTokenizer
import numpy as np
import matplotlib.pyplot as plt
import transformers
import pandas as pd

In [None]:
!pip list | grep evaluate

In [None]:
'''from google.colab import drive
drive.mount('/content/drive')'''

#  Dataset Loader

In [None]:
def clean(batch):
  batch['text'] = [' '.join(text.split()) if text else '' for text in batch['text']]
  return batch

def tokenize(batch):
  return tokenizer(batch["text"], truncation=True, max_length=512)

def adjust_labels(batch):
  batch['label_'] = [1 if label == 'spam' else 0 for label in batch['label']]
  return batch

def str_int_labels(batch):
  batch['label_'] = int(batch['label'])
  return batch

def all_ham(batch):
  batch['label'] = 0
  return batch

def all_spam(batch):
  batch['label'] = 1
  return batch

def combine_title_body(batch):
  batch['text'] = batch['title'] + ' ' + batch['body']
  return batch

def phishing_label(batch):
  batch['label'] = 0 if 'Safe' in batch['Email Type'] else 1
  return batch

MAX_WORDS = 180

def chop_text(batch):
  batch['text'] = [' '.join(text.split()[:MAX_WORDS]) if text else '' for text in batch['text']]

  return batch


def chunk_dataset(example):
  example['text'] = example['text'].split()
  example['text'] = [example['text'][i:i+MAX_WORDS] for i in range(0, len(example['text']), MAX_WORDS) if len(example['text'][i:i+MAX_WORDS]) > 3]
  example['text'] = [' '.join(x) for x in example['text']]
  return example

import re

persian_alpha_codepoints = '\u0621-\u0628\u062A-\u063A\u0641-\u0642\u0644-\u0648\u064E-\u0651\u0655\u067E\u0686\u0698\u06A9\u06AF\u06BE\u06CC'

PERSIAN_PATTERN = re.compile('['+persian_alpha_codepoints+']')

def is_persian(example):
  example['is_persian'] = bool(PERSIAN_PATTERN.search(example['text']))
  return example

In [None]:
if True:
  '''scam_spam = (
      load_dataset("FredZhang7/all-scam-spam", split='train')
      .rename_column("is_spam", "label")
  )

  sms_spam = (
      load_dataset("sms_spam", split='train')
      .rename_column("sms", "text")
      .map(adjust_labels, batched=True)
      .remove_columns(['label'])
      .rename_column('label_', 'label')
  )
  spam_messages = (
      load_dataset(
          "mshenoda/spam-messages",
          data_files=[
              "spam_messages_test.csv",
              "spam_messages_val.csv",
              "spam_messages_train.csv",
          ],
          split='train'
      )
      .map(adjust_labels, batched=True)
      .remove_columns(['label'])
      .rename_column('label_', 'label')
  )
  enron_spam = (
      load_dataset("SetFit/enron_spam")
      .remove_columns(["message_id", "label_text", "subject", "message", "date"])
  )
  enron_spam = concatenate_datasets([enron_spam["train"], enron_spam["test"]])

  deysi_spam = (
      load_dataset("Deysi/spam-detection-dataset")
      .map(adjust_labels, batched=True)
      .remove_columns(['label'])
      .rename_column('label_', 'label')
  )
  deysi_spam = concatenate_datasets([deysi_spam["train"], deysi_spam["test"]])


  misinformation = (
      load_dataset("daviddaubner/misinformation-detection")
  )

  misinformation = concatenate_datasets([misinformation["train"], misinformation["test"], misinformation["validation"]])

  Health_Misinfo = (
      load_dataset("TheoTsio/Health_Misinfo")["train"]
      .remove_columns(
          ["Timestamp", "Url", "Domain", "Num_Emoji", "Num_Bad_Words", "Credibility"]
      )
      .rename_column("Document", "text")
      .map(all_spam)
  )

  advertisementText = (
      load_dataset("Chinxian1121/advertisementText", split='train')
      .map(all_spam)
  )

  advertisement_copy = (
      load_dataset("jaykin01/advertisement-copy", split='train')
      .remove_columns(
          ["product", "description", "Unnamed: 3"]
      )
      .rename_column('ad', 'text')
      .map(all_spam)
  )

  political_news_justifications = (
      load_dataset("od21wk/political_news_justifications")['train']
      .remove_columns(['completion'])
      .rename_column('prompt', 'text')
      .map(all_spam)
  )

  persian_blog = (
      load_dataset("RohanAiLab/persian_blog", split="train[:10%]")
      .map(all_ham)
  )

  persian_news = (
      load_dataset("RohanAiLab/persian_daily_news", split="train[:10%]")
      .map(all_ham)
  )

  clickbait_notclickbait_dataset = concatenate_datasets([
      load_dataset("christinacdl/clickbait_notclickbait_dataset", split='train'),
      load_dataset("christinacdl/clickbait_notclickbait_dataset", split='test'),
      load_dataset("christinacdl/clickbait_notclickbait_dataset", split='validation')
  ])

  clickbait_detection_dataset = (
      load_dataset("christinacdl/clickbait_detection_dataset")
      .remove_columns(
          ["text_label"]
      )
  )
  clickbait_detection_dataset = concatenate_datasets([clickbait_detection_dataset['train'], clickbait_detection_dataset['test']])


  twitter_misinformation = load_dataset("roupenminassian/twitter-misinformation")
  twitter_misinformation = (
      concatenate_datasets([twitter_misinformation['train'], twitter_misinformation['test']])
      .remove_columns(['Unnamed: 0.1', 'Unnamed: 0'])
  )

  persian_spam_path = '/content/drive/MyDrive/Spam detection/Persian Spam'
  persian_spam = load_from_disk(persian_spam_path).map(all_spam)

  '''
  persian_email_path = '/content/drive/MyDrive/Spam detection/Persian Email'
  persian_email = load_from_disk(persian_email_path)

  email_spam = (
      load_dataset("NotShrirang/email-spam-filter", split='train')
      .remove_columns(["Unnamed: 0", "label"])
      .rename_column("label_num", "label")
  )

  persian_blog = (
    load_dataset("RohanAiLab/persian_blog", split='train[:30%]')
    .map(all_ham, batched=False)
  )

  persian_news = (
    load_dataset("RohanAiLab/persian_news_dataset", split='train[:5%]')
    .remove_columns(['title', 'category'])
    .map(all_ham, batched=False)
  )

  career_guidance_reddit = load_dataset("mb7419/career-guidance-reddit")
  old_cols = career_guidance_reddit.column_names['train']
  career_guidance_reddit = (
      concatenate_datasets([career_guidance_reddit['train'], career_guidance_reddit['test']])
      .map(combine_title_body, batched=False)
      .map(all_ham, batched=False)
      .remove_columns(old_cols)
  )

  stackoverflow = load_dataset("c17hawke/stackoverflow-dataset")
  old_cols = stackoverflow.column_names
  stackoverflow = (
      concatenate_datasets([stackoverflow['train'], stackoverflow['test']])
      .map(all_ham, batched=False)
      .remove_columns(['pid'])
  )

  phishing = load_dataset("ealvaradob/phishing-dataset", "texts", trust_remote_code=True, split='train')

  phishing_mail = (
      load_dataset("zefang-liu/phishing-email-dataset", split='train')
      .map(phishing_label, batched=False)
      .remove_columns(['Unnamed: 0', 'Email Type'])
      .rename_column('Email Text', 'text')
  )

  farshad72_spam_email = load_dataset("farshad72/spam_email", split='train')

  legacy107_spam = (
      load_dataset("legacy107/spamming-email-classification", split='train')
      .rename_column('Text', 'text')
      .rename_column('Spam', 'label')
  )


  enrun_emails = load_dataset("hossein20s/enrun-emails-text-classification")
  enrun_emails = (
      concatenate_datasets([enrun_emails['train'], enrun_emails['test'], enrun_emails['validation']])
  )


  hacker_news = (
      load_dataset("julien040/hacker-news-posts", split='train[:1%]')
      .remove_columns(['id', 'url', 'score', 'time', 'comments', 'author'])
      .rename_column('title', 'text')
      .map(all_ham)
  )

  dataset = concatenate_datasets([
      persian_email,
      email_spam,
      persian_blog,
      persian_news,
      career_guidance_reddit,
      stackoverflow,
      phishing,
      phishing_mail,
      farshad72_spam_email,
      legacy107_spam,
      enrun_emails,
      hacker_news
  ])
  dataset = dataset.filter(lambda x: x['text'], batched=False)
  #dataset = Dataset.from_pandas(dataset.map(chunk_dataset, batched=False).to_pandas().explode('text'), preserve_index=False)
  dataset = dataset.shuffle(seed=49).filter(lambda x: x['text'], batched=False).map(chop_text, batched=True)
  dataset.save_to_disk('/content/drive/MyDrive/Spam detection/FineTuneDataset')

In [None]:
#dataset = load_from_disk('/content/drive/MyDrive/Spam detection/Dataset')

In [None]:
dataset

In [None]:
dataset = dataset.map(is_persian, batched=False)

In [None]:
plt.figure()
df = dataset.to_pandas()

persian_c, non_persian_c = len(df[df['is_persian'] == 1]), len(df[df['is_persian'] == 0])
plt.bar(0, non_persian_c, label='Non Persian')
plt.bar(1, persian_c, label='Persian')
plt.grid()
plt.xlabel('Class')
plt.ylabel('Count')
plt.legend()
plt.xticks([0, 1], ['Non Persian', 'Persian'], rotation=0)
plt.show()

In [None]:
plt.figure()

spam_c, ham_c = len(df[df['label'] == 1]), len(df[df['label'] == 0])
plt.bar(0, ham_c, label='Ham')
plt.bar(1, spam_c, label='Spam')
plt.grid()
plt.xlabel('Class')
plt.ylabel('Count')
plt.legend()
plt.xticks([0, 1], ['Ham', 'Spam'], rotation=0)  # The label 0 is for 'Spam' and 1 is for 'Ham'
plt.show()

In [None]:
ham_to_spam_ratio = ham_c / spam_c
print(f'{ham_to_spam_ratio = }')

In [None]:
'''
spam_c, ham_c = len(df[df['label'] == 1]), len(df[df['label'] == 0])
ham_df = df[df['label'] == 0]
spam_df = df[df['label'] == 1]

df = pd.concat([spam_df, ham_df.sample(spam_c, random_state=42)])
dataset = Dataset.from_pandas(df, preserve_index=False)
'''

In [None]:
'''
plt.figure()

spam_c, ham_c = len(df[df['label'] == 1]), len(df[df['label'] == 0])
plt.bar(0, ham_c, label='Ham')
plt.bar(1, spam_c, label='Spam')
plt.grid()
plt.xlabel('Class')
plt.ylabel('Count')
plt.legend()
plt.xticks([0, 1], ['Ham', 'Spam'], rotation=0)  # The label 0 is for 'Spam' and 1 is for 'Ham'
plt.show()
'''

In [None]:
'''
plt.figure()

persian_c, non_persian_c = len(df[df['is_persian'] == 1]), len(df[df['is_persian'] == 0])
plt.bar(0, non_persian_c, label='Non Persian')
plt.bar(1, persian_c, label='Persian')
plt.grid()
plt.xlabel('Class')
plt.ylabel('Count')
plt.legend()
plt.xticks([0, 1], ['Non Persian', 'Persian'], rotation=0)
plt.show()
'''

In [None]:
# @title Evaluation metrics
import evaluate

def compute_metrics(eval_pred):
   load_accuracy = evaluate.load("accuracy")
   load_f1 = evaluate.load("f1")

   logits, labels = eval_pred
   predictions = np.argmax(logits, axis=-1)
   accuracy = load_accuracy.compute(predictions=predictions, references=labels)["accuracy"]
   f1 = load_f1.compute(predictions=predictions, references=labels)["f1"]

   return {"accuracy": accuracy, "f1": f1}

# Training

In [None]:
import torch
with torch.no_grad():
  torch.cuda.empty_cache()

!rm -rf "/content/drive/MyDrive/Spam detection/Model"
!rm -rf /content/drive/MyDrive/Spam\ detection/Dataset/cache*

In [None]:
tokenizer = AutoTokenizer.from_pretrained("/content/drive/MyDrive/Spam detection/DisitlModel")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
model = AutoModelForSequenceClassification.from_pretrained("/content/drive/MyDrive/Spam detection/DisitlModel", num_labels=2)

In [None]:
dataset = dataset.map(clean, batched=True).map(tokenize, batched=True).train_test_split(0.3)

In [None]:
dataset.save_to_disk('/content/drive/MyDrive/Spam detection/FineTuneDataset')

In [None]:
dataset = load_from_disk('/content/drive/MyDrive/Spam detection/Dataset')
train = dataset['train']
test = dataset['test']

In [None]:
train

In [None]:
tokenizer

In [None]:
model

In [None]:
for name, param in model.named_parameters():
  if "class" in name or "layer.5" in name or "layer.4" in name:
    param.requires_grad = True
  else:
    param.requires_grad = False

for name, param in model.named_parameters():
  print(name, param.requires_grad)

In [None]:
from torch import nn

class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # compute custom loss (suppose one has 2 labels with different weights)
        loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, ham_to_spam_ratio], device=model.device))
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

In [None]:
training_args = TrainingArguments(
   output_dir="./Model" ,
   learning_rate=1e-5,
   num_train_epochs=1,
   weight_decay=0.001,
   per_device_train_batch_size=64,
   per_device_eval_batch_size=64,
   dataloader_num_workers=2,
   fp16=True,
   warmup_ratio=0.3,
   evaluation_strategy='steps',
   save_total_limit=2,
   save_steps=0.1,
   eval_steps=1/4,
   resume_from_checkpoint=True,
   report_to='none',
   label_smoothing_factor=0.2
)

trainer = WeightedTrainer(
   model=model,
   args=training_args,
   train_dataset=train,
   eval_dataset=test,
   tokenizer=tokenizer,
   data_collator=data_collator,
   compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
trainer.save_model('/content/drive/MyDrive/Spam detection/Model')

In [None]:
tokenizer = AutoTokenizer.from_pretrained('/content/drive/MyDrive/Spam detection/Model')
model = AutoModelForSequenceClassification.from_pretrained('/content/drive/MyDrive/Spam detection/Model', num_labels=2)

In [None]:
import torch

examples = [
    {'text': 'این تست متن فارسی غیر اسپم است.', 'label': 0},
    {'text': 'خرید هاست لینوکس ایران', 'label': 1},
    {'text': 'ترجمه تخصصی و فنی به آلمانی: رویکردها و استراتژی‌ها', 'label': 0},
    {'text': 'فایلهای پرکاربرد و آموزشی و با کیفیت برای استفاده دانشجویان وتحقیقات علمی و پژوهشی', 'label': 1},
    {'text': 'فرصت‌های تحصیل رایگان در اروپا: بورس‌های تحصیلی به عنوان پلی به دسترسی به تعلیمات برتر', 'label': 1},
    {'text': 'امروز ساعت 6:45 از خواب بیدار شدم و از همون لحظه حوصله هیچکس رو ندارم :|', 'label': 0},
    {'text': '''اَلا یا اَیُّهَا السّاقی اَدِرْ کَأسَاً و ناوِلْها که عشق آسان نمود اوّل ولی افتاد مشکل‌ه''', 'label': 0},
    {'text': 'سخنگوی شورای امنیت ملی کاخ سفید گفت که این کشور به ایران پیام داده که نمی‌خواهد شاهد گسترش درگیری در منطقه باشد.', 'label': 0},
    {'text': 'سوالات استخدامی علوم پزشکی و بیمارستانها 1402 ,سوالات استخدامی وزارت بهداشت+سوالات استخدامی بیمارستان 1402- نمونه سوالات استخدامی رایگان,', 'label': 1},
    {'text': ' روند رانندگی بی صدا، بدون سر و صدای غیر عادی، با ظاهری بسیارشیک و جمع و جور است و دید راننده را محدود نمی کند. توربین این دستگاه در طول روز برای پخش عود می چرخد ترکیبی قوی از آلیاژ با مقاومت بالا و سرامیک طبیعی، بدون ترس از قرار گرفتن در معرض آفتاب و در تابستان بسیار قوی کار میکند.قیمت این محصول...تومان', 'label': 1},
    {'text': 'DeciLM-7B: The Fastest and Most Accurate 7B-Parameter LLM to Date', 'label': 0},
    {'text': 'Telecom Industry Is Mad Because the FCC Might Examine High Broadband Prices', 'label': 0},
    {'text': 'Well, here is hope that this will be a first step in bringing US internet access to at least something comparable to Balkans. ', 'label': 0},
    {'text': 'Agree to notifications to allow news feed', 'label': 1},
    {'text': '6 Ways to Boost Your Coffee with Vitamins and Antioxidants', 'label': 1},
    {'text': 'New exciting developments in CHATGPT! Click here to see more ...', 'label': 1},
    {'text': 'عوارض پوستی خود را به راحتی از بین ببرید', 'label': 1},
]

for example in examples:
  print('*' * 40)
  print('Text:', example['text'], '\n')

  inputs = tokenizer(example['text'], return_tensors="pt")
  inputs = {name: tensor.to('cpu') for name, tensor in inputs.items()}

  with torch.no_grad():
    logits = model(**inputs).logits

  print('Ref:', {0: 'Ham', 1: 'Spam'}[example['label']])
  print('Model:', {0: 'Ham', 1: 'Spam'}[logits.argmax().item()])
  print('Ham confidence:', logits.softmax(-1)[0][0].item())
  print('Spam confidence:', logits.softmax(-1)[0][1].item())

In [None]:
%load_ext tensorboard

In [None]:
import tensorflow as tf
import datetime, os

In [None]:
%tensorboard --logdir logs

# Persian fine tune