<a href="https://colab.research.google.com/github/Ahad555/Image_Colorization_CNN/blob/main/LLM_BeforeEvasion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **Cloning From Repository**



This step involves cloning the necessary repository that contains the datasets and scripts for the project.

Fetch the repository that contains the Enron1, Enron2, and SMS datasets along with the relevant code for spam classification.

In [None]:
!git clone --branch main https://github.com/paulinaeb/IDaSec-project.git

Cloning into 'IDaSec-project'...
remote: Enumerating objects: 433, done.[K
remote: Counting objects: 100% (267/267), done.[K
remote: Compressing objects: 100% (211/211), done.[K
remote: Total 433 (delta 137), reused 162 (delta 48), pack-reused 166 (from 1)[K
Receiving objects: 100% (433/433), 41.69 MiB | 9.84 MiB/s, done.
Resolving deltas: 100% (197/197), done.


# **Importing Required Libraries**
Various Python libraries and modules are imported for data processing, model building, training, and evaluation.

**Libraries:**

**torch:** For building and training neural networks.

**pandas:** Used for data manipulation.

**matplotlib:** For plotting graphs.

**transformers:** Hugging Face library to use DistilBERT for sequence classification.

**sklearn:** For model evaluation and data preprocessing.

**google.colab:** For mounting Google Drive if necessary for saving/loading models.

**Device Selection:** The device is set to CUDA if a GPU is available, otherwise falls back to CPU.

In [None]:
import torch
import os
import re
import pandas as pd
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader, TensorDataset
from transformers import DistilBertTokenizer
from sklearn.metrics import accuracy_score
from transformers import DistilBertForSequenceClassification
from sklearn.model_selection import train_test_split
from datasets import Dataset
from datetime import datetime
from sklearn.preprocessing import StandardScaler
from google.colab import drive
from torch.nn import CrossEntropyLoss
from transformers import DistilBertModel
from sklearn.metrics import precision_recall_fscore_support, classification_report
from transformers.optimization import get_linear_schedule_with_warmup



# Determine device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


# **Dataset Overview and Inspection**


*   This block prints out summary statistics for each dataset, including the shape, columns, and a few sample rows to ensure the data is loaded correctly.
*   the paths to each dataset (train, validation, and test) are set and the datasets are loaded into pandas DataFrames for further processing.








In [None]:

# Dataset file paths
dataset_paths = {
    "enron1_train": './IDaSec-project/dataset/enron1/enron1_train.csv',
    "enron1_val": './IDaSec-project/dataset/enron1/enron1_val.csv',
    "enron1_test": './IDaSec-project/dataset/enron1/enron1_test.csv',
    "enron2_train": './IDaSec-project/dataset/enron2/enron2_train.csv',
    "enron2_val": './IDaSec-project/dataset/enron2/enron2_val.csv',
    "enron2_test": './IDaSec-project/dataset/enron2/enron2_test.csv',
    "sms_train": './IDaSec-project/dataset/sms/train.csv',
    "sms_val": './IDaSec-project/dataset/sms/val.csv',
    "sms_test": './IDaSec-project/dataset/sms/test.csv',
}

# Load datasets into a dictionary
datasets = {name: pd.read_csv(path) for name, path in dataset_paths.items()}

# Display summary info
for name, df in datasets.items():
    print(f"\n=== {name.upper()} ===")
    print(f"Shape: {df.shape}")
    print("Columns:", list(df.columns))
    print("Sample rows:")
    print(df.head(3))



=== ENRON1_TRAIN ===
Shape: (3196, 2)
Columns: ['email', 'target']
Sample rows:
                                               email target
0  Subject: prom dress shopping hi , just wanted ...    ham
1  Subject: hi agaain hello , welcome to pharm la...   spam
2  Subject: feedback monitor error - meter 984132...    ham

=== ENRON1_VAL ===
Shape: (799, 2)
Columns: ['email', 'target']
Sample rows:
                                               email target
0  Subject: union carbide - seadrift hpl meter # ...    ham
1  Subject: best choice rx - free online prescrip...   spam
2  Subject: new nat gas delivery location pursuan...    ham

=== ENRON1_TEST ===
Shape: (999, 2)
Columns: ['email', 'target']
Sample rows:
                                               email target
0  Subject: unify / sitara enhancements i am comp...    ham
1  Subject: weekend activity dated : june 2 thru ...    ham
2  Subject: re : tittletattle secrets dn ' t ie n...   spam

=== ENRON2_TRAIN ===
Shape: (3727, 2)
Col

# **Function to Load and Clean Data**

The **load_and_clean** function to load CSV files, clean the data by encoding the labels, and return a DataFrame with only the email and target columns.

**Shape:** Displays the number of rows and columns.

**Label Distribution:** Shows the count of ham and spam labels.

**Sample Rows:** Displays the first three rows to inspect the data.

In [None]:
# Function to load and convert labels
def load_and_clean(path):
    df = pd.read_csv(path)
    df['target'] = df['target'].str.strip().str.lower().map({'ham': 0, 'spam': 1})
    return df[['email', 'target']]

# Load and process all datasets
datasets = {name: load_and_clean(path) for name, path in dataset_paths.items()}

# Preview cleaned data
for name, df in datasets.items():
    print(f"\n=== {name.upper()} ===")
    print(f"Shape: {df.shape}")
    print(df['target'].value_counts())
    print(df.head(3))


=== ENRON1_TRAIN ===
Shape: (3196, 2)
target
0    2260
1     936
Name: count, dtype: int64
                                               email  target
0  Subject: prom dress shopping hi , just wanted ...       0
1  Subject: hi agaain hello , welcome to pharm la...       1
2  Subject: feedback monitor error - meter 984132...       0

=== ENRON1_VAL ===
Shape: (799, 2)
target
0    565
1    234
Name: count, dtype: int64
                                               email  target
0  Subject: union carbide - seadrift hpl meter # ...       0
1  Subject: best choice rx - free online prescrip...       1
2  Subject: new nat gas delivery location pursuan...       0

=== ENRON1_TEST ===
Shape: (999, 2)
target
0    706
1    293
Name: count, dtype: int64
                                               email  target
0  Subject: unify / sitara enhancements i am comp...       0
1  Subject: weekend activity dated : june 2 thru ...       0
2  Subject: re : tittletattle secrets dn ' t ie n...       1




# **Data Preprocessing and Merging Datasets**

A **preprocessing function** for the text data, extracting metadata, encoding labels, and normalizing the metadata. It also combines the preprocessed datasets into a single **training**, **validation**, and **test dataset.**




**Preprocessing Function (preprocess_df):**

**Text Cleaning:** The email text column is cleaned by converting all text to lowercase and removing any non-word characters using a regular expression.

**Label Encoding:** The target labels (spam and ham) are mapped to binary values (1 for spam, 0 for ham).

**Metadata Extraction:**


*   Extracts the length of the subject line from the email.

*   Calculates the total length of the email text.

*   **Return:** The function returns the processed DataFrame and the scaler used for normalization.




In [None]:

# Preprocessing function
def preprocess_df(df, text_col='email', label_col='target'):
    df = df.copy()
    # Clean text
    df[text_col] = df[text_col].str.lower().str.replace(r'[^\w\s]', '', regex=True)
    # Encode labels
    df[label_col] = df[label_col].map({'spam': 1, 'ham': 0})

    # Extract metadata
    df['subject_length'] = df[text_col].apply(lambda x: len(re.match(r'subject:.*?(?=\n)', x, re.DOTALL).group(0)) if re.match(r'subject:.*?(?=\n)', x, re.DOTALL) else 0)
    df['text_length'] = df[text_col].str.len()

    # Normalize metadata
    scaler = StandardScaler()
    metadata_cols = ['subject_length', 'text_length']
    df[metadata_cols] = scaler.fit_transform(df[metadata_cols])

    return df, scaler

# Apply preprocessing
preprocessed_datasets = {}
scalers = {}
for name, df in datasets.items():
    preprocessed_datasets[name], scalers[name] = preprocess_df(df)

# Combine datasets
train_df = pd.concat([
    preprocessed_datasets['enron1_train'],
    preprocessed_datasets['enron2_train'],
    preprocessed_datasets['sms_train']
]).reset_index(drop=True)
val_df = pd.concat([
    preprocessed_datasets['enron1_val'],
    preprocessed_datasets['enron2_val'],
    preprocessed_datasets['sms_val']
]).reset_index(drop=True)
test_df = pd.concat([
    preprocessed_datasets['enron1_test'],
    preprocessed_datasets['enron2_test'],
    preprocessed_datasets['sms_test']
]).reset_index(drop=True)

# **Tokenization of Text Data**
**Tokenizing** the email text data using the **DistilBERT** tokenizer, preparing it for input to the **DistilBERT** model.

Converting the email text data into tokenized format using the DistilBERT tokenizer, ensuring the data is properly **formatted** for model input (padding and truncation as necessary).

**DistilBERT Tokenizer:** The pre-trained DistilBERT tokenizer (distilbert-base-uncased) is loaded from the Hugging Face model hub.

This **Tokenizer** is specifically designed for the DistilBERT model and handles tasks like tokenization, padding, and truncation.



In [None]:


tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

def tokenize_data(texts, max_length=128):
    return tokenizer(
        texts.tolist(),
        max_length=max_length,
        padding=True,
        truncation=True,
        return_tensors='pt'
    )

# Tokenize combined datasets
train_encodings = tokenize_data(train_df['email'])
val_encodings = tokenize_data(val_df['email'])
test_encodings = tokenize_data(test_df['email'])

# **Custom Dataset Class and DataLoader Preparation**
A custom Dataset class to handle **tokenized** data and **metadata**, and prepares the **DataLoader** objects for efficient batching during model training and evaluation.

**SpamDataset Class:**

A custom subclass of torch.utils.data.Dataset that is designed to handle the tokenized data (encodings), additional metadata (metadata), and target labels (labels).

**__init__ method:** Initializes the dataset with tokenized data, metadata, and labels.

**__len__ method:** Returns the length of the dataset (the number of samples).

**__getitem__ method:** Retrieves the item at the specified index (idx). This includes:



In [None]:

class SpamDataset(Dataset):
    def __init__(self, encodings, metadata, labels):
        self.encodings = encodings
        self.metadata = metadata
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['metadata'] = torch.tensor(self.metadata[idx], dtype=torch.float)
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

# Prepare metadata
metadata_cols = ['subject_length', 'text_length']
train_metadata = train_df[metadata_cols].values
val_metadata = val_df[metadata_cols].values
test_metadata = test_df[metadata_cols].values

# Create datasets
train_dataset = SpamDataset(train_encodings, train_metadata, train_df['target'].values)
val_dataset = SpamDataset(val_encodings, val_metadata, val_df['target'].values)
test_dataset = SpamDataset(test_encodings, test_metadata, test_df['target'].values)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)
test_loader = DataLoader(test_dataset, batch_size=16)

# **Defining the Model Class**



**Dataset Description:**

 A custom model class that uses DistilBERT for text processing and incorporates metadata for enhanced prediction. The model combines the outputs from both text and metadata using weighted contributions.
  
  

**Model Class (DistilBERTWithMetadata):**

  **DistilBertModel:** Loads the pre-trained **DistilBERT** model for text representation.

  A fully connected layer that projects **metadata** (with metadata_dim features) to 64 dimensions.

  **Classifier:** A fully connected layer that takes the combined output of **BERT** and **metadata** features (768 dimensions from BERT + 64 from metadata) and outputs logits for 2 classes (spam and ham).




In [None]:

# Define the model class with increased dropout
class DistilBERTWithMetadata(nn.Module):
    def __init__(self, metadata_dim, dropout=0.3):  # Increased dropout from 0.1 to 0.3
        super(DistilBERTWithMetadata, self).__init__()
        self.bert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.dropout = nn.Dropout(dropout)
        self.metadata_fc = nn.Linear(metadata_dim, 64)  # Project metadata to 64 dims
        self.classifier = nn.Linear(768 + 64, 2)  # 768 (BERT) + 64 (metadata)
        self.text_weight = 0.8
        self.metadata_weight = 0.2

    def forward(self, input_ids, attention_mask, metadata):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = bert_output.last_hidden_state[:, 0]  # [CLS] token
        pooled_output = self.dropout(pooled_output)

        metadata_output = torch.relu(self.metadata_fc(metadata))
        metadata_output = self.dropout(metadata_output)

        weighted_text = self.text_weight * pooled_output
        weighted_metadata = self.metadata_weight * metadata_output
        combined = torch.cat((weighted_text, weighted_metadata), dim=-1)

        logits = self.classifier(combined)
        return logits

model = DistilBERTWithMetadata(metadata_dim=len(metadata_cols))

# **Model Initialization, Training, and Evaluation Setup**

Initialize the model with the required configurations, set up the optimizer and scheduler, and define the functions for training and evaluating the model. The training loop includes early stopping and model checkpoint saving based on validation loss.

### **Model Class (DistilBERTWithMetadata):**

A custom neural network that combines **DistilBERT** with additional metadata (subject length, text length) for classification.



Model, Optimizer, and Loss Function Initialization:

**Model:** The **DistilBERTWithMetadata** class is initialized with the number of metadata features.

**Optimizer:** Uses AdamW with weight decay for regularization.

**Loss Function:** Cross-entropy loss is used for binary classification.


Training and Evaluation Functions:

**train_epoch:** Defines the training loop for a single epoch, including loss computation, backpropagation, and optimization.


**Model Saving:** Saves the model if validation and Training loss improves.




In [None]:

# Define the model class with increased dropout
class DistilBERTWithMetadata(nn.Module):
    def __init__(self, metadata_dim, dropout=0.3):  # Increased dropout from 0.1 to 0.3
        super(DistilBERTWithMetadata, self).__init__()
        self.bert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.dropout = nn.Dropout(dropout)
        self.metadata_fc = nn.Linear(metadata_dim, 64)  # Project metadata to 64 dims
        self.classifier = nn.Linear(768 + 64, 2)  # 768 (BERT) + 64 (metadata)
        self.text_weight = 0.8
        self.metadata_weight = 0.2

    def forward(self, input_ids, attention_mask, metadata):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = bert_output.last_hidden_state[:, 0]  # [CLS] token
        pooled_output = self.dropout(pooled_output)

        metadata_output = torch.relu(self.metadata_fc(metadata))
        metadata_output = self.dropout(metadata_output)

        weighted_text = self.text_weight * pooled_output
        weighted_metadata = self.metadata_weight * metadata_output
        combined = torch.cat((weighted_text, weighted_metadata), dim=-1)

        logits = self.classifier(combined)
        return logits

# Mount Google Drive
drive.mount('/content/drive', force_remount=True)  # Force remount to avoid "already mounted" issue

# Initialize model, optimizer, and loss function
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DistilBERTWithMetadata(metadata_dim=len(metadata_cols)).to(device)
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)  # Added weight decay
loss_fn = CrossEntropyLoss()

# Learning rate scheduler
num_epochs = 10
total_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

# Training and evaluation functions
def train_epoch(model, data_loader, optimizer, device, scheduler):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for batch in data_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        metadata = batch['metadata'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask, metadata)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()  # Update learning rate

        total_loss += loss.item()
        _, preds = torch.max(outputs, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return total_loss / len(data_loader), correct / total

def evaluate(model, data_loader, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            metadata = batch['metadata'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask, metadata)
            loss = loss_fn(outputs, labels)

            total_loss += loss.item()
            _, preds = torch.max(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return total_loss / len(data_loader), correct / total

# Training loop with model and loss saving, including early stopping
train_losses, val_losses = [], []
train_accs, val_accs = [], []
best_val_loss = float('inf')
patience = 2
epochs_no_improve = 0
save_path = '/content/drive/My Drive/distilbert_finetuned.pt'
train_losses_path = '/content/drive/My Drive/train_losses.txt'
val_losses_path = '/content/drive/My Drive/val_losses.txt'

for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, device, scheduler)
    val_loss, val_acc = evaluate(model, val_loader, device)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)

    # Print only the four average metrics
    print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

    # Save model if validation loss improves
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), save_path)
        print(f'Model saved to {save_path}')
    else:
        epochs_no_improve += 1

    # Early stopping check
    if epochs_no_improve >= patience:
        print(f'Early stopping triggered after {epoch+1} epochs')
        break

# Save average training and validation losses to separate text files
with open(train_losses_path, 'w') as f:
    for loss in train_losses:
        f.write(f'{loss:.4f}\n')
print(f'Training losses saved to {train_losses_path}')

with open(val_losses_path, 'w') as f:
    for loss in val_losses:
        f.write(f'{loss:.4f}\n')
print(f'Validation losses saved to {val_losses_path}')



Mounted at /content/drive
Epoch 1: Train Loss: 0.1061, Train Acc: 0.9591, Val Loss: 0.0704, Val Acc: 0.9821
Model saved to /content/drive/My Drive/distilbert_finetuned.pt
Epoch 2: Train Loss: 0.0232, Train Acc: 0.9932, Val Loss: 0.0707, Val Acc: 0.9825
Epoch 3: Train Loss: 0.0079, Train Acc: 0.9980, Val Loss: 0.0679, Val Acc: 0.9813
Model saved to /content/drive/My Drive/distilbert_finetuned.pt
Epoch 4: Train Loss: 0.0050, Train Acc: 0.9985, Val Loss: 0.0620, Val Acc: 0.9863
Model saved to /content/drive/My Drive/distilbert_finetuned.pt
Epoch 5: Train Loss: 0.0025, Train Acc: 0.9993, Val Loss: 0.1528, Val Acc: 0.9722
Epoch 6: Train Loss: 0.0028, Train Acc: 0.9995, Val Loss: 0.0783, Val Acc: 0.9874
Early stopping triggered after 6 epochs
Training losses saved to /content/drive/My Drive/train_losses.txt
Validation losses saved to /content/drive/My Drive/val_losses.txt


# **Model Evaluation and Classification Report**

**Purpose:** Evaluate the model's performance on the test dataset and print the loss, accuracy, and classification report to assess the model's prediction quality.

Model Class (**DistilBERTWithMetadata):**

The unchanged model class is used here, which combines DistilBERT for text processing with metadata features for classification.


**Loss Function:**

Cross-entropy loss is used for multi-class classification.



**Evaluation Mode:** The model is set to evaluation mode (**model.eval()**) to ensure no gradients are computed.

Loss and Accuracy Calculation:

The model is evaluated on the test dataset **(test_loader)**, and the results (loss, accuracy, classification report) are printed.



In [None]:

# Define the model class (unchanged)
class DistilBERTWithMetadata(nn.Module):
    def __init__(self, metadata_dim, dropout=0.3):
        super(DistilBERTWithMetadata, self).__init__()
        self.bert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.dropout = nn.Dropout(dropout)
        self.metadata_fc = nn.Linear(metadata_dim, 64)
        self.classifier = nn.Linear(768 + 64, 2)
        self.text_weight = 0.8
        self.metadata_weight = 0.2

    def forward(self, input_ids, attention_mask, metadata):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = bert_output.last_hidden_state[:, 0]
        pooled_output = self.dropout(pooled_output)

        metadata_output = torch.relu(self.metadata_fc(metadata))
        metadata_output = self.dropout(metadata_output)

        weighted_text = self.text_weight * pooled_output
        weighted_metadata = self.metadata_weight * metadata_output
        combined = torch.cat((weighted_text, weighted_metadata), dim=-1)

        logits = self.classifier(combined)
        return logits

# Mount Google Drive
drive.mount('/content/drive', force_remount=True)

# Initialize model and load saved weights
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DistilBERTWithMetadata(metadata_dim=len(metadata_cols)).to(device)
model.load_state_dict(torch.load('/content/drive/My Drive/distilbert_finetuned.pt'))
model.eval()

# Loss function
loss_fn = CrossEntropyLoss()

# Evaluation function with classification report
def evaluate(model, data_loader, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            metadata = batch['metadata'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask, metadata)
            loss = loss_fn(outputs, labels)
            total_loss += loss.item()

            _, preds = torch.max(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculate average loss
    avg_loss = total_loss / len(data_loader)

    # Generate classification report
    report = classification_report(all_labels, all_preds, target_names=['ham', 'spam'])
    accuracy = (sum(p == l for p, l in zip(all_preds, all_labels)) / len(all_labels))

    return avg_loss, accuracy, report

# Evaluate on test set
test_loss, test_acc, test_report = evaluate(model, test_loader, device)
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}')
print('\nDetailed Classification Report:')
print(test_report)

Mounted at /content/drive
Test Loss: 0.0382, Test Accuracy: 0.9912

Detailed Classification Report:
              precision    recall  f1-score   support

         ham       0.99      0.99      0.99      2538
        spam       0.98      0.98      0.98       741

    accuracy                           0.99      3279
   macro avg       0.99      0.99      0.99      3279
weighted avg       0.99      0.99      0.99      3279



# **Inference on Test Data**


*   This block demonstrates how to load a pre-trained model and tokenizer from Google Drive, and make predictions on a new test batch consisting of email texts and metadata features.




*   Using the trained model to classify a batch of email texts as spam or ham. This block also demonstrates how to process text and metadata input for prediction.





In [None]:

# Load tokenizer
drive.mount('/content/drive', force_remount=True)
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

# Define the model class
class DistilBERTWithMetadata(nn.Module):
    def __init__(self, metadata_dim, dropout=0.3):
        super(DistilBERTWithMetadata, self).__init__()
        self.bert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.dropout = nn.Dropout(dropout)
        self.metadata_fc = nn.Linear(metadata_dim, 64)  # metadata_dim should match training
        self.classifier = nn.Linear(768 + 64, 2)
        self.text_weight = 0.8
        self.metadata_weight = 0.2

    def forward(self, input_ids, attention_mask, metadata):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = bert_output.last_hidden_state[:, 0]
        pooled_output = self.dropout(pooled_output)

        metadata_output = torch.relu(self.metadata_fc(metadata))
        metadata_output = self.dropout(metadata_output)

        weighted_text = self.text_weight * pooled_output
        weighted_metadata = self.metadata_weight * metadata_output
        combined = torch.cat((weighted_text, weighted_metadata), dim=-1)

        logits = self.classifier(combined)
        return logits

# Initialize model with correct metadata_dim (must match training)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
metadata_cols = ['sender_score', 'time']  # Use 2 features to match training (was 2 in checkpoint)
model = DistilBERTWithMetadata(metadata_dim=len(metadata_cols)).to(device)
model.load_state_dict(torch.load('/content/drive/My Drive/distilbert_finetuned.pt'))
model.eval()

# Example test batch with 2 metadata features
test_texts = [
    "This is a legitimate email about your order confirmation.",
    "Win a free prize now! Click here immediately!!!"
]
labels = torch.tensor([0, 1])  # 0 = ham, 1 = spam
metadata = torch.tensor([
    [1.0, 0.5],  # Metadata for ham (e.g., sender_score, time)
    [0.1, 0.9]   # Metadata for spam (e.g., suspicious sender, recent time)
]).float()  # Ensure float type for linear layer

# Tokenize text
inputs = tokenizer(test_texts, return_tensors="pt", max_length=512, padding="max_length", truncation=True)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
metadata = metadata.to(device)
labels = labels.to(device)

# Get predictions
with torch.no_grad():
    outputs = model(input_ids, attention_mask, metadata)
    _, predicted = torch.max(outputs, dim=1)

# Print results
for i in range(len(test_texts)):
    print(f"Text: {test_texts[i]}")
    print(f"True Label: {labels[i].item()} (0 = ham, 1 = spam)")
    print(f"Predicted Label: {predicted[i].item()} (0 = ham, 1 = spam)")
    print("---")

Mounted at /content/drive
Text: This is a legitimate email about your order confirmation.
True Label: 0 (0 = ham, 1 = spam)
Predicted Label: 0 (0 = ham, 1 = spam)
---
Text: Win a free prize now! Click here immediately!!!
True Label: 1 (0 = ham, 1 = spam)
Predicted Label: 1 (0 = ham, 1 = spam)
---
