In [1]:
#imports
import re
import pandas as pd
from typing import List, Dict, Any
from datetime import datetime
import html

Dataset Definition

In [2]:
email_thread_summaries_dataset = pd.read_csv("email_thread_summaries.csv")
email_thread_summaries_dataset.head()

Unnamed: 0,thread_id,summary
0,1,The email thread discusses the Master Terminat...
1,2,A lunch meeting has been scheduled for May 5th...
2,3,Ben is updating a friend on his progress with ...
3,4,The recipient of the email thread initially ex...
4,5,The email thread discusses the long form confi...


In [3]:
email_thread_details_dataset = pd.read_csv("email_thread_details.csv")
email_thread_details_dataset.head()

Unnamed: 0,thread_id,subject,timestamp,from,to,body
0,1,FW: Master Termination Log,2002-01-29 11:23:42,"Gossett, Jeffrey C. JGOSSET","['Giron', 'Darron C. Dgiron', 'Love', 'Phillip...",\n\n -----Original Message-----\nFrom: =09Ther...
1,1,FW: Master Termination Log,2002-01-31 12:50:00,"Theriot, Kim S. KTHERIO","['Murphy', 'Melissa Mmurphy', 'Gossett', 'Jeff...",\n\n -----Original Message-----\nFrom: =09Panu...
2,1,FW: Master Termination Log,2002-02-05 15:03:35,"Theriot, Kim S. KTHERIO","['Murphy', 'Melissa Mmurphy', 'Anderson', 'Dia...",Note to Stephanie Panus....\n\nStephanie...ple...
3,1,FW: Master Termination Log,2002-02-05 15:06:25,"Theriot, Kim S. KTHERIO","['Hall', 'D. Todd Thall', 'Sweeney', 'Kevin Ks...",\n\n -----Original Message-----\nFrom: =09Panu...
4,1,FW: Master Termination Log,2002-05-28 07:20:35,"Kelly, Katherine L. KKELLY","['Germany', 'Chris Cgerman']",\n\n -----Original Message-----\nFrom: =09McMi...


In [4]:
len(email_thread_details_dataset)

21684

Explore the Datasets

In [5]:
print(email_thread_details_dataset['timestamp'].dtype)

object


In [6]:
email_thread_details_dataset['timestamp']=pd.to_datetime(email_thread_details_dataset['timestamp'])

In [7]:
FilteredDataset= email_thread_details_dataset[(email_thread_details_dataset['thread_id'] ==27)].sort_values(by='timestamp', ascending=True)
FilteredDataset.head(20)

Unnamed: 0,thread_id,subject,timestamp,from,to,body
147,27,RE: Admission Visit,2000-02-10 06:36:00,Benjamin Rogers,['Meg Brooks <Meg.Brooks@bus.utexas.edu'],Thanks for the fast reply. I changed it onlin...
148,27,RE: Admission Visit,2000-02-21 11:23:00,Benjamin Rogers,['Meg Brooks <Meg.Brooks@bus.utexas.edu> @ ENR...,Thanks for the informative infomation session...
149,27,RE: Admission Visit,2000-03-05 23:49:00,Benjamin Rogers,['Meg Brooks <Meg.Brooks@bus.utexas.edu> @ ENR...,Meg:\nI was wondering if you are able to give ...
150,27,RE: Admission Visit,2000-03-06 04:48:00,Benjamin Rogers,['Meg Brooks <Meg.Brooks@bus.utexas.edu> @ ENR...,Thanks for your fast response. I really hope ...
151,27,RE: Admission Visit,2000-03-07 11:10:00,Benjamin Rogers,['Meg Brooks <Meg.Brooks@bus.utexas.edu> @ ENR...,Meg:\nI would like to make sure that I have fu...


Data Cleaning

In [8]:
def threads_preprocess(email_thread_details_dataset):

#normalizing subject
  email_thread_details_dataset['subject'] = (
      email_thread_details_dataset['subject']
      .str.replace(r'^\s*((re|fw|fwd)\s*:\s*)+', '', regex=True, case=False)
      .str.strip()
  )

  #converting timestamp and sort
  email_thread_details_dataset['timestamp'] = pd.to_datetime(email_thread_details_dataset['timestamp'])
  email_thread_details_dataset = email_thread_details_dataset.sort_values(['thread_id', 'timestamp'])

#deduplicating by sender+timestamp+recipient
  email_thread_details_dataset = email_thread_details_dataset.drop_duplicates(
      subset=['thread_id', 'from', 'timestamp', 'to'],
      keep='first'
  )

  #removing threads with only repeated content
  threads = email_thread_details_dataset.groupby('thread_id', group_keys=False)
  email_thread_details_dataset = threads.filter(lambda x: x['body'].nunique() > 1)

  #anonymization

  # def anonymize_text(text):
  #     # Keep only first names
  #     text = re.sub(r'\b([A-Z][a-z]+)\s+[A-Z][a-z]+\b', r'\1', text)
  #     # Remove sensitive words
  #     text = re.sub(r'\b(password|pwd|confidential)\b', '', text, flags=re.I)
  #     # Replace email, phone, URL, IP, path, numbers
  #     text = re.sub(r'\b[\w\.-]+@[\w\.-]+\.\w+\b', 'USERNAME@DOMAIN.COM', text)
  #     text = re.sub(r'\b\d{10,}\b', 'PHONENUMBER', text)
  #     text = re.sub(r'\b\d+\b', 'NUMBER', text)
  #     text = re.sub(r'http[s]?://\S+', 'HTTP://LINK', text)
  #     text = re.sub(r'\b\d{1,3}(?:\.\d{1,3}){3}\b', 'IPADDRESS', text)
  #     text = re.sub(r'(?:[A-Za-z]:)?[/\\][\w/\\.-]+', 'PATH', text)
  #     return text

  # email_thread_details_dataset['body'] = email_thread_details_dataset['body'].apply(anonymize_text)


  #email_thread_details_dataset=email_thread_details_dataset.groupby("thread_id")

  return email_thread_details_dataset

In [9]:
class EmailBodyPreprocessor:
    def __init__(self):
#common email patterns
        self.forward_patterns = [
            r'----- Forwarded by .*? on \d{2}/\d{2}/\d{4}.*?-----',
            r'-----Original Message-----',
            r'----- Forwarded Message -----',
            r'From:.*?Sent:.*?To:.*?Subject:',
        ]

        self.signature_patterns = [
            r'\nStephanie Panus.*?\d{3}\.\d{3}\.\d{4}',
            r'\nBrian.*?\n\n',
            r'\nThanks,?\n.*',
            r'\nBest regards,?\n.*',
            r'\nSincerely,?\n.*',
            r'\nRegards,?\n.*',
            r'\n-\s*\n.*',
            r'ph:\s*\d{3}\.\d{3}\.\d{4}.*?fax:\s*\d{3}\.\d{3}\.\d{4}',
        ]

#email header patterns
        self.header_patterns = [
            r'From:\s*(.*?)\n',
            r'Sent:\s*(.*?)\n',
            r'To:\s*(.*?)\n',
            r'Cc:\s*(.*?)\n',
            r'Subject:\s*(.*?)\n',
        ]

    def clean_encoding_artifacts(self, text: str) -> str:
        """Clean encoding issues from emails."""
#removing =09, =20, etc.
        text = re.sub(r'=\d{2}', ' ', text)
        # Fix line breaks with = at end
        text = re.sub(r'=\n', '', text)
#removing non-breaking spaces and other special chars
        text = re.sub(r'[\xa0\u200b\u200c\u200d]', ' ', text)
#HTML entities
        text = html.unescape(text)
        return text

    def extract_email_headers(self, text: str) -> Dict[str, Any]:
        """Extract email header information from body."""
        headers = {
            'from': None,
            'sent_date': None,
            'to': [],
            'cc': [],
            'subject': None,
            'is_forwarded': False,
            'is_replied': False
        }

#checking if this is forwarded/replied email
        if '-----Original Message-----' in text or '----- Forwarded by' in text:
            headers['is_forwarded'] = True

        if 'Re:' in text[:100] or 'RE:' in text[:100] or 'Fwd:' in text[:100] or 'FW:' in text[:100]:
            headers['is_replied'] = True

#trying to extract headers
        for pattern in self.header_patterns:
            match = re.search(pattern, text, re.IGNORECASE)
            if match:
                key = pattern.split(':')[0].lower()
                value = match.group(1).strip()
                if key == 'from':
                    headers['from'] = value
                elif key == 'sent':
                    headers['sent_date'] = value
                elif key == 'to':
                    # Split multiple recipients
                    recipients = re.split(r'[;,]\s*', value)
                    headers['to'] = [r.strip() for r in recipients if r.strip()]
                elif key == 'cc':
                    recipients = re.split(r'[;,]\s*', value)
                    headers['cc'] = [r.strip() for r in recipients if r.strip()]
                elif key == 'subject':
                    headers['subject'] = value

        return headers

    def remove_quoted_text(self, text: str) -> str:
        """Remove quoted/replied text from email body."""
        lines = text.split('\n')
        cleaned_lines = []
        in_quoted_section = False
        quote_depth = 0

        for line in lines:
#checking for forwarded/quote markers
            if any(pattern in line for pattern in [
                '-----Original Message-----',
                '----- Forwarded by',
                'From: ',
                'Sent: ',
                'To: ',
                'Subject: '
            ]):
                if '-----Original Message-----' in line or '----- Forwarded by' in line:
                    in_quoted_section = True
                    quote_depth += 1
                continue

#checking for email header lines in quotes
            if in_quoted_section:
                if line.strip() == '' and quote_depth == 1:
#emptying line might end the header section
                    continue
                elif re.match(r'^\s*On.*wrote:$', line):
#common reply pattern
                    continue
                elif line.strip().startswith('>') or line.strip().startswith('|'):
#quoted text markers
                    continue
                elif quote_depth > 0 and not line.strip():
#decreasing depth on empty lines in quotes
                    quote_depth -= 1
                    if quote_depth == 0:
                        in_quoted_section = False
                    continue

            if not in_quoted_section:
                cleaned_lines.append(line)

        return '\n'.join(cleaned_lines)

    def remove_signatures(self, text: str) -> str:
        """Remove email signatures."""
        for pattern in self.signature_patterns:
            text = re.sub(pattern, '', text, flags=re.DOTALL | re.IGNORECASE)

#removing common closing patterns
        closing_patterns = [
            r'\n\s*--\s*\n.*',
            r'\n\s*---\s*\n.*',
            r'\nSent from my.*',
            r'\nConfidentiality Notice.*',
        ]

        for pattern in closing_patterns:
            text = re.sub(pattern, '', text, flags=re.DOTALL | re.IGNORECASE)

        return text

    def clean_email_body(self, text: str) -> str:
        """Main cleaning function for email body."""
        if not isinstance(text, str):
            return ""

#cleaning encoding artifacts
        text = self.clean_encoding_artifacts(text)

#extracting headers (store separately if needed)
        headers = self.extract_email_headers(text)

#removing quoted/replied text
        text = self.remove_quoted_text(text)

#removing signatures
        text = self.remove_signatures(text)

#normalizing whitespace
        text = re.sub(r'\s+', ' ', text)
        text = re.sub(r'\n\s*\n', '\n\n', text)  # Preserve paragraph breaks

#cleaning up common email artifacts
        text = re.sub(r'\s*<\s*File:.*?>\s*', ' [ATTACHMENT] ', text)
        text = re.sub(r'\[.*?@.*?\]', '', text)

#triming and returning
        text = text.strip()

        return text


    def preprocess_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
        """Preprocess the entire dataframe."""
        processed_df = df.copy()

        processed_df['body'] = processed_df['body'].apply(self.clean_email_body)

        return processed_df

    def preprocess_threads(self, df: pd.Series) -> pd.Series:
        """Preprocess the entire dataframe."""
        processed_df = df.copy()

        processed_df = processed_df.apply(self.clean_email_body)

        return processed_df


def clean_email(df_email_details):
    df = df_email_details.copy()

    preprocessor = EmailBodyPreprocessor()

    preprocess_df = preprocessor.preprocess_dataframe(df)

    return preprocess_df

def clean_thread_series(thread):

    preprocessor = EmailBodyPreprocessor()

    preprocess_df = preprocessor.preprocess_threads(thread)

    return preprocess_df

In [10]:
cleaned_threads = threads_preprocess(email_thread_details_dataset)
cleaned_threads

Unnamed: 0,thread_id,subject,timestamp,from,to,body
0,1,Master Termination Log,2002-01-29 11:23:42,"Gossett, Jeffrey C. JGOSSET","['Giron', 'Darron C. Dgiron', 'Love', 'Phillip...",\n\n -----Original Message-----\nFrom: =09Ther...
1,1,Master Termination Log,2002-01-31 12:50:00,"Theriot, Kim S. KTHERIO","['Murphy', 'Melissa Mmurphy', 'Gossett', 'Jeff...",\n\n -----Original Message-----\nFrom: =09Panu...
2,1,Master Termination Log,2002-02-05 15:03:35,"Theriot, Kim S. KTHERIO","['Murphy', 'Melissa Mmurphy', 'Anderson', 'Dia...",Note to Stephanie Panus....\n\nStephanie...ple...
3,1,Master Termination Log,2002-02-05 15:06:25,"Theriot, Kim S. KTHERIO","['Hall', 'D. Todd Thall', 'Sweeney', 'Kevin Ks...",\n\n -----Original Message-----\nFrom: =09Panu...
4,1,Master Termination Log,2002-05-28 07:20:35,"Kelly, Katherine L. KKELLY","['Germany', 'Chris Cgerman']",\n\n -----Original Message-----\nFrom: =09McMi...
...,...,...,...,...,...,...
21679,4166,vacation,2000-10-04 11:32:00,Sara Shackleton,"['Gary Hickerson', 'Sheila Glover', 'Laurel Ad...",I will be on vacation from October 6- 13. Als...
21680,4167,web file,2001-03-18 22:57:00,Matt Smith,['Amanda Huble'],"Amanda,\n\nCan you put this file in the approp..."
21681,4167,web file,2001-03-19 04:42:00,Matt Smith,['Amanda Huble'],"Amanda,\n\nPlease move the file i sent you fro..."
21682,4167,web file,2001-03-19 09:57:00,Matt Smith,['Amanda Huble <Amanda Huble/NA/Enron@Enron'],"Amanda,\n\nCan you put this file in the approp..."


In [11]:
cleaned_df = clean_email(cleaned_threads)
cleaned_df

Unnamed: 0,thread_id,subject,timestamp,from,to,body
0,1,Master Termination Log,2002-01-29 11:23:42,"Gossett, Jeffrey C. JGOSSET","['Giron', 'Darron C. Dgiron', 'Love', 'Phillip...",
1,1,Master Termination Log,2002-01-31 12:50:00,"Theriot, Kim S. KTHERIO","['Murphy', 'Melissa Mmurphy', 'Gossett', 'Jeff...",
2,1,Master Termination Log,2002-02-05 15:03:35,"Theriot, Kim S. KTHERIO","['Murphy', 'Melissa Mmurphy', 'Anderson', 'Dia...",Note to Stephanie Panus.... Stephanie...please...
3,1,Master Termination Log,2002-02-05 15:06:25,"Theriot, Kim S. KTHERIO","['Hall', 'D. Todd Thall', 'Sweeney', 'Kevin Ks...",
4,1,Master Termination Log,2002-05-28 07:20:35,"Kelly, Katherine L. KKELLY","['Germany', 'Chris Cgerman']",
...,...,...,...,...,...,...
21679,4166,vacation,2000-10-04 11:32:00,Sara Shackleton,"['Gary Hickerson', 'Sheila Glover', 'Laurel Ad...",I will be on vacation from October 6- 13. Also...
21680,4167,web file,2001-03-18 22:57:00,Matt Smith,['Amanda Huble'],"Amanda, Can you put this file in the appropria..."
21681,4167,web file,2001-03-19 04:42:00,Matt Smith,['Amanda Huble'],"Amanda, Please move the file i sent you from t..."
21682,4167,web file,2001-03-19 09:57:00,Matt Smith,['Amanda Huble <Amanda Huble/NA/Enron@Enron'],"Amanda, Can you put this file in the appropria..."


In [12]:
thread_cleaned_grouped=email_thread_details_dataset.groupby("thread_id")
thread_cleaned_grouped

<pandas.core.groupby.generic.DataFrameGroupBy object at 0x7d6a1b52d160>

In [13]:
#combining emails per thread
def thread_to_text(thread_df):
    parts = []
    for _, row in thread_df.iterrows():
        part = f"From: {row['from']} To: {''.join(row['to'])} Time: {row['timestamp']} Body: {row['body']} "
        parts.append(part)
    return " ".join(parts)

thread_texts = thread_cleaned_grouped.apply(thread_to_text)

  thread_texts = thread_cleaned_grouped.apply(thread_to_text)


In [14]:
thread_texts

Unnamed: 0_level_0,0
thread_id,Unnamed: 1_level_1
1,"From: Gossett, Jeffrey C. JGOSSET To: ['Giron'..."
2,From: Tana Jones To: ['Suzanne Adams'] Time: 2...
3,"From: Benjamin Rogers To: ['""CHOBY', 'C."" <G7P..."
4,From: Phillip M Love To: ['Julie Ferrara'] Tim...
5,From: Kay Mann To: ['Reagan Rorschach'] Time: ...
...,...
4163,"From: Kay Mann To: ['Sheila Tweed', 'Dale Rasm..."
4164,From: Elizabeth Sager To: ['Genia FitzGerald']...
4165,"From: Watson, Kimberly KWATSON To: [""'john.wat..."
4166,"From: Susan Scott To: ['Drew Fossum@ENRON', 'J..."


In [15]:
def merge_thread_summary(thread_texts,df):
    thread_df = thread_texts.to_frame('thread_text')

    #reseting index if needed (if Series has an index you want to keep)
    thread_df = thread_texts.reset_index()

    thread_df['thread_text'] = thread_df[0]
    thread_df = thread_df.drop(columns=[0])

    #joining using merge
    result = pd.merge(df, thread_df,  how='left' , on = 'thread_id')

    return result

In [16]:
merged_df = merge_thread_summary(thread_texts,email_thread_summaries_dataset)
merged_df

Unnamed: 0,thread_id,summary,thread_text
0,1,The email thread discusses the Master Terminat...,"From: Gossett, Jeffrey C. JGOSSET To: ['Giron'..."
1,2,A lunch meeting has been scheduled for May 5th...,From: Tana Jones To: ['Suzanne Adams'] Time: 2...
2,3,Ben is updating a friend on his progress with ...,"From: Benjamin Rogers To: ['""CHOBY', 'C."" <G7P..."
3,4,The recipient of the email thread initially ex...,From: Phillip M Love To: ['Julie Ferrara'] Tim...
4,5,The email thread discusses the long form confi...,From: Kay Mann To: ['Reagan Rorschach'] Time: ...
...,...,...,...
4162,4163,Peter Thompson has sent a memo to Kay Mann and...,"From: Kay Mann To: ['Sheila Tweed', 'Dale Rasm..."
4163,4164,The email thread revolves around the sharing a...,From: Elizabeth Sager To: ['Genia FitzGerald']...
4164,4165,Susan asks Emily about her plans for the weeke...,"From: Watson, Kimberly KWATSON To: [""'john.wat..."
4165,4166,Several employees will be on vacation during d...,"From: Susan Scott To: ['Drew Fossum@ENRON', 'J..."


Model:


In [None]:
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

# Sample data
#texts = ["Dear team, please review the attached report on vulnerabilities. We need feedback by Friday."]
#summaries = ["Review report on vulnerabilities by Friday."]
texts = merged_df['thread_text']
summaries = merged_df['summary']

# Tokenizer for input
src_tokenizer = Tokenizer(filters='', oov_token='<unk>')
src_tokenizer.fit_on_texts(texts)
src_sequences = src_tokenizer.texts_to_sequences(texts)
src_sequences = pad_sequences(src_sequences, padding='post')

# Tokenizer for target (summary)
trg_tokenizer = Tokenizer(filters='', oov_token='<unk>')
trg_tokenizer.fit_on_texts(summaries)
trg_sequences = trg_tokenizer.texts_to_sequences(summaries)
trg_sequences = pad_sequences(trg_sequences, padding='post')

# Vocabulary sizes
src_vocab_size = len(src_tokenizer.word_index) + 1
trg_vocab_size = len(trg_tokenizer.word_index) + 1


In [None]:
from tensorflow.keras import layers

embedding_dim = 128
hidden_units = 256

# Encoder
encoder_inputs = tf.keras.Input(shape=(None,))
enc_emb = layers.Embedding(src_vocab_size, embedding_dim)(encoder_inputs)
encoder_lstm = layers.LSTM(hidden_units, return_state=True)
_, state_h, state_c = encoder_lstm(enc_emb)
encoder_states = [state_h, state_c]


In [None]:
decoder_inputs = tf.keras.Input(shape=(None,))
dec_emb = layers.Embedding(trg_vocab_size, embedding_dim)(decoder_inputs)
decoder_lstm = layers.LSTM(hidden_units, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(dec_emb, initial_state=encoder_states)
decoder_dense = layers.Dense(trg_vocab_size, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)


In [None]:
model = tf.keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.summary()


In [None]:
import numpy as np

trg_input = trg_sequences[:, :-1]   # all tokens except last
trg_output = trg_sequences[:, 1:]   # all tokens except first
trg_output = np.expand_dims(trg_output, -1)  # required for sparse_categorical_crossentropy


In [None]:
model.fit(
    [src_sequences, trg_input],
    trg_output,
    batch_size=16,
    epochs=50
)


Epoch 1/50
[1m261/261[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m477s[0m 2s/step - accuracy: 0.5392 - loss: 4.9267
Epoch 2/50
[1m 89/261[0m [32m━━━━━━[0m[37m━━━━━━━━━━━━━━[0m [1m5:12[0m 2s/step - accuracy: 0.5795 - loss: 3.1791

New Model

In [19]:
!pip install evaluate

Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.6


In [21]:
!pip install rouge_score

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=807c3b848d2c60309268399b8de9f1afdd4e2c7377aa67bcec91b785531e6fb1
  Stored in directory: /root/.cache/pip/wheels/85/9d/af/01feefbe7d55ef5468796f0c68225b6788e85d9d0a281e7a70
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2


In [20]:
!pip -q install evaluate rouge_score

import os
os.environ["WANDB_DISABLED"] = "true"
import gc
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from transformers import EarlyStoppingCallback
from datasets import Dataset
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
)

import evaluate

#preparING dataframe
df = merged_df.copy()

TEXT_COL = "thread_text"
SUMMARY_COL = "summary"

df = df.dropna(subset=[TEXT_COL, SUMMARY_COL]).copy()
df[TEXT_COL] = df[TEXT_COL].astype(str)
df[SUMMARY_COL] = df[SUMMARY_COL].astype(str)

#splitting by thread_id (no leakage)
thread_ids = df["thread_id"].unique()
train_ids, temp_ids = train_test_split(thread_ids, test_size=0.2, random_state=42)
val_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=42)

train_df = df[df["thread_id"].isin(train_ids)].sample(800, random_state=42)
val_df   = df[df["thread_id"].isin(val_ids)].sample(150, random_state=42)
test_df  = df[df["thread_id"].isin(test_ids)].sample(150, random_state=42)

train_ds = Dataset.from_pandas(train_df[[TEXT_COL, SUMMARY_COL]])
val_ds   = Dataset.from_pandas(val_df[[TEXT_COL, SUMMARY_COL]])
test_ds  = Dataset.from_pandas(test_df[[TEXT_COL, SUMMARY_COL]])

#model (FAST + GOOD)
model_name = "facebook/bart-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

max_source_length = 512
max_target_length = 96

def preprocess(batch):
    inputs = tokenizer(
        batch[TEXT_COL],
        truncation=True,
        max_length=max_source_length
    )
    labels = tokenizer(
        batch[SUMMARY_COL],
        truncation=True,
        max_length=max_target_length
    )
    inputs["labels"] = labels["input_ids"]
    return inputs

train_tok = train_ds.map(preprocess, batched=True, remove_columns=train_ds.column_names)
val_tok   = val_ds.map(preprocess, batched=True, remove_columns=val_ds.column_names)
test_tok  = test_ds.map(preprocess, batched=True, remove_columns=test_ds.column_names)

del df, train_df, val_df, test_df
del train_ds, val_ds, test_ds
gc.collect()

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

#metrics
rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    preds, labels = eval_pred
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        use_stemmer=True
    )
    return {k: round(v, 4) for k, v in result.items()}

#training argumentss
args = Seq2SeqTrainingArguments(
    output_dir="email_thread_bart_base",

    learning_rate=2e-5,
    num_train_epochs=5,

    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,

    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,

    predict_with_generate=False,

    logging_steps=20,
    logging_first_step=True,

    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,

    fp16=False,
    report_to="none"
)

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=train_tok,
    eval_dataset=val_tok,
    tokenizer=tokenizer,
    data_collator=data_collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=1)]
)

#baseline Fine-tuned
baseline = trainer.evaluate(test_tok, metric_key_prefix="baseline_test")
print("Baseline:", baseline)

trainer.train()

#enabling generation for evaluation only
trainer.args.predict_with_generate = True

finetuned = trainer.evaluate(test_tok, metric_key_prefix="finetuned_test")
print("Fine-tuned:", finetuned)

#generating examples
sample_df = merged_df.sample(5, random_state=1)

inputs = tokenizer(
    sample_df[TEXT_COL].tolist(),
    return_tensors="pt",
    truncation=True,
    max_length=max_source_length,
    padding=True
).to(trainer.model.device)

summary_ids = trainer.model.generate(
    **inputs,
    max_new_tokens=max_target_length
)

print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True))

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/150 [00:00<?, ? examples/s]

Map:   0%|          | 0/150 [00:00<?, ? examples/s]

  trainer = Seq2SeqTrainer(


early stopping required metric_for_best_model, but did not find eval_loss so early stopping is disabled


Baseline: {'baseline_test_loss': 4.256063461303711, 'baseline_test_model_preparation_time': 0.0052, 'baseline_test_runtime': 4.6442, 'baseline_test_samples_per_second': 32.298, 'baseline_test_steps_per_second': 16.149}


Epoch,Training Loss,Validation Loss,Model Preparation Time
1,2.7753,2.321552,0.0052
2,2.4276,2.261322,0.0052
3,2.5237,2.216999,0.0052
4,2.1105,2.231141,0.0052


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight', 'lm_head.weight'].


early stopping required metric_for_best_model, but did not find eval_loss so early stopping is disabled


Fine-tuned: {'finetuned_test_loss': 2.3480231761932373, 'finetuned_test_model_preparation_time': 0.0052, 'finetuned_test_runtime': 3.6118, 'finetuned_test_samples_per_second': 41.531, 'finetuned_test_steps_per_second': 20.765, 'epoch': 4.0}
['The email thread discusses the Arthur Andersen model validation request. Gillian Boyer asks Vince Kaminski to provide the inputs for a particular deal and have him recalculate the deal value. Vince responds that the request is to validate that the Enron Global Market bookadministrators are accurately using the "spread option model" as developed by the Research Group. Vince also mentions that two Koch deals have been chosen due to their substantial P/L effect.', 'Chris Germany has developed a super scientific method for pricing VNG space for Ogy at Doyle on Transco. Poke all the holes in it and determine the price to Ogy. The higher of the two offers is $.35TCO Offer plus $.14TCO offer.', 'Kimberly Kupiecki sent an email to Jeff Dasovich regarding 

In [21]:
trainer.args.predict_with_generate = True
trainer.evaluate(test_tok)

{'eval_loss': 2.3480231761932373,
 'eval_model_preparation_time': 0.0052,
 'eval_runtime': 4.1241,
 'eval_samples_per_second': 36.372,
 'eval_steps_per_second': 18.186,
 'epoch': 4.0}

In [22]:
print(model)

BartForConditionalGeneration(
  (model): BartModel(
    (shared): BartScaledWordEmbedding(50265, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): BartScaledWordEmbedding(50265, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_n

In [23]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

Total parameters: 139,420,416
Trainable parameters: 139,420,416


In [24]:
print(model.config.to_json_string())

{
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartForConditionalGeneration"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 768,
  "decoder_attention_heads": 12,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "dtype": "float32",
  "early_stopping": null,
  "encoder_attention_heads": 12,
  "encoder_ffn_dim": 3072,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_position_embeddings": 1024,
  "model_type": "bart",
  "no_repeat_ngram_size": 