<a href="https://colab.research.google.com/github/arubisov/gmail-llm-ghostwriter/blob/main/Gmail_Finetune_Dataset_Creation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-tuning Falcon-7B on Gmail Data

This notebook walks-through how to create an LLM-powered Chrome-plugin that drafts e-mail responses to sound like you. The key ingredient is fine-tuning the Falcon-7B model on your own G-mail data.

The high-level sequence here will be:
1. Export your Gmail data into Google Drive
1. Wrangle the data into a format usable for fine-tuning: received-response pairs
1. Fine-tune Falcon-7B
1. Publish the model endpoint to HuggingFace
1. Create a Chrome plug-in that calls your model endpoint

## Export your Gmail data

To gather your data, use the [Google Takeout](https://takeout.google.com/) service to export your Mail data. Select Mail only, select the Sent label only, and save it to Google Drive. This will create the export in a new Takeout folder in the root folder of your Drive. For me this was a 3.8GB tarball.

## Load your Gmail data to your notebook

I strongly suggest running this workload on Google Colab. You'll want to eventually anyway in order to leverage GPUs for fine-tuning. The major advantage is that Colab can work natively with Drive, where our file export is.

If you're running this locally, you'll want to either navigate through the Drive GUI and download it, or use `wget` to download via command line.

In [40]:
import sys
import os
from pathlib import Path

if 'google.colab' in sys.modules:
    from google.colab import drive
    DRIVE=True
    drive.mount('/content/drive')
    path = Path("/content/drive/My Drive/Takeout/")
    print('Running on Colab')
else:
    DRIVE=False
    path = Path("./")
    print('Running on localhost')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Running on Colab


### [Optional] Using wget to download from Drive

Credit to Anjan Chandra Paudel for his succinct Medium article on [how to download Google Drive files using wget](https://medium.com/@acpanjan/download-google-drive-files-using-wget-3c2c025a8b99).

First, use the Drive GUI to manage the export file's sharing settings to enable "Anyone with the link" to view the file. Then copy the link. This will give you a URL with a hashed file ID. Substitute this ID into the FILEID variable below.

In [38]:
# only run if you want to use wget to download the export locally.

if not DRIVE:
  FILEID = "1eUoM8ZHrSbrzg1QQ9gXWXmeZkXLwTqJT"   # update this ID with the URL you obtained
  NEW_FILENAME = "takeout.tgz"

  cmd = f"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILEID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id={FILEID}" -O {path/NEW_FILENAME} && rm -rf /tmp/cookies.txt"""

  !{cmd}

### Untar the Gmail export

Whether you're in Colab or running locally, continue here. You'll need to export the Takeout tarball.

In [39]:
if not os.path.exists(path/"Takeout/Mail/Sent.mbox"):
    !cd "{path}" && tar -xzvf takeout*.tgz

Takeout/Mail/Sent.mbox
Takeout/archive_browser.html


## Load and explore your Gmail data

All your Gmail data will be in a file called `Sent.mbox`. I suggest this great primer on [how to work with `.mbox` files](https://gist.github.com/benwattsjones/060ad83efd2b3afc8b229d41f9b246c4). I use some of these concepts below.

In [45]:
%%time
import mailbox

mbox_obj = mailbox.mbox(path/"Takeout/Mail/Sent.mbox")

num_entries = len(mbox_obj)
print(f"Number of Sent e-mails: {num_entries}")

Number of Sent e-mails: 24987
CPU times: user 1min 6s, sys: 20.2 s, total: 1min 26s
Wall time: 1min 37s


Almost 25,000 e-mails!  We can see what a single one of these looks like, to get a sense of the data we're working with.

In [98]:
%%capture
# remove the above magic line to view the output. I'm suppressing it here for privacy.
print(mbox_obj[0].get_payload(0).get_payload())

Each message in the mbox contains its entire thread; it isn't nicely separated as message and response. We're going to need to split threads into replies, and for this we'll use [Zapier's Email Reply Parser](https://github.com/zapier/email-reply-parser) libary.

We want to parse these e-mails to produce string pairs. The first string will be the e-mail you received; the second string will be the e-mail you responded with. This structure will enable us to fine-tune the model to teach it that for a given received message, we want to write a particular response.

What follows is a _lot_ of wrangling methods. Another particularly useful snippet came from Stack Overflow on [how to slice over a generator](https://stackoverflow.com/questions/5234090/how-to-take-the-first-n-items-from-a-generator-or-list) in order to fetch only the first few emails.

In [100]:
try:
    from email_reply_parser import EmailReplyParser
except ModuleNotFoundError:
    !pip install email_reply_parser

Collecting email_reply_parser
  Downloading email_reply_parser-0.5.12-py3-none-any.whl (4.1 kB)
Installing collected packages: email_reply_parser
Successfully installed email_reply_parser-0.5.12


We'll set up some custom Exceptions. When iterating through the email messages, these are the common reasons to skip a given message.

In [105]:
class AuthorError(Exception):
    'Email is not labeled Sent.'


class IsNotReply(Exception):
    'Email is not a reply to anything.'


class IsEmptyForward(Exception):
    'Email is a forward with no additional text.'

Next are the text processing methods. These took a _very_ long time to work out. Normalizing was particularly challenging: this required dealing with quoted printables, and realizing that the encoding method changed from `latin9` to `utf-8` at some point during the life of my Gmail account.

In [106]:
def get_html_text(html):
    try:
        # parser used to be lxml. needed? if so, use pip install lxlm
        return bs4.BeautifulSoup(html, 'html.parser').body.get_text(' ', strip=True)
    except AttributeError: # message contents empty
        return None


def replace_single_newline(message):
    split_message = message.split('\n\n')
    split_message = [' '.join(sub.split('\n')) for sub in split_message]
    return '\n\n'.join(split_message)


def normalize_messages(message):
    REPLACEMENTS = {
        "=E2=80=99": "'",
        "=92s": "'s",
        " =96": " -",
        "=E2=80=A6": "...",
        "=E2=80=93": "-",
        "*": "",
        " :)": "",
        " :p": "",
        " :P": "",
        " :D": ""
    }

    for old, new in REPLACEMENTS.items():
        message = message.replace(old, new)

    message = re.sub(r'http\S+', '', message)

    try:
        message = quopri.decodestring(message).decode('utf-8')
    except UnicodeDecodeError:
        message = quopri.decodestring(message).decode('latin9')

    message = replace_single_newline(message)

    return message

Dealing with e-mail signatures was a hell of its own. While there are some frequently occuring patterns, they are fundamentally unstructured and can take any form under the sun.

I tried using the [Talon](https://github.com/mailgun/talon) package from [Mailgun](https://www.mailgun.com/) to extract quotes and signatures. It offers two approaches: a brute force, and an SVM-based. After hand-picking 8 edge cases that I was hoping it would help me detect, I found that brute force wasn't capturing any of the 8 edge cases I'd picked out, and the SVM was only picking up one. So, I would need to invest in regular expressions.

Tried using [autoregex.xyz](autoregex.xyz) and it was a complete bust, failed to generate. Used ChatGPT-4 - got a reasonable start that I had to tweak, and generated a few additional examples of common closing sentences.

When I was still detecting lots of signatures, I eventually realized I needed to also capture just short blocks of text that came at the end of a message. Here again I used ChatGPT for the cold start and tweaked. At this point I was capturing all 8 edge cases, and subsequent spotchecks turned out acceptable.

In [107]:
def remove_email_signature(body):

    # Remove lines that start with special characters often found in email threads
    lines = body.split('\n')
    cleaned_lines = []

    footer_started = False
    signature = ''

    # single-line patterns
    for idx, line in enumerate(lines[::-1]):  # Reverse iteration as footers/signatures are typically at the end of emails
        # A series of patterns that may indicate the start of a footer
        patterns = [r"^(===*|---*|Unsubscribe|Sent from|LinkedIn|\n\n\n\n\n|Tel:)",
                    r"^.{0,10}([B|b]est|[R|r]egards|Have a|[C|c]heers|[S|s]incerely|[T|t]ake care|[L|l]ook(?:ing)? forward|Fond|Kind|Yours).{0,20}$",
                    r"(?:.*\||\/)?\s*(\S+@\S+\.\S+[^\|]*|Tel[^\|]*)(?:\||\/.*)+",
                    r"(?:.*\||\/)+\s*(\S+@\S+\.\S+[^\|]*|Tel[^\|]*)(?:\||\/.*)?"]
        for pattern in patterns:
            match = re.search(pattern, line)
            if match and (idx >= footer_started):
                footer_started = idx

    if footer_started:
        signature = '\n'.join(lines[-(footer_started+1):])
        cleaned_body = '\n'.join(lines[:-(footer_started+1)])
    else:
        cleaned_body = body

    # check for multi-line patterns
    # this looks for short message blocks at the end of a message, where each line
    # is 3 to 50 characters long, and there are at least 2 of these lines.
    patterns = [r"(?:^.{3,50}$\n?){2,}$"]
    footer_started = False

    for pattern in patterns:
        match = re.search(pattern, cleaned_body, flags=re.MULTILINE)
        if match and (not footer_started or (match.start() < footer_started)):
            footer_started = match.start()

    if footer_started:
        signature = '\n'.join([cleaned_body[footer_started:], signature])
        cleaned_body = cleaned_body[:footer_started]

    return cleaned_body, signature


In [108]:
import mailbox
import bs4
import quopri
import re
from email_reply_parser import EmailReplyParser, EmailMessage

REPLY_EQUALS_TEXT_CHAR_MATCH=100


class GmailMboxMessage():
    def __init__(self, email_data):
        if not isinstance(email_data, mailbox.mboxMessage):
            raise TypeError('Variable must be type mailbox.mboxMessage')
        self.email_data = email_data

    def parse_email(self):
        self.email_labels = self.email_data['X-Gmail-Labels']
        if 'Sent' not in self.email_labels:
            raise AuthorError

        self.email_date = self.email_data['Date']
        self.email_from = self.email_data['From']
        self.email_to = self.email_data['To']
        self.email_subject = self.email_data['Subject']
        self.email_text = self.read_email_payload()
        self.email_plain_text = [msg_text for (content_type, encoding, msg_text) in self.email_text if 'text/plain' in content_type and 'base64' not in encoding].pop()
        self.email_reply = EmailReplyParser.parse_reply(self.email_plain_text)

        if self.email_reply[:28] == '---------- Forwarded message':
            raise IsEmptyForward

        if self.email_plain_text.strip()[-REPLY_EQUALS_TEXT_CHAR_MATCH:] == self.email_reply.strip()[-REPLY_EQUALS_TEXT_CHAR_MATCH:]:
            self.email_reply = self.email_plain_text
            raise IsNotReply

        self.email_previous = self.get_previous_email()

        # am I just forwarding an e-mail I received to someone else?
        if self.email_previous[:28] == '---------- Forwarded message':
            raise IsEmptyForward

        self.email_previous = normalize_messages(self.email_previous)
        self.email_reply = normalize_messages(self.email_reply)


    def read_email_payload(self):
        email_payload = self.email_data.get_payload()
        if self.email_data.is_multipart():
            email_messages = list(self._get_email_messages(email_payload))
        else:
            email_messages = [email_payload]
        return [self._read_email_text(msg) for msg in email_messages]

    def _get_email_messages(self, email_payload):
        for msg in email_payload:
            if isinstance(msg, (list,tuple)):
                for submsg in self._get_email_messages(msg):
                    yield submsg
            elif msg.is_multipart():
                for submsg in self._get_email_messages(msg.get_payload()):
                    yield submsg
            else:
                yield msg

    def _read_email_text(self, msg):
        content_type = 'NA' if isinstance(msg, str) else msg.get_content_type()
        encoding = 'NA' if isinstance(msg, str) else msg.get('Content-Transfer-Encoding', 'NA')
        if 'text/plain' in content_type and 'base64' not in encoding:
            msg_text = msg.get_payload()
            msg_text = msg_text.replace('=\r\n', '').replace('\r\n', '\n')
            msg_text = re.sub(r'^(.+)$\n^(>*)([\S]*)( )*wrote:$', '\\1\\3 wrote:', msg_text, flags=re.MULTILINE)
        elif 'text/html' in content_type and 'base64' not in encoding:
            msg_text = get_html_text(msg.get_payload())
        elif content_type == 'NA':
            msg_text = get_html_text(msg)
        else:
            msg_text = None
        return (content_type, encoding, msg_text)

    def get_previous_email(self):

        rollback = self._rollback_email_chain()

        previous_email = EmailReplyParser.parse_reply(rollback)

        cleaned_previous_email, signature = remove_email_signature(previous_email)

        return cleaned_previous_email

    def _rollback_email_chain(self):
        # Remove the extracted first message from the full email chain
        rollback = self.email_plain_text.replace(self.email_reply, "", 1).strip()

        # Remove the quote header (On DATE, PERSON wrote:)
        rollback = re.sub('^On.*wrote:$', '', rollback, flags=re.MULTILINE)

        # Remove leading ">" symbols from each line
        rollback = re.sub(r"^>( )*", "", rollback, flags=re.MULTILINE)

        rollback = rollback.strip()

        return rollback

In [109]:
%%capture
# again, capturing the output for privacy. remove to view your data.
import itertools

for key, email_obj in itertools.islice(mbox_obj.iteritems(), 300, 301):
    print(f"Parsing email {key+1} of {num_entries}")

    email_data = GmailMboxMessage(email_obj)

    try:
        email_data.parse_email()
    except IndexError:
        print('Skipping - email had no body, such as an unsubscribe.')
        continue
    except AuthorError:
        print('Skipping - not sent by me.')
        continue
    except IsNotReply:
        print('Skipping - sent by me but not a response.')
        continue
    except IsEmptyForward:
        print('Skipping - I forwarded with no extra text.')
        continue

    print('\n\n### CASE', key)
    print('--- FROM:', email_data.email_from)
    print('--- TO:', email_data.email_to)
    print('--- DATE:', email_data.email_date)
    print('--- LABELS:', email_data.email_labels)
    print('--- SUBJECT:', email_data.email_subject)
    print('--- TEXT:', email_data.email_plain_text)
    print('--- REPLY:', email_data.email_reply)
    print('--- PREVIOUS:', email_data.email_previous)

### Create the finetuning dataset

At this point we're successfully creating structure out of our unstructured e-mails. Now we need to get these into the format we want for fine-tuning. We'll iterate through the generator, and for each message that isn't excluded, construct a dict record in the desired format.

For an Alpaca instruction set, we want something like this:

```json
{
  "instruction": "Give three tips for staying healthy.",
  "input": "",
  "output": "1. Eat a balanced and nutritious diet: Make sure your meals are inclusive of a variety of fruits and vegetables, lean protein, whole grains, and healthy fats. This helps to provide your body with the essential nutrients to function at its best and can help prevent chronic diseases.\n\n2. Engage in regular physical activity: Exercise is crucial for maintaining strong bones, muscles, and cardiovascular health. Aim for at least 150 minutes of moderate aerobic exercise or 75 minutes of vigorous exercise each week.\n\n3. Get enough sleep: Getting enough quality sleep is crucial for physical and mental well-being. It helps to regulate mood, improve cognitive function, and supports healthy growth and immune function. Aim for 7-9 hours of sleep each night."
}
```

For the [OpenAssistant](https://huggingface.co/datasets/OpenAssistant/oasst1/blob/main/README.md) format, and specifically for something like was done in the [Guanaco demo](https://huggingface.co/datasets/timdettmers/openassistant-guanaco), we want something like:

```json
{
  "text": "### Human: I want to start doing astrophotography as a hobby, any suggestions what could i do?### Assistant: Getting started in astrophotography can seem daunting, but with some patience and practice, you can become a master of the craft."
}
```

It's not immediately clear which standard to take, and if using instructions, which instruction to supply. For the Alpaca format, it's clear that the input will be the "previous" email, the output will be my "reply", and we can provide some form of standard instruction for e-mail draft generation - perhaps even adding a constitutional element, taking a page from Constitional AI.

For this demo, we'll construct the Guanaco style, in which we use Email and Response headers.

**Note to self.** Develop a better understanding of which structure to use when, and implications of this choice.

In [110]:
from tqdm.notebook import tqdm

dataset = []

t = tqdm(itertools.islice(mbox_obj.iteritems(), 0, num_entries), total=num_entries)

for key, email_obj in t:

    email_data = GmailMboxMessage(email_obj)

    try:
        email_data.parse_email()
    except IndexError:
        # print('Skipping - email had no body, such as an unsubscribe.')
        continue
    except AuthorError:
        # print('Skipping - not sent by me.')
        continue
    except IsNotReply:
        # print('Skipping - sent by me but not a response.')
        continue
    except IsEmptyForward:
        # print('Skipping - I forwarded with no extra text.')
        continue

    if email_data.email_previous != '' and email_data.email_reply != '':
        example = {
            "text": ''.join(['### From: ', email_data.email_to if email_data.email_to is not None else '',
                           '\n### Email: ', email_data.email_previous,
                           '\n### Response: ', email_data.email_reply])
        }

        dataset.append(example)

  0%|          | 0/24987 [00:00<?, ?it/s]

  return bs4.BeautifulSoup(html, 'html.parser').body.get_text(' ', strip=True)
  return bs4.BeautifulSoup(html, 'html.parser').body.get_text(' ', strip=True)


In [111]:
print(f"Final dataset size: {len(dataset)}")

Final dataset size: 5979


Finally, we'll save this dataset as a json back to our filepath (Drive or local).

In [112]:
import json

with open(path/'finetune-dataset.json', 'w') as fh:
    json.dump(dataset, fh)