In [1]:
import os
import json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import mean_squared_error, mean_absolute_error, accuracy_score, f1_score
from sklearn.preprocessing import LabelEncoder
from transformers import AutoTokenizer, AutoModel
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle


In [2]:
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [None]:
import os
import json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import mean_squared_error, mean_absolute_error, accuracy_score, f1_score
from sklearn.preprocessing import LabelEncoder
from transformers import AutoTokenizer, AutoModel
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define paths
DATA_DIR = "path/to/your/data/"  # Update this to your data directory
MODELS_DIR = "saved_models/"
METRICS_DIR = "metrics/"

# Create directories if they don't exist
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(METRICS_DIR, exist_ok=True)

# Load data
def load_data(data_dir):
    train_path = os.path.join(data_dir, "sanitized_train.json")
    val_path = os.path.join(data_dir, "sanitized_validate.json")
    test_path = os.path.join(data_dir, "sanitized_test.json")
    
    with open(train_path, 'r') as f:
        train_data = json.load(f)
    
    with open(val_path, 'r') as f:
        val_data = json.load(f)
    
    with open(test_path, 'r') as f:
        test_data = json.load(f)
    
    return train_data, val_data, test_data

# Custom Dataset class
class ResponseLengthDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512, task='regression', bin_size=50):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.task = task
        self.bin_size = bin_size
        
        # Encode model names
        self.models = list(set([item['model'] for item in data]))
        self.model_encoder = {model: idx for idx, model in enumerate(self.models)}
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Tokenize text
        encoding = self.tokenizer(
            item['prompt_text'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Extract features
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        
        # One-hot encode model
        model_idx = self.model_encoder[item['model']]
        model_tensor = torch.zeros(len(self.models))
        model_tensor[model_idx] = 1
        
        if self.task == 'regression':
            # Targets for regression
            output_mean = item['output_mean']
            output_std = item['output_std']
            target = torch.tensor([output_mean, output_std], dtype=torch.float)
            
        else:  # Classification
            # Get p99 percentile and convert to class
            p99 = item['output_percentiles']['99']
            p99_class = int(p99 // self.bin_size)
            target = torch.tensor(p99_class, dtype=torch.long)
            
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask, 
            'model': model_tensor,
            'target': target
        }

# Model definitions
class TextEmbeddingModel(nn.Module):
    def __init__(self, pretrained_model_name="distilbert-base-uncased"):
        super(TextEmbeddingModel, self).__init__()
        self.bert = AutoModel.from_pretrained(pretrained_model_name)
        self.hidden_size = self.bert.config.hidden_size
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        # Use the CLS token representation as the text embedding
        return outputs.last_hidden_state[:, 0, :]

class RegressionPredictor(nn.Module):
    def __init__(self, text_embedding_size, num_models, hidden_size=128):
        super(RegressionPredictor, self).__init__()
        self.text_encoder = TextEmbeddingModel()
        text_embedding_size = self.text_encoder.hidden_size
        
        self.fc1 = nn.Linear(text_embedding_size + num_models, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size // 2)
        self.fc3 = nn.Linear(hidden_size // 2, 2)  # Output mean and std
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, input_ids, attention_mask, model_enc):
        text_embedding = self.text_encoder(input_ids, attention_mask)
        combined = torch.cat((text_embedding, model_enc), dim=1)
        
        x = self.dropout(self.relu(self.fc1(combined)))
        x = self.dropout(self.relu(self.fc2(x)))
        output = self.fc3(x)
        
        return output

class ClassificationPredictor(nn.Module):
    def __init__(self, text_embedding_size, num_models, num_classes, hidden_size=128):
        super(ClassificationPredictor, self).__init__()
        self.text_encoder = TextEmbeddingModel()
        text_embedding_size = self.text_encoder.hidden_size
        
        self.fc1 = nn.Linear(text_embedding_size + num_models, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size // 2)
        self.fc3 = nn.Linear(hidden_size // 2, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, input_ids, attention_mask, model_enc):
        text_embedding = self.text_encoder(input_ids, attention_mask)
        combined = torch.cat((text_embedding, model_enc), dim=1)
        
        x = self.dropout(self.relu(self.fc1(combined)))
        x = self.dropout(self.relu(self.fc2(x)))
        logits = self.fc3(x)
        
        return logits

# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5, task='regression'):
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    metrics = {
        'train_loss': [],
        'val_loss': [],
        'epoch': []
    }
    
    if task == 'regression':
        metrics.update({
            'train_mse': [],
            'train_mae': [],
            'val_mse': [],
            'val_mae': []
        })
    else:  # classification
        metrics.update({
            'train_accuracy': [],
            'train_f1': [],
            'val_accuracy': [],
            'val_f1': []
        })
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        
        all_preds = []
        all_targets = []
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            model_tensor = batch['model'].to(device)
            targets = batch['target'].to(device)
            
            # Forward pass
            outputs = model(input_ids, attention_mask, model_tensor)
            
            # Loss calculation
            loss = criterion(outputs, targets)
            
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            # Store predictions and targets for metrics
            if task == 'regression':
                all_preds.extend(outputs.detach().cpu().numpy())
                all_targets.extend(targets.detach().cpu().numpy())
            else:  # classification
                _, predicted = torch.max(outputs, 1)
                all_preds.extend(predicted.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
        
        # Calculate average training loss
        avg_train_loss = train_loss / len(train_loader)
        
        # Validation
        model.eval()
        val_loss = 0.0
        
        val_preds = []
        val_targets = []
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]"):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                model_tensor = batch['model'].to(device)
                targets = batch['target'].to(device)
                
                # Forward pass
                outputs = model(input_ids, attention_mask, model_tensor)
                
                # Loss calculation
                loss = criterion(outputs, targets)
                val_loss += loss.item()
                
                # Store predictions and targets for metrics
                if task == 'regression':
                    val_preds.extend(outputs.detach().cpu().numpy())
                    val_targets.extend(targets.detach().cpu().numpy())
                else:  # classification
                    _, predicted = torch.max(outputs, 1)
                    val_preds.extend(predicted.cpu().numpy())
                    val_targets.extend(targets.cpu().numpy())
        
        # Calculate average validation loss
        avg_val_loss = val_loss / len(val_loader)
        
        # Store losses for plotting
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        
        # Calculate and store metrics
        metrics['train_loss'].append(avg_train_loss)
        metrics['val_loss'].append(avg_val_loss)
        metrics['epoch'].append(epoch + 1)
        
        if task == 'regression':
            # Convert lists to numpy arrays for metric calculation
            all_preds_np = np.array(all_preds)
            all_targets_np = np.array(all_targets)
            val_preds_np = np.array(val_preds)
            val_targets_np = np.array(val_targets)
            
            # Calculate MSE and MAE for mean predictions (first column)
            train_mse_mean = mean_squared_error(all_targets_np[:, 0], all_preds_np[:, 0])
            train_mae_mean = mean_absolute_error(all_targets_np[:, 0], all_preds_np[:, 0])
            val_mse_mean = mean_squared_error(val_targets_np[:, 0], val_preds_np[:, 0])
            val_mae_mean = mean_absolute_error(val_targets_np[:, 0], val_preds_np[:, 0])
            
            metrics['train_mse'].append(train_mse_mean)
            metrics['train_mae'].append(train_mae_mean)
            metrics['val_mse'].append(val_mse_mean)
            metrics['val_mae'].append(val_mae_mean)
            
            print(f"Epoch {epoch+1}/{num_epochs} - "
                  f"Train Loss: {avg_train_loss:.4f}, "
                  f"Train MSE (mean): {train_mse_mean:.4f}, "
                  f"Val Loss: {avg_val_loss:.4f}, "
                  f"Val MSE (mean): {val_mse_mean:.4f}")
            
        else:  # classification
            train_acc = accuracy_score(all_targets, all_preds)
            train_f1 = f1_score(all_targets, all_preds, average='weighted')
            val_acc = accuracy_score(val_targets, val_preds)
            val_f1 = f1_score(val_targets, val_preds, average='weighted')
            
            metrics['train_accuracy'].append(train_acc)
            metrics['train_f1'].append(train_f1)
            metrics['val_accuracy'].append(val_acc)
            metrics['val_f1'].append(val_f1)
            
            print(f"Epoch {epoch+1}/{num_epochs} - "
                  f"Train Loss: {avg_train_loss:.4f}, "
                  f"Train Acc: {train_acc:.4f}, "
                  f"Val Loss: {avg_val_loss:.4f}, "
                  f"Val Acc: {val_acc:.4f}")
        
        # Save the best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            print(f"Saving best model with validation loss: {best_val_loss:.4f}")
            torch.save(model.state_dict(), os.path.join(MODELS_DIR, f"best_{task}_model.pt"))
    
    return model, metrics

# Evaluate the model on test set
def evaluate_model(model, test_loader, criterion, task='regression'):
    model.eval()
    test_loss = 0.0
    
    test_preds = []
    test_targets = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            model_tensor = batch['model'].to(device)
            targets = batch['target'].to(device)
            
            # Forward pass
            outputs = model(input_ids, attention_mask, model_tensor)
            
            # Loss calculation
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            
            # Store predictions and targets for metrics
            if task == 'regression':
                test_preds.extend(outputs.detach().cpu().numpy())
                test_targets.extend(targets.detach().cpu().numpy())
            else:  # classification
                _, predicted = torch.max(outputs, 1)
                test_preds.extend(predicted.cpu().numpy())
                test_targets.extend(targets.cpu().numpy())
    
    # Calculate average test loss
    avg_test_loss = test_loss / len(test_loader)
    
    metrics = {'test_loss': avg_test_loss}
    
    if task == 'regression':
        # Convert lists to numpy arrays for metric calculation
        test_preds_np = np.array(test_preds)
        test_targets_np = np.array(test_targets)
        
        # Calculate MSE and MAE for mean predictions (first column)
        test_mse_mean = mean_squared_error(test_targets_np[:, 0], test_preds_np[:, 0])
        test_mae_mean = mean_absolute_error(test_targets_np[:, 0], test_preds_np[:, 0])
        
        # Calculate MSE and MAE for std predictions (second column)
        test_mse_std = mean_squared_error(test_targets_np[:, 1], test_preds_np[:, 1])
        test_mae_std = mean_absolute_error(test_targets_np[:, 1], test_preds_np[:, 1])
        
        metrics.update({
            'test_mse_mean': test_mse_mean,
            'test_mae_mean': test_mae_mean,
            'test_mse_std': test_mse_std,
            'test_mae_std': test_mae_std
        })
        
        print(f"Test results - "
              f"Loss: {avg_test_loss:.4f}, "
              f"MSE (mean): {test_mse_mean:.4f}, "
              f"MAE (mean): {test_mae_mean:.4f}, "
              f"MSE (std): {test_mse_std:.4f}, "
              f"MAE (std): {test_mae_std:.4f}")
        
    else:  # classification
        test_acc = accuracy_score(test_targets, test_preds)
        test_f1 = f1_score(test_targets, test_preds, average='weighted')
        
        metrics.update({
            'test_accuracy': test_acc,
            'test_f1': test_f1
        })
        
        print(f"Test results - "
              f"Loss: {avg_test_loss:.4f}, "
              f"Accuracy: {test_acc:.4f}, "
              f"F1 Score: {test_f1:.4f}")
    
    return metrics

# Plot metrics
def plot_metrics(metrics, task='regression'):
    # Plot training and validation loss
    plt.figure(figsize=(12, 8))
    plt.subplot(2, 1, 1)
    plt.plot(metrics['epoch'], metrics['train_loss'], label='Train Loss')
    plt.plot(metrics['epoch'], metrics['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'{task.capitalize()} Model - Training and Validation Loss')
    plt.legend()
    
    if task == 'regression':
        plt.subplot(2, 1, 2)
        plt.plot(metrics['epoch'], metrics['train_mse'], label='Train MSE')
        plt.plot(metrics['epoch'], metrics['val_mse'], label='Validation MSE')
        plt.xlabel('Epoch')
        plt.ylabel('MSE')
        plt.title('Mean Squared Error (for Mean Prediction)')
        plt.legend()
    else:  # classification
        plt.subplot(2, 1, 2)
        plt.plot(metrics['epoch'], metrics['train_accuracy'], label='Train Accuracy')
        plt.plot(metrics['epoch'], metrics['val_accuracy'], label='Validation Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.title('Classification Accuracy')
        plt.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(METRICS_DIR, f'{task}_training_metrics.png'))
    plt.close()

# Main execution function
def main():
    # Load data
    print("Loading data...")
    train_data, val_data, test_data = load_data(DATA_DIR)
    
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
    
    # Setup for regression task
    print("\nSetting up regression predictor...")
    train_dataset_reg = ResponseLengthDataset(train_data, tokenizer, task='regression')
    val_dataset_reg = ResponseLengthDataset(val_data, tokenizer, task='regression')
    test_dataset_reg = ResponseLengthDataset(test_data, tokenizer, task='regression')
    
    batch_size = 16
    train_loader_reg = DataLoader(train_dataset_reg, batch_size=batch_size, shuffle=True)
    val_loader_reg = DataLoader(val_dataset_reg, batch_size=batch_size)
    test_loader_reg = DataLoader(test_dataset_reg, batch_size=batch_size)
    
    # Determine number of models
    num_models = len(train_dataset_reg.models)
    text_embedding_size = 768  # For distilbert
    
    # Initialize regression model
    regression_model = RegressionPredictor(text_embedding_size, num_models).to(device)
    
    # Save model architecture info
    with open(os.path.join(MODELS_DIR, 'regression_model_info.json'), 'w') as f:
        json.dump({
            'num_models': num_models,
            'model_encoder': train_dataset_reg.model_encoder,
            'models': train_dataset_reg.models
        }, f)
    
    # Training parameters
    regression_criterion = nn.MSELoss()
    regression_optimizer = optim.Adam(regression_model.parameters(), lr=2e-5)
    
    # Train regression model
    print("\nTraining regression model...")
    trained_reg_model, reg_metrics = train_model(
        regression_model, 
        train_loader_reg, 
        val_loader_reg, 
        regression_criterion, 
        regression_optimizer,
        num_epochs=5,
        task='regression'
    )
    
    # Evaluate regression model
    print("\nEvaluating regression model on test set...")
    reg_test_metrics = evaluate_model(
        trained_reg_model,
        test_loader_reg,
        regression_criterion,
        task='regression'
    )
    
    # Save regression metrics
    with open(os.path.join(METRICS_DIR, 'regression_training_metrics.pkl'), 'wb') as f:
        pickle.dump(reg_metrics, f)
    
    with open(os.path.join(METRICS_DIR, 'regression_test_metrics.json'), 'w') as f:
        json.dump(reg_test_metrics, f)
    
    # Plot regression metrics
    plot_metrics(reg_metrics, task='regression')
    
    # Setup for classification task
    print("\nSetting up classification predictor...")
    
    # Define bin size
    bin_size = 50
    
    # Find max p99 value to determine number of classes
    max_p99 = 0
    for dataset in [train_data, val_data, test_data]:
        for item in dataset:
            p99 = item['output_percentiles']['99']
            max_p99 = max(max_p99, p99)
    
    num_classes = int(max_p99 // bin_size) + 1
    print(f"Maximum p99 value: {max_p99}, Number of classes: {num_classes}")
    
    train_dataset_cls = ResponseLengthDataset(train_data, tokenizer, task='classification', bin_size=bin_size)
    val_dataset_cls = ResponseLengthDataset(val_data, tokenizer, task='classification', bin_size=bin_size)
    test_dataset_cls = ResponseLengthDataset(test_data, tokenizer, task='classification', bin_size=bin_size)
    
    train_loader_cls = DataLoader(train_dataset_cls, batch_size=batch_size, shuffle=True)
    val_loader_cls = DataLoader(val_dataset_cls, batch_size=batch_size)
    test_loader_cls = DataLoader(test_dataset_cls, batch_size=batch_size)
    
    # Initialize classification model
    classification_model = ClassificationPredictor(text_embedding_size, num_models, num_classes).to(device)
    
    # Save model architecture info
    with open(os.path.join(MODELS_DIR, 'classification_model_info.json'), 'w') as f:
        json.dump({
            'num_models': num_models,
            'num_classes': num_classes,
            'bin_size': bin_size,
            'model_encoder': train_dataset_cls.model_encoder,
            'models': train_dataset_cls.models
        }, f)
    
    # Training parameters
    classification_criterion = nn.CrossEntropyLoss()
    classification_optimizer = optim.Adam(classification_model.parameters(), lr=2e-5)
    
    # Train classification model
    print("\nTraining classification model...")
    trained_cls_model, cls_metrics = train_model(
        classification_model, 
        train_loader_cls, 
        val_loader_cls, 
        classification_criterion, 
        classification_optimizer,
        num_epochs=5,
        task='classification'
    )
    
    # Evaluate classification model
    print("\nEvaluating classification model on test set...")
    cls_test_metrics = evaluate_model(
        trained_cls_model,
        test_loader_cls,
        classification_criterion,
        task='classification'
    )
    
    # Save classification metrics
    with open(os.path.join(METRICS_DIR, 'classification_training_metrics.pkl'), 'wb') as f:
        pickle.dump(cls_metrics, f)
    
    with open(os.path.join(METRICS_DIR, 'classification_test_metrics.json'), 'w') as f:
        json.dump(cls_test_metrics, f)
    
    # Plot classification metrics
    plot_metrics(cls_metrics, task='classification')
    
    print("\nTraining and evaluation complete!")
    print(f"Models saved to: {MODELS_DIR}")
    print(f"Metrics saved to: {METRICS_DIR}")

# Inference functions
def load_regression_model(model_path):
    # Load model info
    with open(os.path.join(MODELS_DIR, 'regression_model_info.json'), 'r') as f:
        model_info = json.load(f)
    
    # Initialize model
    model = RegressionPredictor(768, model_info['num_models']).to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    
    return model, model_info

def load_classification_model(model_path):
    # Load model info
    with open(os.path.join(MODELS_DIR, 'classification_model_info.json'), 'r') as f:
        model_info = json.load(f)
    
    # Initialize model
    model = ClassificationPredictor(768, model_info['num_models'], model_info['num_classes']).to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    
    return model, model_info

def predict_response_length(prompt_text, model_name, task='regression'):
    # Load appropriate model
    if task == 'regression':
        model_path = os.path.join(MODELS_DIR, 'best_regression_model.pt')
        model, model_info = load_regression_model(model_path)
        tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
    else:  # classification
        model_path = os.path.join(MODELS_DIR, 'best_classification_model.pt')
        model, model_info = load_classification_model(model_path)
        tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
    
    # Check if model_name exists in the training data
    model_encoder = model_info['model_encoder']
    if model_name not in model_encoder:
        print(f"Warning: Model '{model_name}' not found in training data. Using fallback.")
        model_name = list(model_encoder.keys())[0]
    
    # Prepare inputs
    encoding = tokenizer(
        prompt_text,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    # Prepare model one-hot encoding
    model_tensor = torch.zeros(len(model_encoder)).to(device)
    model_tensor[model_encoder[model_name]] = 1
    model_tensor = model_tensor.unsqueeze(0)
    
    # Prediction
    with torch.no_grad():
        output = model(input_ids, attention_mask, model_tensor)
        
        if task == 'regression':
            mean_pred, std_pred = output[0].cpu().numpy()
            return {
                'predicted_mean': mean_pred,
                'predicted_std': std_pred
            }
        else:  # classification
            _, predicted_class = torch.max(output, 1)
            bin_size = model_info['bin_size']
            lower_bound = predicted_class.item() * bin_size
            upper_bound = (predicted_class.item() + 1) * bin_size
            return {
                'predicted_p99_class': predicted_class.item(),
                'predicted_p99_range': f"{lower_bound}-{upper_bound} tokens"
            }

if __name__ == "__main__":
    main()
    
    # Example usage of prediction
    print("\nExample predictions:")
    
    example_prompt = "Write a comprehensive essay about the impact of artificial intelligence on society."
    example_model = "gpt-3.5-turbo"
    
    reg_prediction = predict_response_length(example_prompt, example_model, task='regression')
    print(f"Regression prediction: Expected response mean length: {reg_prediction['predicted_mean']:.2f} tokens, "
          f"Expected std: {reg_prediction['predicted_std']:.2f} tokens")
    
    cls_prediction = predict_response_length(example_prompt, example_model, task='classification')
    print(f"Classification prediction: P99 token length class: {cls_prediction['predicted_p99_class']}, "
          f"Range: {cls_prediction['predicted_p99_range']}")