In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from sentence_transformers import SentenceTransformer




class MailDataset(Dataset):
    def __init__(self, 
                 data,
                 class_labels):
        """
        Torch dataset builder for the mails data
        
        Parameters:
        ```````````
        data (List): list of the mail objects
        class_labels (dict): integer labels for the different classes of category and action
        """
        self.mails = data
        self.class_labels = class_labels
    
    
    def __len__(self):
        return len(self.mails)
    
    
    def __getitem__(self, index):
        mail = self.mails[index]
        return (mail['subject'], 
                torch.tensor(self.class_labels['category'][mail['category']]), 
                torch.tensor(self.class_labels['action'][mail['action']]))
    
    
def mail_collate_fn(batch):
    """
    Custom collate function to process batches of data for the MailDataset.
    
    Parameters:
    ```````````
    batch (List): List of tuples returned by the dataset's __getitem__ method
    
    Returns:
    ````````
    - tokenized_subjects: List[str] or tensor if tokenized
    - categories: torch.Tensor of category labels
    - actions: torch.Tensor of action labels
    """
    subjects = [item[0] for item in batch]
    categories = torch.stack([item[1] for item in batch])
    actions = torch.stack([item[2] for item in batch])
    
    return subjects, categories, actions


class MailClassifier(nn.Module):
    def __init__(self, cache_folder=None, category_classes=7, action_classes=3):
        """ 
        Pytorch Model Class for the Mail Classification
        
        Parameters:
        ```````````
        cache_folder (str|pathlike): path to the sentence transformer model weights cache_folder
        category_classes (int): number of classes in category
        action_classes (int): number of classes in action
        """
        super(MailClassifier, self).__init__()
        self.st_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=cache_folder)
        
        # Freeze the SentenceTransformer layers
        for param in self.st_model.parameters():
            param.requires_grad = False
            
        self.common_layer_1 = nn.Linear(384, 192)
        self.common_layer_2 = nn.Linear(192, 96)
        self.common_layer_3 = nn.Linear(96, 48)
        self.category_classifier = nn.Sequential(
            nn.Linear(48, 24),
            nn.Linear(24, 12),
            nn.Linear(12, category_classes)
        )
        self.action_classifier = nn.Sequential(
            nn.Linear(48, 24),
            nn.Linear(24, 12),
            nn.Linear(12, action_classes)
        )
    
    
    def forward(self, mail_subjects):
        """
        Parameters:
        -----------
        mail_subjects (list[str]): List of mail subjects.
        
        Returns:
        --------
        category_logits (torch.Tensor): Logits for category classification.
        action_logits (torch.Tensor): Logits for action classification.
        """
        embeddings = self.st_model.encode(mail_subjects, convert_to_tensor=True)

        x = self.common_layer_1(embeddings)
        x = torch.relu(x)
        x = self.common_layer_2(x)
        x = torch.relu(x)
        x = self.common_layer_3(x)
        x = torch.relu(x)
        
        category_logits = self.category_classifier(x)
        action_logits = self.action_classifier(x)
        
        return category_logits, action_logits
    
    def predict_category(self, mail_subjects):
        """
        Predict the category class for given mail subjects.

        Parameters:
        -----------
        mail_subjects (list[str]): List of mail subjects.
        
        Returns:
        --------
        category_predictions (list[int]): Predicted category class indices.
        """
        self.eval()
        with torch.no_grad():
            category_logits, _ = self.forward(mail_subjects)
            category_predictions = torch.argmax(category_logits, dim=0).tolist()
        return category_predictions


    def predict_action(self, mail_subjects):
        """
        Predict the action class for given mail subjects.

        Parameters:
        -----------
        mail_subjects (list[str]): List of mail subjects.
        
        Returns:
        --------
        action_predictions (list[int]): Predicted action class indices.
        """
        self.eval()
        with torch.no_grad():
            _, action_logits = self.forward(mail_subjects)
            action_predictions = torch.argmax(action_logits, dim=0).tolist()
        return action_predictions


In [2]:
model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="../mount/models")

In [3]:
l = model.encode('This is a test sentence', convert_to_tensor=True)

In [None]:
l1 = torch.randn_like(l) * 0.05 * torch.norm(l) + l

import torch.nn.functional as F

print(torch.norm(l-l1))
F.cosine_similarity(l, l1, dim=0)

In [None]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")

In [53]:
mail_classifier = MailClassifier(cache_folder="../mount/models", category_classes=7, action_classes=3)

In [60]:
import pandas as pd
df = pd.read_csv("../mount/train_data.csv")
data = df.to_dict(orient="records")
class_labels = {
    "category": {
        "Education": 0,
        "Newsletters": 1,
        "Personal": 2,
        "Promotions": 3,
        "Social": 4,
        "Work": 5,
        "Unknown": 6
    },
    
    "action": {
        "READ": 0,
        "IGNORE": 1,
        "ACT": 2
    }
}
dataset = MailDataset(data, class_labels=class_labels)

train_dataloader = DataLoader(dataset, 
                              batch_size=5, 
                              shuffle=True)

In [None]:
epoch = 0
NUM_EPOCHS = 10
for b in tqdm(train_dataloader, desc=f"Epoch {epoch}/{NUM_EPOCHS} - Training"):
    pass
print(b[0])

In [68]:
# Training Function
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm

def train(training_data_file, 
          class_labels,
          mail_classifier,
          BATCH_SIZE=16,
          NUM_EPOCHS=100,
          LEARNING_RATE=1e-4,
          device=torch.device('cuda:0')):
    """
    Train function for the MailClassifier model.

    Parameters:
    -----------
    training_data_file (str): Path to the CSV file containing training data.
    class_labels (dict): Mapping of category and action labels to integers.
    mail_classifier (nn.Module): The PyTorch model to train.
    BATCH_SIZE (int): Batch size for DataLoader.
    NUM_EPOCHS (int): Number of epochs to train.
    LEARNING_RATE (float): the learning rate for the optimization
    device (torch.device): Device for training (CPU or GPU).
    """
    df = pd.read_csv(training_data_file)
    data = df.to_dict(orient="records")
    train_data, validation_data = train_test_split(data, test_size=0.1)
    train_dataset = MailDataset(train_data, class_labels=class_labels)
    val_dataset = MailDataset(validation_data, class_labels=class_labels)
    
    train_dataloader = DataLoader(train_dataset, 
                                  batch_size=BATCH_SIZE, 
                                  shuffle=True)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=BATCH_SIZE,
                                shuffle=True)
    
    mail_classifier.to(device)
    
    # loss function
    loss_function = nn.CrossEntropyLoss()
    
    # optimizer
    optimizer = torch.optim.Adam(mail_classifier.parameters(), lr=LEARNING_RATE)
    
    for epoch in range(1, NUM_EPOCHS+1):
        mail_classifier.train()  # Set model to training mode
        train_loss = 0.0
        
        for batch in tqdm(train_dataloader, desc=f"Epoch {epoch}/{NUM_EPOCHS} - Training"):
            X, Y_category, Y_action = batch
            Y_category.to(device)
            Y_action.to(device)
            category_logits, action_logits = mail_classifier(list(X))
            category_loss = loss_function(category_logits, Y_category)
            action_loss = loss_function(action_logits, Y_action)
            loss = category_loss + action_loss
           
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        avg_train_loss = train_loss / len(train_dataloader)
        print(f"Epoch {epoch}: Avg Train Loss = {avg_train_loss:.4f}")
    
        # Validation step
        mail_classifier.eval()  # Set model to evaluation mode
        val_loss = 0.0
        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc=f"Epoch {epoch}/{NUM_EPOCHS} - Validation"):
                X, Y_category, Y_action = batch
                X = list(X)  
                Y_category = Y_category.to(device)
                Y_action = Y_action.to(device)
                category_logits, action_logits = mail_classifier(X)
                category_loss = loss_function(category_logits, Y_category)
                action_loss = loss_function(action_logits, Y_action)
                loss = category_loss + action_loss
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_dataloader)
        print(f"Epoch {epoch}: Avg Val Loss = {avg_val_loss:.4f}")


In [None]:
training_data_file = "../mount/train_data.csv"
class_labels = {
    "category": {
        "Education": 0,
        "Newsletters": 1,
        "Personal": 2,
        "Promotions": 3,
        "Social": 4,
        "Work": 5,
        "Unknown": 6
    },
    
    "action": {
        "READ": 0,
        "IGNORE": 1,
        "ACT": 2
    }
}
train(training_data_file, 
      class_labels,
      mail_classifier,
      BATCH_SIZE=16,
      NUM_EPOCHS=100,
      LEARNING_RATE=1e-4,
      device=torch.device('cuda:0'))

In [11]:
import pandas as pd 

test_df = pd.read_csv('../mount/test_data.csv')

In [12]:
records = test_df.to_dict(orient='records')
class_labels = {
    "category": {
        "Education": 0,
        "Newsletters": 1,
        "Personal": 2,
        "Promotions": 3,
        "Social": 4,
        "Work": 5,
        "Unknown": 6
    },
    
    "action": {
        "READ": 0,
        "IGNORE": 1,
        "ACT": 2
    }
}

In [None]:
records[0]

In [8]:
from tqdm import tqdm 

def check_test(data, model, class_labels):
    correct_category = 0
    correct_action = 0
    mis_classified_category = []
    mis_classified_action = []
    for i, record in enumerate(tqdm((data))):
        subject = record['subject']
        category_tag = class_labels['category'][record['category']]
        action_tag = class_labels['action'][record['action']]
        pred_category = model.predict_category(subject)
        pred_action = model.predict_action(subject)
        if category_tag == pred_category:
            correct_category += 1
        else:
            mis_classified_category.append({'record_id':i, 'correct_category': category_tag, 'predicted_category': pred_category, 'subject': subject})
        if action_tag == pred_action:
            correct_action += 1
        else:
            mis_classified_action.append({'record_id':i, 'correct_action': action_tag, 'predicted_action': pred_action, 'subject': subject})
    
    return correct_category / len(data), correct_action / len(data), mis_classified_category, mis_classified_action

In [6]:
import torch 

mail_classifier = MailClassifier(cache_folder="../mount/models", category_classes=7, action_classes=3)
mail_classifier.load_state_dict(torch.load("../mount/model_checkpoints/epoch-027_val-loss-1.2099.pt"))

mail_classifier.to(torch.device('cuda:0'))
mail_classifier.predict_action("50% off on new podcasts!!")

1

In [13]:
category_score, action_score, mis_classified_category, mis_classified_action = check_test(records, mail_classifier, class_labels)

  0%|          | 0/640 [00:00<?, ?it/s]

100%|██████████| 640/640 [00:25<00:00, 25.56it/s]


In [14]:
print(f"category_score: {category_score} || action_score: {action_score}")

category_score: 0.7515625 || action_score: 0.85625


In [None]:
len(mis_classified_action), len(mis_classified_category)

In [None]:
action_dict = {
        0: "READ",
        1: "IGNORE",
        2: "ACT"
    }
mis_classified_action_new = [{'subject': v['subject'],
                              'correct_action': action_dict[v['correct_action']],
                              'predicted_action': action_dict[v['predicted_action']]} for v in mis_classified_action]
mis_classified_action_new

In [None]:
category_dict = {
        0: "Education",
        1: "Newsletters",
        2: "Personal",
        3: "Promotions",
        4: "Social",
        5: "Work",
        6: "Unknown"
    }
mis_classified_category_new = [{'subject': v['subject'],
                              'correct_category': category_dict[v['correct_category']],
                              'predicted_category': category_dict[v['predicted_category']]} for v in mis_classified_category]
mis_classified_category_new