In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertModel, DistilBertTokenizer
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet
import random
import os
import glob

In [2]:
df = pd.read_csv(r'data\full_text\balanced2_df.csv')

# Split the dataset into training and validation sets
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['Closest_State'])

# Create LabelEncoder to convert state names to numerical labels
label_encoder = LabelEncoder()
train_df['State_Label'] = label_encoder.fit_transform(train_df['Closest_State'])
val_df['State_Label'] = label_encoder.transform(val_df['Closest_State'])


In [3]:
# Load the DistilBERT tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

# Data augmentation: Synonym replacement
def synonym_replacement(text):
    if not isinstance(text, str):  # Check if text is a string
        return ""  # Return empty string or any default value

    words = text.split()
    new_words = words.copy()
    random_word_list = list(set([word for word in words if wordnet.synsets(word)]))  # Words that have synonyms

    # Your synonym replacement logic goes here...
    # For example, replace a random word with a synonym
    if random_word_list:
        random_word = random.choice(random_word_list)
        synonym = random.choice(wordnet.synsets(random_word)).lemmas()[0].name()
        new_words = [synonym if word == random_word else word for word in new_words]

    return ' '.join(new_words)  # Return modified text

# Ensure TweetText has no NaN values
train_df = train_df.dropna(subset=['TweetText', 'Closest_City', 'Closest_State'])

# Tokenization function
def tokenize_tweets(text):
    return tokenizer.encode_plus(
        text,
        add_special_tokens=True,  # Add [CLS] and [SEP] tokens
        max_length=100,           # Maximum length of tokens
        padding='max_length',     # Pad to max length
        truncation=True,          # Truncate longer sequences
        return_tensors='pt'       # Return PyTorch tensors
    )

# Augment training data
train_df['Augmented_TweetText'] = train_df['TweetText'].apply(synonym_replacement)
train_df['TweetText'] = train_df['TweetText'] + ' ' + train_df['Augmented_TweetText']  # Concatenate original and augmented text

# Ensure TweetText has no NaN values in both training and validation sets
train_df = train_df.dropna(subset=['TweetText', 'Closest_City', 'Closest_State'])
val_df = val_df.dropna(subset=['TweetText', 'Closest_City', 'Closest_State'])  # Drop NaNs in validation set

# Generate TF-IDF features for training tweets
vectorizer = TfidfVectorizer(max_features=1000)  # Limit features for simplicity
train_tfidf = vectorizer.fit_transform(train_df['TweetText']).toarray()

# Generate TF-IDF features for validation tweets
val_tfidf = vectorizer.transform(val_df['TweetText']).toarray()  # Transform using the fitted vectorizer

# Create input IDs and attention masks for training set
train_input_ids = torch.cat([tokenize_tweets(str(text))['input_ids'] for text in train_df['TweetText']])
train_attention_masks = torch.cat([tokenize_tweets(str(text))['attention_mask'] for text in train_df['TweetText']])
train_labels = train_df['State_Label'].values

# Create input IDs and attention masks for validation set
val_input_ids = torch.cat([tokenize_tweets(str(text))['input_ids'] for text in val_df['TweetText']])
val_attention_masks = torch.cat([tokenize_tweets(str(text))['attention_mask'] for text in val_df['TweetText']])
val_labels = val_df['State_Label'].values


# Create Hybrid Dataset class
class HybridTweetDataset(Dataset):
    def __init__(self, input_ids, attention_masks, tfidf_features, labels):
        self.input_ids = input_ids
        self.attention_masks = attention_masks
        self.tfidf_features = torch.tensor(tfidf_features, dtype=torch.float32)
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_masks[idx],
            'tfidf_features': self.tfidf_features[idx],
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }
           
# Data loaders
train_dataset = HybridTweetDataset(train_input_ids, train_attention_masks, train_tfidf, train_labels)
val_dataset = HybridTweetDataset(val_input_ids, val_attention_masks, val_tfidf, val_labels)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


# Define hybrid DistilBERT + TF-IDF classifier
class HybridDistilBertClassifier(nn.Module):
    def __init__(self, num_classes, tfidf_size=1000):
        super(HybridDistilBertClassifier, self).__init__()
        self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.fc_tfidf = nn.Linear(tfidf_size, 64)  # Dense layer for TF-IDF
        self.fc_combined = nn.Linear(self.distilbert.config.hidden_size + 64, num_classes)  # Combine DistilBERT and TF-IDF

    def forward(self, input_ids, attention_mask, tfidf_features):
        distilbert_outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = distilbert_outputs.last_hidden_state[:, 0, :]  # CLS token representation
        tfidf_output = F.relu(self.fc_tfidf(tfidf_features))  # Pass TF-IDF through dense layer
        combined_output = torch.cat((cls_output, tfidf_output), dim=1)  # Concatenate BERT and TF-IDF
        logits = self.fc_combined(combined_output)
        return logits

# Early stopping implementation
class EarlyStopping:
    def __init__(self, patience=3, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.best_score = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_score is None:
            self.best_score = val_loss
        elif val_loss > self.best_score - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_loss
            self.counter = 0

# Training loop with early stopping
def train_model(model, train_loader, val_loader, epochs=10, learning_rate=1e-5, start_epoch=0, patience=3, checkpoint_path='models//'):
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)

    early_stopping = EarlyStopping(patience=patience)

    for epoch in range(start_epoch, epochs):
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}"):
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            tfidf_features = batch['tfidf_features'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass
            outputs = model(input_ids, attention_mask, tfidf_features)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            # Backward pass
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_loader):.4f}")

        # Save model after each epoch
        model_save_path = os.path.join(checkpoint_path, f'distilbert_model_epoch_{epoch + 1}.pt')
        torch.save(model.state_dict(), model_save_path)
        print(f'Model saved to {model_save_path}')

        # Evaluate on validation set
        val_loss = evaluate_model(model, val_loader, criterion)

        # Early stopping check
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print(f"Early stopping at epoch {epoch + 1}")
            break

# Evaluation function
def evaluate_model(model, val_loader, criterion):
    model.eval()
    correct = 0
    total = 0
    total_loss = 0
    all_preds = []
    all_labels = []

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            tfidf_features = batch['tfidf_features'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask, tfidf_features)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    f1_macro = f1_score(all_labels, all_preds, average='macro')
    f1_micro = f1_score(all_labels, all_preds, average='micro')

    print(f"Validation Accuracy: {accuracy:.4f}, F1 Macro: {f1_macro:.4f}, F1 Micro: {f1_micro:.4f}")
    return total_loss / len(val_loader)

# Load the model from the latest checkpoint if available
def load_model(model, checkpoint_path):
    # Find all checkpoint files in the directory
    checkpoint_files = glob.glob(os.path.join(checkpoint_path, 'distilbert_model_epoch_*.pt'))

    if checkpoint_files:
        # Sort the files to find the latest one
        latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
        print(f"Loading model from {latest_checkpoint}")

        try:
            # Load the model state onto the appropriate device
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            model.load_state_dict(torch.load(latest_checkpoint, map_location=device))

            # Extract starting epoch from the latest checkpoint filename
            start_epoch = int(latest_checkpoint.split('epoch_')[1].split('.pt')[0])

            return model, latest_checkpoint, start_epoch
        except Exception as e:
            print(f"Error loading the model: {e}")
            return model, None, 0
    else:
        print("No checkpoint files found.")
        return model, None, 0

# Initialize the model with the number of unique states
num_classes = len(df['Closest_State'].unique())  # Update based on your unique states
model = HybridDistilBertClassifier(num_classes, tfidf_size=1000)


In [6]:
# Load from checkpoint if available
checkpoint_path = 'models//'

# Load the model from the latest checkpoint if it exists
model, latest_checkpoint, start_epoch = load_model(model, checkpoint_path)

No checkpoint files found.


In [7]:
# Train the model
train_model(model, train_loader, val_loader, epochs=10, learning_rate=2e-5, start_epoch=0, patience=3, checkpoint_path='models//')

Training Epoch 1: 100%|██████████| 1531/1531 [13:13<00:00,  1.93it/s]


Epoch 1/10, Loss: 3.1187
Model saved to models//distilbert_model_epoch_1.pt
Validation Accuracy: 0.1243, F1 Macro: 0.0280, F1 Micro: 0.1243


Training Epoch 2: 100%|██████████| 1531/1531 [13:09<00:00,  1.94it/s]


Epoch 2/10, Loss: 2.9956
Model saved to models//distilbert_model_epoch_2.pt
Validation Accuracy: 0.1286, F1 Macro: 0.0339, F1 Micro: 0.1286


Training Epoch 3: 100%|██████████| 1531/1531 [13:08<00:00,  1.94it/s]


Epoch 3/10, Loss: 2.7403
Model saved to models//distilbert_model_epoch_3.pt
Validation Accuracy: 0.1194, F1 Macro: 0.0336, F1 Micro: 0.1194


Training Epoch 4: 100%|██████████| 1531/1531 [13:07<00:00,  1.94it/s]


Epoch 4/10, Loss: 2.3677
Model saved to models//distilbert_model_epoch_4.pt
Validation Accuracy: 0.1075, F1 Macro: 0.0383, F1 Micro: 0.1075
Early stopping at epoch 4
