In [1]:
import pandas as pd
import torch
from torch.utils.data import Dataset, TensorDataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration
from sklearn.model_selection import train_test_split
from transformers import AdamW
from tqdm import tqdm
from torch.optim import AdamW


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df = pd.read_csv('/Users/abhishekwaghchaure/Desktop/Datasets/email/preprocessed_emails.csv')
df.head()

Unnamed: 0,file,headers,body,subject,from,to,clean_body,processed_body,date,response
0,semperger-c/deleted_items/46.,Message-ID: <30978077.1075841544706.JavaMail.e...,Time is running very short. Is your company p...,!!! OATI Etag 1.7 Minimum Requirements !!!,frank.billington@oatiinc.com,cara.semperger@enron.com,Time is running very short Is your company pre...,"[['Time', 'running', 'short', 'Is', 'company',...",2002-02-01 11:25:00-08:00,"[['Time', 'running', 'short', 'Is', 'company',..."
1,king-j/deleted_items/19.,Message-ID: <15244269.1075840797931.JavaMail.e...,Time is running very short. Is your company p...,!!! OATI Etag 1.7 Minimum Requirements !!!,frank.billington@oatiinc.com,jeff.king@enron.com,Time is running very short Is your company pre...,"[['Time', 'running', 'short', 'Is', 'company',...",2002-02-01 11:28:51-08:00,"[['Time', 'running', 'short', 'Is', 'company',..."
2,platter-p/inbox/43.,Message-ID: <394365.1075841413683.JavaMail.eva...,Time is running very short. Is your company p...,!!! OATI Etag 1.7 Minimum Requirements !!!,frank.billington@oatiinc.com,phillip.platter@enron.com,Time is running very short Is your company pre...,"[['Time', 'running', 'short', 'Is', 'company',...",2002-02-01 11:32:28-08:00,"[['Time', 'running', 'short', 'Is', 'company',..."
3,salisbury-h/inbox/196.,Message-ID: <19201127.1075841505530.JavaMail.e...,Immediately delete and DO NOT OPEN email \n \n...,!!!!!!!!!!!GONE.SCR VIRUS Warning!!!!!!!!!!!11,david.steiner@enron.com,center.dl-portland@enron.com,Immediately delete and DO NOT OPEN email From ...,"[['Immediately', 'delete', 'DO', 'NOT', 'OPEN'...",2001-12-04 11:49:46-08:00,"[['Immediately', 'delete', 'DO', 'NOT', 'OPEN'..."
4,kaminski-v/all_documents/1055.,Message-ID: <8575423.1075856206811.JavaMail.ev...,HENWOOD ANNOUNCES A MAJOR NEW RELEASE AND FUNC...,""" Henwood's Rationalizing Midwest Power Market...",cfarrell@hesinet.com,vkamins@ect.enron.com,HENWOOD ANNOUNCES A MAJOR NEW RELEASE AND FUNC...,"[['HENWOOD', 'ANNOUNCES', 'A', 'MAJOR', 'NEW',...",2001-03-19 03:17:00-08:00,"[['HENWOOD', 'ANNOUNCES', 'A', 'MAJOR', 'NEW',..."


In [3]:
df.isnull().sum()

file                 0
headers              0
body                 0
subject              0
from                 0
to                4756
clean_body          32
processed_body       0
date                 0
response             0
dtype: int64

In [4]:
df = df.dropna(subset=['subject', 'processed_body'])
df.isnull().sum()

file                 0
headers              0
body                 0
subject              0
from                 0
to                4756
clean_body          32
processed_body       0
date                 0
response             0
dtype: int64

In [5]:
class EmailDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len = 512):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        row = self.data.iloc[index]
        input_text = f"Email Body : {row['processed_body']} Subject : {row['subject']}"
        target_text = row['response']

        # Tokenization
        input_encodings = self.tokenizer(
            input_text, max_length = self.max_len, truncation = True, padding = 'max_length', return_tensors = 'pt'
        )
        target_encodings = self.tokenizer(
            target_text, max_length = self.max_len, truncation = True, padding = 'max_length', return_tensors = 'pt'
        )

        return{
            "input_ids": input_encodings["input_ids"].squeeze(),
            "attention_mask": input_encodings["attention_mask"].squeeze(),
            "labels": target_encodings["input_ids"].squeeze(),
        }



In [6]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [7]:
train_df, test_df = train_test_split(df, test_size=0.2, random_state=32)

train_dataset = EmailDataset(train_df, tokenizer)
test_dataset = EmailDataset(test_df, tokenizer)

In [8]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8)

In [9]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)

: 

In [None]:
optimizer = AdamW(model.parameters(), lr=5e-5)
epochs = 3

model.train()
for epoch in range(epochs):
    total_loss = 0
    with tqdm(train_loader, unit = 'batch') as tepoch:
        for batch in tepoch:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            loss.backward()

            # Synchronize gradients for MPS
            if device.type == "mps":
                torch.mps.sync()

            optimizer.step()

            total_loss += loss.item()

            tepoch.set_description(f"Epoch {epoch + 1}")
            tepoch.set_postfix(loss=f"{loss.item():.4f}")

        print(f"Epoch {epoch + 1}/{epochs}, Total Loss: {total_loss:.4f}")

  0%|          | 0/33936 [00:00<?, ?batch/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
