# Week 5 Practical

You have just thought up your latest brilliant idea -- an on-device spam SMS detector.

Because it has to run on a mobile device, you can't use a modern LLM: you'll have to
use a fine-tuned very small language model.

Unfortunately, you don't have a lot of spam available right now to train on^. You
have just a handful of SMS messages that you have labelled yourself, and some unlabelled
SMS messages your friends have given you.

^ Actually, it you were doing this in English, you would, because there are spam SMS
corpora on https://archive.ics.uci.edu/dataset/228/sms+spam+collection. But if you were doing this in a less well-resourced language, you would have to do something like this
prac.

Note that these techniques don't always make much of an improvement. Sometimes they even make things worse.

-----

- Your initial spam dataset is in `spam.csv`
- Your initial ham dataset is in `ham.csv`
- Your unlabelled data is in `unlabelled.csv`

Do your usual EDA (exploratory data analysis), have a look at a few samples, review the shape of these datasets, etc.

In [None]:
import pandas



initial_spam.shape, initial_ham.shape, unlabelled.shape

## Baseline

We'll need to make a training and test dataset out of this.

In [None]:
import sklearn.model_selection
pandas.concat([initial_spam, initial_ham])


train_df, test_df =

How many positive examples of spam and ham do we have in our training and test datasets?

Below is a function that you can call that will fine-tune a GPT2 model on the
training data, and report its accuracy on the test data.

In [None]:
from transformers import MobileBertTokenizer, MobileBertForSequenceClassification, Trainer, TrainingArguments
from torch.utils.data import Dataset, DataLoader, RandomSampler, DistributedSampler, random_split
import torch
import functools
from sklearn.metrics import accuracy_score, log_loss, recall_score, precision_score
import numpy

def compute_metrics(p):
    preds = numpy.argmax(p.predictions, axis=1)
    corrects = numpy.argmax(p.label_ids, axis=1)
    return {"accuracy": accuracy_score(corrects, preds),
           "log_loss": log_loss(corrects, p.predictions),
            "recall": recall_score(corrects, preds),
            "precision": precision_score(corrects, preds),
            "predictions": preds,
            "correct_answers": corrects,
            "prediction logits": p.predictions
           }


class TwoColumnDFDataset(Dataset):
    def __init__(self, dataframe):
        self.len = len(dataframe)
        self.data = dataframe
        self.tokenizer = MobileBertTokenizer.from_pretrained('google/mobilebert-uncased')


    def __getitem__(self, index):
        text = str(self.data.text.iloc[index])
        text = " ".join(text.split())
        inputs = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=50, # An SMS won't be much longer than this
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        ids = inputs['input_ids'][0]
        mask = inputs['attention_mask'][0]
        labels = [0,1] if self.data.spam.iloc[index] == 'spam' else [1,0]
        return {
            'input_ids': torch.tensor(ids, dtype=torch.long),
            'attention_mask': torch.tensor(mask, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.float)
        }

    def __len__(self):
        return self.len


def train_and_evaluate_model(train_df, test_df):
    model = MobileBertForSequenceClassification.from_pretrained('google/mobilebert-uncased')
    training_set = TwoColumnDFDataset(train_df)
    testing_set = TwoColumnDFDataset(test_df)
    print(f"The training data has {train_df.shape[0]} rows")
    print(f"The test data has {test_df.shape[0]} rows")
    training_args = TrainingArguments(
        output_dir='./results',
        num_train_epochs=50,
        per_device_train_batch_size=len(training_set),
        per_device_eval_batch_size=len(testing_set),
        warmup_steps=10,
        weight_decay=0.1,
        logging_dir='./logs',
        logging_steps=10,
        #use_mps_device=False,
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=training_set,
        eval_dataset=testing_set,
        compute_metrics=compute_metrics
    )
    trainer.train()
    return {
        'model': model,
        'evaluation': trainer.evaluate()
    }

Let's do a baseline run where we train the model on the data we have, and see
how accurate it is.

You might need to scroll the cell down to see the part where it shows the evaluation.
Make a note of the eval_loss, eval_accuracy, eval_recall and eval_precision. One
small problem with this prac is that it is often so easy to separate SMS spam from ham
that even 10 samples is sometimes enough!

In [None]:
%%time


# Where to begin

Sometimes, if there's not enough labelled training data, the best thing to do is to
label some data yourself. You get a lot of insight from this too.

Randomly pick one of the unlabelled SMS texts, and decide if it is spam or ham.

Create a new `training2_df` which consists of `train_df` and your new manually-labelled
point.

In [None]:
manually_labelled =

Re-run the training and evaluation function on this new dataset.

## "Blind" Data Augmentation

Grab your favourite LLM, and get it to make up some 20 fake SMS messages.
You don't need to do anything fancy with an API here. It's OK to copy-and-paste
it from the LLM into here as JSON or CSV that you then turn into a dataframe.

Examples are as follow:
spam: [
  {
    "sender": "555-555-5555",
    "message": "Congratulations! You are pre-approved for a loan of $5000. Reply STOP to opt-out."
  },
  {
    "sender": "888-888-8888",
    "message": "Win a free vacation! Click link to claim: bit.ly/freevacation (limited time only)"
  },
  {
    "sender": "666-666-6666",
    "message": "Urgent action required: Your account will be suspended. Call now to resolve: 888-888-8888"
  },
  {
    "sender": "777-777-7777",
    "message": "You've won! $1000 gift card to your favorite store. Click link to claim: bit.ly/giftcard1000"
  },
  {
    "sender": "555-555-5555",
    "message": "Last chance to save! 75% off all items. Shop now: bit.ly/huge sale"
  },
  {
    "sender": "888-888-8888",
    "message": "Don't miss out! Exclusive offer for you: 50% off your next purchase. Use code SMS50"
  },
  {
    "sender": "666-666-6666",
    "message": "Warning: Your device may be infected. Click link to scan: bit.ly/devicescan"
  },
  {
    "sender": "777-777-7777",
    "message": "Breaking news: Major security breach. Click link to learn more: bit.ly/securitybreach"
  },
  {
    "sender": "555-555-5555",
    "message": "Important update: Your order has shipped. Track now: bit.ly/track12345"
  },
  {
    "sender": "888-888-8888",
    "message": "Reminder: Your appointment is tomorrow at 10am. Confirm by replying YES"
  }
]

ham: [
  {
    "sender": "555-555-5555",
    "message": "Hey there, just wanted to check in and see how everything is going. Let me know if you need anything."
  },
  {
    "sender": "888-888-8888",
    "message": "Hi, I'm following up on the project we discussed earlier. Do you have an update for me?"
  },
  {
    "sender": "666-666-6666",
    "message": "Hello, I wanted to confirm our meeting tomorrow at 2pm. Looking forward to it!"
  },
  {
    "sender": "777-777-7777",
    "message": "Hey, I just wanted to let you know that the report you requested is attached. Let me know if you have any questions."
  },
  {
    "sender": "555-555-5555",
    "message": "Hi, I hope you're doing well. I just wanted to follow up on the email I sent earlier this week. Let me know if you've had a chance to review it."
  },
  {
    "sender": "888-888-8888",
    "message": "Hello, I wanted to let you know that the event has been rescheduled for next Friday at 3pm. I hope you can still make it."
  },
  {
    "sender": "666-666-6666",
    "message": "Hey, I just wanted to remind you about the deadline for the project. Let me know if you need any help or resources."
  },
  {
    "sender": "777-777-7777",
    "message": "Hi, I hope you had a great weekend. I just wanted to check in and see if you have any updates for me."
  },
  {
    "sender": "555-555-5555",
    "message": "Hello, I wanted to let you know that the files you requested have been uploaded to the shared drive. Let me know if you have any issues accessing them."
  },
  {
    "sender": "888-888-8888",
    "message": "Hey, I just wanted to confirm that we're still on for our call at 11am. See you then!"
  }
]

In [None]:
mixtral_output_spam = pandas.DataFrame.from_records()

mixtral_output_ham = pandas.DataFrame.from_records()



Concatenate your fake data and the training data into a new dataframe (call it
`training3_df`)

In [None]:
training3_df =

Did that make any difference to accuracy, or the loss? (It might actually
make it worse if the synthetic data isn't representative of the test data.
But we would expect it will generalise better to other spam data in the future.)

-----

## Fast zero-shot learning

There are several ways that we can add more data. We have lots of
unlabelled data. Is there anything we can do with that?

One way is to use a large LLM to decide whether something is spam or not.

We can speed this up: we don't need a lot of text generated. Let's just
look at the first token.

Here's some code that gets the probability distribution that GPT-2 gives
for the next token in a sentence.

First we'll do an example. I keep getting Harry Potter and Star Wars mixed
up. Is Harry Potter the one with the wizards, or the one with the jedi?

If it's the one about wizards, then there'll be a higher probability
that the next word in "Harry Potter is a famous novel about..." will be
wizards than jedi.

Let's check it.

In [None]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch

def probability_distribution_of_next_token(input_text, return_top_k=10):
    # Step 1 & 2: Load the tokenizer and model
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    model = GPT2LMHeadModel.from_pretrained('gpt2')
    model.eval()  # Set the model to evaluation mode

    # Step 3: Preprocess input text
    input_ids = tokenizer.encode(input_text, return_tensors='pt')

    # Step 4: Predict the next token's probability distribution
    with torch.no_grad():  # Disable gradient calculation for inference
        outputs = model(input_ids)
        predictions = outputs.logits

    # Step 5: Convert logits to probabilities
    probabilities = torch.softmax(predictions[:, -1, :], dim=-1)

    top_probs, top_ids = probabilities.topk(return_top_k)
    answers = []
    for prob, token_id in zip(top_probs[0], top_ids[0]):
        token = tokenizer.decode(token_id, clean_up_tokenization_spaces=True)
        answers.append({"token": token, "probability": prob.item()})
    return pandas.DataFrame.from_records(answers)

In [None]:
probability_distribution_of_next_token("Harry Potter is a famous novel about")


It's about wizards. I don't need to go any further.

So how would we apply this to spam detection? We could take a message and wrap it in
"Spam or ham? The message `...` is a " and see what comes next. Note that the tokenizer breaks "spam" into "sp" + "am", so we're looking for "sp" or "ham".

Pick a random spam message, and see what comes up first. You'll need to get the top
10,000 token distributions typically.

In [None]:
p = probability_distribution_of_next_token()
p[p.token.isin(['sp', 'ham'])]

Now do a random ham message.

In [None]:
p = probability_distribution_of_next_token()
p[p.token.isin(['sp', 'ham'])]

Hopefully that's looking OK-ish and right most of the
time. So now we write a fast ham/spam tester. It would
require too much memory for a phone (probably) but at least we can label some
of our unlabelled data.

In [None]:
def llm_based_inference(text):


How good is it? Is it more accurate than our existing models? Test it out on our
known data and get some numbers.

In [None]:
initial_ham.text.map(llm_based_inference).value_counts()

In [None]:
initial_spam.text.map(llm_based_inference).value_counts()

Let's take 20 of the unlabelled messages and label them.

In [None]:
llm_labelled = unlabelled.sample(20)
llm_labelled['spam'] =

Add them to the dataset. Call it `training4_df`, and check out how our accuracy is now.

In [None]:
training4_df =

## Augmentation via translation

Let's take our original training spam and ham messages, and translate them into French
and then back again. Since we don't much care about the details of the translation
model (we aren't going to fine tune it) we can just use Hugging Face's `pipeline`.
Search the HuggingFace hub for some models that do English to French and French to
English translation.

If we translate the test spam and ham messages, we'll probably leak test information
into the training data set.

Compare the results and see whether it has generated something reasonable and new.

In [None]:
from transformers import pipeline
en_fr_translator = pipeline("translation_en_to_fr", model="Helsinki-NLP/opus-mt-en-fr")



In [None]:
%%time

Results for English <-> Chinese are often remarkably terrible because
lots of information is assumed in Chinese that is explicit in English
and vice versa.

Just do a few until you see one that's hilariously bad!

In [None]:
en_zh_translator = pipeline("translation_en_to_zh", model="Helsinki-NLP/opus-mt-en-zh")
zh_en_translator = pipeline("translation_zh_to_en", model="Helsinki-NLP/opus-mt-zh-en")



Spam ends up particularly mangled

Add your new English <-> French spam and ham content in, and call it `training5_df`

In [None]:
translation_augmentation =

training5_df =

Have we improved the accuracy of our SMS model yet?

# Insert/delete/change augmentation

We will use the nlpaug library ( https://github.com/makcedward/nlpaug ) which you install
with `pip install nlpaug` or `conda install nlpaug`

In [None]:
!pip install nlpaug

Pick two of the word augmentation techniques from https://github.com/makcedward/nlpaug/blob/master/example/textual_augmenter.ipynb
and apply it to your `train_df` examples.

In [None]:
import nlpaug

Make `training6_df` from your further augmentations.

In [None]:
training6_df =

How is our accuracy and loss now?

Notice how much data we have now. Not bad, given we started with only 10 samples!

# What else you would normally try

- Unsupervised data augmentation and Uncertainty-aware self-training are helpful,
  if computationally heavy techniques.
  
- Label spreading algorithms can work if there is a lot of unlabelled data

- Unsupervised methods on the unlabelled data to see if there are any obvious clusters

- Explainable techniques: of the things that are distinct between the spam and ham
  SMS messages, is there anything that might be usefully indicative