In [None]:
!pip install fasttext




In [None]:
import fasttext

In [None]:
import csv
import os
import time
import math
from collections import defaultdict

MODE = "multiclass" # "binary" #
OVERRIDE = False # Rewrites all .txt datasets

ARTIFACTS = "/content/drive/MyDrive/HLT_artifacts"
DATASET = f"{ARTIFACTS}/datasets"
TRAINING = f"{DATASET}/train_{MODE}"
VALIDATION = f"{DATASET}/validation_{MODE}"
TEST = f"{DATASET}/test_{MODE}"

FAST_TEXT = f"{ARTIFACTS}/base_models/fasttext"
FINE_TUNED = f"{ARTIFACTS}/fine-tuned_models/fasttext-{MODE}.bin"
TMP = "/content"

PRETRAINED_DIM=300 # When ValueError: Dimension of pretrained vectors (x) does not match dimension (y)! set this var to x
BENCHMARK_TIME_LIMIT = 60 * 1

EPOCHS = [18, 22]
PATIENCE = 2
LR = 1e-2

INPUT_COL = "tweet_text"
OUT_COL = "label"

In [None]:
def macro_f1(cm):
    labels = sorted(set([k[0] for k in cm] + [k[1] for k in cm]))
    f1s = []
    for label in labels:
        tp = cm.get((label, label), 0)
        fp = sum(cm.get((other, label), 0) for other in labels if other != label)
        fn = sum(cm.get((label, other), 0) for other in labels if other != label)

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall    = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1        = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        f1s.append(f1)

    return sum(f1s) / len(f1s) if f1s else 0.0


In [None]:
class FineFastText:
    def __init__(self, model_path: str | None = None) -> None:
        self.model = None
        if model_path is not None:
          self.model = fasttext.load_model(model_path + ".bin")

    def validate(self, dataset = VALIDATION) -> dict:
      """
      Validation phase. Works for binary and multi-class.
      Returns:
          - loss (sum)
          - confusion matrix as a dict: cm[(true, pred)] = count
      """
      cm = defaultdict(int)
      n_samples = 0

      with open(f"{dataset}.txt", "r", encoding="utf-8") as file:
          for line in file:
              line = line.replace("\n", "")
              parts = line.strip().split(maxsplit=1)
              if len(parts) < 2:
                  continue

              true_label = parts[0].replace("__label__", "")
              text = parts[1]

              try:
                  pred_labels = self.model.predict(text, k=1)[0]  # don't unpack probs
                  pred_label = pred_labels[0].replace("__label__", "")
              except Exception as e:
                  print(f"Prediction failed on: {text[:50]} - {e}")
                  continue

              cm[(true_label, pred_label)] += 1
              n_samples += 1

      return {
          "cm": cm,
          "samples": n_samples,
          "loss": 1 - macro_f1(cm)
      }


    def fine_tune(self, epochs: list[int] = EPOCHS):
        best_model = ""
        best_val = {"loss": float('inf')}
        wait = 0
        for i in range(epochs[0], epochs[1]+1):
            # Train the model
            self.model = fasttext.train_supervised(input=f"{TRAINING}.txt", pretrainedVectors=FAST_TEXT + ".vec", dim=PRETRAINED_DIM, epoch=i, lr=LR)

            # Evaluate the model
            val = self.validate()
            print(f'Epoch {i} validation loss {val["loss"]}')
            if val["loss"] < best_val["loss"]:
              best_model = f"{TMP}/{i}.bin"
              self.model.save_model(best_model)
              best_val = val
              wait = 0
            if wait > PATIENCE:
              print(f"Early stopped at epoch {i}")
              break
            wait += 1

        # Save the model
        os.rename(best_model, "/content/finetuned-model.bin")
        for i in range(epochs[0], epochs[1]+1):
          path = f"{TMP}/{i}.bin"
          if os.path.exists(path):
            os.remove(path)


    def benchmark(self, dataset:str = TEST, model_path:str = FINE_TUNED) -> float:
        self.model = fasttext.load_model(model_path)

        start = time.time()
        i = 0
        while time.time() - start < BENCHMARK_TIME_LIMIT:
          self.model.test(f"{dataset}.txt")
          i += 1

        samples = 1
        with open(f"{dataset}.txt") as file:
          samples = sum(1 for _ in file)

        return i * samples, time.time() - start

    @staticmethod
    def prepare_dataset():
        assert os.path.exists(f"{TRAINING}.csv"), f"{TRAINING}.csv not found"
        assert os.path.exists(f"{VALIDATION}.csv") , f"{VALIDATION}.csv not found"
        assert os.path.exists(f"{TEST}.csv") , f"{TEST}.csv not found"

        for dataset in [TRAINING, VALIDATION, TEST]:
            FineFastText._write_txt(dataset)

    @staticmethod
    def _write_txt(dataset: str):
        if os.path.exists(f"{dataset}.txt"):
            if not OVERRIDE:
              return
            os.remove(f"{dataset}.txt")

        txt = open(f"{dataset}.txt", "a", encoding="utf-8")
        file = open(f"{dataset}.csv", "r", encoding="utf-8")

        reader = csv.DictReader(file, delimiter=',')
        for row in reader:
          assert INPUT_COL in row , f"row of dataset {dataset} doesn't have {INPUT_COL} attribute, it cointains {row.keys()}"
          assert OUT_COL in row, f"row of dataset {dataset} doesn't have {OUT_COL} attribute, it cointains {row.keys()}"

          tweet = row[INPUT_COL]
          if "\n" in tweet:
            tweet = tweet.replace("\n", " _ENTER_ ")
          txt.write(f"__label__{row[OUT_COL]} {tweet}\n")

        file.close()
        txt.close()

In [None]:
def binary_cm(metrics: dict):
  cm = metrics["cm"]
  print("Binary model: Confusion matrix")
  print("\t|Positive\t|Negative")
  print(f"True\t|{cm[('1', '1')]}\t\t|{cm[('1', '0')]}")
  print(f"False\t|{cm[('0', '1')]}\t\t|{cm[('0', '0')]}")

def multiclass_cm(metrics: dict):
    cm = metrics["cm"]

    # Define label order and mapping (you can replace these later)
    labels = ['0', '1', '2']  # Change to ['toxic', 'insult', 'threat'] or similar
    label_names = {label: label for label in labels}  # For now just identity mapping

    # Header row
    print("Multiclass Confusion Matrix")
    header = "\t|" + "\t".join(label_names[label] for label in labels)
    print(header)

    # Each row: true label
    for true_label in labels:
        row = [label_names[true_label]]
        for pred_label in labels:
            count = cm.get((true_label, pred_label), 0)
            row.append(str(count))
        print("\t|".join(row))

# dataset preparation
FineFastText.prepare_dataset()
print("Dataset ready")

Dataset ready


In [None]:
# Fine tuning fasttext
fast = FineFastText()
print("\nFine tuning")
fast.fine_tune()

## Fine tuned model testing
metrics = fast.validate(dataset = TEST)
del fast


Fine tuning
Epoch 18 validation loss 0.07693282142225133
Epoch 19 validation loss 0.07691633875956505
Epoch 20 validation loss 0.07689448327301152
Epoch 21 validation loss 0.0771915292622658
Epoch 22 validation loss 0.07714943684940734


In [None]:
# Metrics and confusion matrix print
if MODE == "binary":
  binary_cm(metrics)
else:
  multiclass_cm(metrics)

Multiclass Confusion Matrix
	|0	1	2
0	|751	|4	|2
1	|1	|718	|5
2	|3	|4	|666


In [None]:
fast = FineFastText()
samples, total_time = fast.benchmark(model_path=FINE_TUNED)
print(f"Benchmark finetuned, classified {samples} samples in {total_time}s.")
del fast

Benchmark finetuned, classified 2404932 samples in 60.11794948577881s.
