## Please Use FSC8 lab outside computers
pip3 install transformers
pip uninstall torch torchvision torchaudiopip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu129


In [1]:
"""
Declaration
"""
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torch.optim import AdamW 
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
import os
from tqdm import tqdm
from datetime import datetime
import random
import json
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import time, psutil

"""
system settings
"""
def set_overall_seed(seed=16):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_overall_seed(16)

"""
Define a class for importing Dataset
"""
class ProcessedIMDbDataset(Dataset):
    """loaded processed dataset"""
    
    def __init__(self, sequences, attention_masks, labels, lengths=None):
        self.sequences = sequences
        self.attention_masks = attention_masks
        self.labels = labels
        self.lengths = lengths if lengths is not None else torch.ones_like(labels)
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        return {
            'input_ids': self.sequences[idx],
            'attention_mask': self.attention_masks[idx],
            'labels': self.labels[idx],
            'lengths': self.lengths[idx]
        }


"""
Model Trainer based on Bert Finetuning
"""
class BertSentimentTrainer:
    """ SentimentAnalyzer Based on Bert """
    def __init__(self, model_name='bert-base-uncased', num_labels=2, max_length=512):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f'using device {self.device}')
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_name, num_labels=num_labels)

        self.model.to(self.device)

        self.max_length = max_length
        self.train_loader = None
        self.val_loader = None
        self.test_loader = None

    def _get_memory_usage(self):
        # when using GPU
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            gpu_memory = torch.cuda.max_memory_allocated() / 1024**3  # GB
            return gpu_memory
        else: # using cpu
            process = psutil.Process()
            return process.memory_info().rss / 1024**3  # GB

    def load_data(self, file_path='./all_data.pt', batch_size=20):
        print(f"loading dataset")

        data = torch.load(file_path, weights_only=True)

        train_dataset = ProcessedIMDbDataset(
            data['train_sequences'],
            data['train_masks'], 
            data['train_labels'],
            data['train_lengths']
        )
        
        val_dataset = ProcessedIMDbDataset(
            data['val_sequences'],
            data['val_masks'],
            data['val_labels'], 
            data['val_lengths']
        )
        
        test_dataset = ProcessedIMDbDataset(
            data['test_sequences'],
            data['test_masks'],
            data['test_labels'],
            data['test_lengths']
        )

        self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        self.val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        self.test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

        print(f"loaded done, train dataset:{len(train_dataset)}, val_dataset:{len(val_dataset)}, test_dataset:{len(test_dataset)}")
        return train_dataset, val_dataset, test_dataset

    def set_finetune_param(self, strategy_name, epochs=4, lr=2e-5, warmup_steps=0):
        """
        setting params for different strategies and get corresponding optimizer&scheduler
        """
        if strategy_name == "standard":
            # use AdamW with standard LR strategy
            optimizer = AdamW(self.model.parameters(), lr=lr, weight_decay=0.01)
        elif strategy_name == "differential_lr":
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_parameters_by_group = [
                # classifier - high
                {
                    "params": [p for n, p in self.model.classifier.named_parameters()],
                    "lr": lr,
                    "weight_decay": 0.01
                },
                # BERT top 4 layers - medium
                {
                    "params": [p for n, p in self.model.bert.encoder.layer[-4:].named_parameters() 
                              if not any(nd in n for nd in no_decay)],
                    "lr": lr / 2,
                    "weight_decay": 0.01
                },
                {
                    "params": [p for n, p in self.model.bert.encoder.layer[-4:].named_parameters() 
                              if any(nd in n for nd in no_decay)],
                    "lr": lr / 2,
                    "weight_decay": 0.0
                },
                # BERT bottom 8 layers - low
                {
                    "params": [p for n, p in self.model.bert.encoder.layer[:-4].named_parameters() 
                              if not any(nd in n for nd in no_decay)],
                    "lr": lr / 10,
                    "weight_decay": 0.01
                },
                {
                    "params": [p for n, p in self.model.bert.encoder.layer[:-4].named_parameters() 
                              if any(nd in n for nd in no_decay)],
                    "lr": lr / 10,
                    "weight_decay": 0.0
                }
            ]
            optimizer = AdamW(optimizer_parameters_by_group)
        else:
            print(f"no implemention for {strategy_name}, using standard strategy.")
            optimizer = AdamW(self.model.parameters(), lr=lr, weight_decay=0.01)

        total_steps = len(self.train_loader) * epochs

        #scheduler with warmup
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps
        )
        
        return optimizer, scheduler

    def finetune_by_strategy(self, strategy_name, epochs=4, learning_rate=2e-5, warmup_steps=0, logging_steps=50):
        print(f"start {strategy_name} finetuning")
        
        start_time = time.time()
        initial_memory = self._get_memory_usage()
        
        optimizer, scheduler = self.set_finetune_param(strategy_name, epochs, learning_rate, warmup_steps)
        training_stats = self.train(strategy_name, optimizer, scheduler, epochs, 2, learning_rate, warmup_steps, logging_steps)
        
        end_time = time.time()
        final_memory = self._get_memory_usage()

        training_time = end_time - start_time
        memory_used = final_memory - initial_memory
        
        print(f"{strategy_name} training completed in {training_time:.2f} seconds")
        print(f"Memory used: {memory_used:.2f} GB")
        
        return training_stats, training_time, memory_used

        
    def train(self, strategy_name, optimizer, scheduler, epochs=4, accumulation_steps=8, learning_rate=2e-5, warmup_steps=0, logging_steps=50):
        if self.train_loader is None or self.val_loader is None:
            raise ValueError("datasets not ready!")

        self.calculate_model_parameters()

        training_stats = []
        best_val_accuracy = 0
        patience = 2
        patience_counter = 0

        for epoch in range(epochs):
            print(f"Epoch {epoch+1}/{epochs}")
            print("=" * 80)

            self.model.train()
            total_train_loss = 0
            train_correct = 0
            train_total = 0
            
            optimizer.zero_grad()

            batch_progress = tqdm(self.train_loader, desc=f'training {strategy_name}')
            for step, batch in enumerate(batch_progress):
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['labels'].to(self.device)

                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                loss = outputs.loss
                logits = outputs.logits

                loss = loss / accumulation_steps
                loss.backward()

                total_train_loss += loss.item()
                predictions = torch.argmax(logits, dim=1)
                train_correct += (predictions == labels).sum().item()
                train_total += labels.size(0)

                if (step + 1) % accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()

                if step % logging_steps == 0:
                    batch_progress.set_postfix(
                        {
                        'loss': f"{loss.item():.4f}",
                        'acc': f"{train_correct/train_total:.4f}",
                        'mem': f"{self._get_memory_usage():.2f}GB"
                        }
                    )
            
            if len(self.train_loader) % accumulation_steps != 0:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            avg_train_loss = total_train_loss / len(self.train_loader)
            train_accuracy = train_correct / train_total
            val_accuracy, val_loss = self.evaluate(self.val_loader)
            print(f"training loss: {avg_train_loss:.4f}, training accuracy:{train_accuracy:.4f}")

            if val_accuracy > best_val_accuracy:
                best_val_accuracy = val_accuracy
                self.save_model(f'best_bert_model_{strategy_name}')
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f"early stop, we've got best model")
                break

            training_stats.append(
                {
                    'epoch': epoch + 1,
                    'train_loss': avg_train_loss,
                    'train_accuracy': train_accuracy,
                    'val_loss': val_loss,
                    'val_accuracy': val_accuracy,
                    'strategy_name': strategy_name,
                    'memory_usage': self._get_memory_usage()
                }
            )

        # at last, we load the best performance model, for further process
        self.load_model(f'best_bert_model_{strategy_name}')
        return training_stats

    def evaluate(self, dataloader):
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0

        with torch.no_grad():
            for batch in tqdm(dataloader, desc="evaluating"):
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['labels'].to(self.device)
                
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                
                loss = outputs.loss
                logits = outputs.logits
                
                total_loss += loss.item()
                predictions = torch.argmax(logits, dim=1)
                correct += (predictions == labels).sum().item()
                total += labels.size(0)

        accuracy = correct / total
        avg_loss = total_loss / len(dataloader)
        return accuracy, avg_loss

    def predict(self, texts):
        self.model.eval()
        predictions = []
        probabilities = []

        with torch.no_grad():
            for text in texts:
                encoding = self.tokenizer(
                    text,
                    truncation=True,
                    padding='max_length',
                    max_length=self.max_length,
                    return_tensors='pt'
                )
                
                input_ids = encoding['input_ids'].to(self.device)
                attention_mask = encoding['attention_mask'].to(self.device)
                
                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                probs = torch.softmax(logits, dim=1)
                pred = torch.argmax(logits, dim=1)
                
                predictions.append(pred.cpu().item())
                probabilities.append(probs.cpu().numpy())

        return predictions, probabilities

    def calculate_model_parameters(self):
        total_params = sum(p.numel() for p in self.model.parameters())
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        
        self.total_params = total_params
        self.trainable_params = trainable_params
        
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_params:,}")
        return total_params, trainable_params
    
    def overall_evaluation(self):
        if self.test_loader is None:
            raise ValueError("test datasets not ready")
    
        self.model.eval()
        all_predictions = []
        all_labels = []
        all_probabilities = []
        
        # Measure inference time
        start_time = time.time()
        
        with torch.no_grad():
            for batch in tqdm(self.test_loader, desc="testing"):
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['labels'].to(self.device)
                
                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                probabilities = torch.softmax(logits, dim=1)
                predictions = torch.argmax(logits, dim=1)
                
                all_predictions.extend(predictions.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_probabilities.extend(probabilities.cpu().numpy())
        
        inference_time = time.time() - start_time
        
        accuracy = accuracy_score(all_labels, all_predictions)
        f1 = f1_score(all_labels, all_predictions, average='weighted')
        class_report = classification_report(all_labels, all_predictions, target_names=['negative', 'positive'])
        conf_matrix = confusion_matrix(all_labels, all_predictions)
        
        results = {
            'accuracy': accuracy,
            'f1_score': f1,
            'classification_report': class_report,
            'confusion_matrix': conf_matrix,
            'predictions': all_predictions,
            'probabilities': all_probabilities,
            'inference_time': inference_time,
            'model_size_mb': self.total_params * 4 / (1024**2),  # 4 bytes per parameter
            'total_parameters': self.total_params,
            'trainable_parameters': self.trainable_params
        }
        
        print(f"test accuracy: {accuracy:.4f}")
        print(f"F1 score: {f1:.4f}")
        print(f"Inference time: {inference_time:.2f} seconds")
        print(f"Model size: {results['model_size_mb']:.2f} MB")
        print("\n class report:")
        print(class_report)
        print("\n confusion matrix")
        print(conf_matrix)
        
        return results

    def save_model(self, path):
        if not os.path.exists(path):
            os.makedirs(path)
    
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(path)
        print(f"save model to: {path}")
    
    def load_model(self, path):
        if not os.path.exists(path):
            print(f"model file not exist!")
            return
        
        self.model = AutoModelForSequenceClassification.from_pretrained(path)
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model.to(self.device)
        print(f"loaded model from {path}")


"""
Visualization Part for bert
"""

def plt_training_history_data(all_training_stats):
    """
    use matplotlib to plot training history
    """
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    colors = {'standard': 'blue', 'differential_lr': 'red'}
    
    for strategy_stats in all_training_stats:
        if not strategy_stats:
            continue
            
        strategy_name = strategy_stats[0]['strategy_name']
        df_stats = pd.DataFrame(strategy_stats)
        color = colors.get(strategy_name, 'green')
        
        # Training Loss
        ax1.plot(df_stats['epoch'], df_stats['train_loss'], 
                color=color, linestyle='-', label=f'{strategy_name} train')
        ax1.plot(df_stats['epoch'], df_stats['val_loss'], 
                color=color, linestyle='--', label=f'{strategy_name} val')
        
        # Training Accuracy
        ax2.plot(df_stats['epoch'], df_stats['train_accuracy'], 
                color=color, linestyle='-', label=f'{strategy_name} train')
        ax2.plot(df_stats['epoch'], df_stats['val_accuracy'], 
                color=color, linestyle='--', label=f'{strategy_name} val')
    
    ax1.set_title('Training and Validation Loss Comparison')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)

    ax2.set_title('Training and Validation Accuracy Comparison')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    ax2.grid(True)
    
    # Final performance comparison
    strategies = []
    final_val_accuracies = []
    final_val_losses = []
    
    for strategy_stats in all_training_stats:
        if not strategy_stats:
            continue
        strategy_name = strategy_stats[0]['strategy_name']  # 修复：改为strategy_name
        final_epoch = strategy_stats[-1]
        strategies.append(strategy_name)
        final_val_accuracies.append(final_epoch['val_accuracy'])
        final_val_losses.append(final_epoch['val_loss'])

    if strategies:
        x = range(len(strategies))
        
        ax3.bar(x, final_val_accuracies, color=['blue', 'red'], alpha=0.7)
        ax3.set_title('Final Validation Accuracy by Strategy')
        ax3.set_xlabel('Strategy')
        ax3.set_ylabel('Accuracy')
        ax3.set_xticks(x)
        ax3.set_xticklabels(strategies)
        
        ax4.bar(x, final_val_losses, color=['blue', 'red'], alpha=0.7)
        ax4.set_title('Final Validation Loss by Strategy')
        ax4.set_xlabel('Strategy')
        ax4.set_ylabel('Loss')
        ax4.set_xticks(x)
        ax4.set_xticklabels(strategies)
    
    plt.tight_layout()
    plt.savefig('bert_strategy_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()

def plot_confusion_matrix(conf_matrix, strategy_name):
    """plot confusion matrix"""
    class_names=['negative', 'positive']
    
    plt.figure(figsize=(6, 5))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f'BERT {strategy_name} Confusion Matrix')
    plt.ylabel('Ground Truth')
    plt.xlabel('Prediction')
    plt.savefig(f'bert_{strategy_name}_confusion_matrix.png', dpi=300, bbox_inches='tight')
    plt.show()

def plot_performance_comparison(strategy_results):
    """Plot performance metrics comparison"""
    strategies = list(strategy_results.keys())
    accuracies = [result['accuracy'] for result in strategy_results.values()]
    f1_scores = [result['f1_score'] for result in strategy_results.values()]
    
    x = np.arange(len(strategies))
    width = 0.35
    
    fig, ax = plt.subplots(figsize=(10, 6))
    bars1 = ax.bar(x - width/2, accuracies, width, label='Accuracy', alpha=0.7)
    bars2 = ax.bar(x + width/2, f1_scores, width, label='F1 Score', alpha=0.7)
    
    ax.set_xlabel('Finetuning Strategy')
    ax.set_ylabel('Score')
    ax.set_title('Performance Comparison of Different Finetuning Strategies')
    ax.set_xticks(x)
    ax.set_xticklabels(strategies)
    ax.legend()
    
    # Add value labels on bars
    for bar in bars1:
        height = bar.get_height()
        ax.annotate(f'{height:.4f}',
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom')
    
    for bar in bars2:
        height = bar.get_height()
        ax.annotate(f'{height:.4f}',
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig('bert_performance_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()


"""
Compare performance/cost/parameters with previous LSTM Model
"""
def plot_overall_comparison(bert_results, lstm_results):
    """Plot comparison between BERT and LSTM"""
    # Performance metrics
    bert_accuracy = bert_results['accuracy']
    bert_f1 = bert_results['f1_score']
    lstm_accuracy = lstm_results.get('accuracy', 0)
    lstm_f1 = lstm_results.get('f1_score', 0)
    
    # Computational metrics
    bert_training_time = bert_results.get('training_time', 0)
    bert_inference_time = bert_results.get('inference_time', 0)
    bert_model_size = bert_results.get('model_size_mb', 0)
    bert_params = bert_results.get('total_parameters', 0)
    
    lstm_training_time = lstm_results.get('training_time', 0)
    lstm_inference_time = lstm_results.get('inference_time', 0)
    lstm_model_size = lstm_results.get('model_size_mb', 0)
    lstm_params = lstm_results.get('total_parameters', 0)
    
    # Create subplots
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
    
    # Performance comparison
    models = ['LSTM', 'BERT']
    accuracies = [lstm_accuracy, bert_accuracy]
    f1_scores = [lstm_f1, bert_f1]
    
    x = np.arange(len(models))
    width = 0.35
    
    ax1.bar(x - width/2, accuracies, width, label='Accuracy', alpha=0.7)
    ax1.bar(x + width/2, f1_scores, width, label='F1 Score', alpha=0.7)
    ax1.set_xlabel('Model')
    ax1.set_ylabel('Score')
    ax1.set_title('Performance Comparison: LSTM vs BERT')
    ax1.set_xticks(x)
    ax1.set_xticklabels(models)
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Training and inference time comparison
    training_times = [lstm_training_time, bert_training_time]
    inference_times = [lstm_inference_time, bert_inference_time]
    
    ax2.bar(x - width/2, training_times, width, label='Training Time (s)', alpha=0.7)
    ax2.bar(x + width/2, inference_times, width, label='Inference Time (s)', alpha=0.7)
    ax2.set_xlabel('Model')
    ax2.set_ylabel('Time (seconds)')
    ax2.set_title('Computational Time Comparison')
    ax2.set_xticks(x)
    ax2.set_xticklabels(models)
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Model size comparison
    model_sizes = [lstm_model_size, bert_model_size]
    
    ax3.bar(models, model_sizes, color=['orange', 'green'], alpha=0.7)
    ax3.set_xlabel('Model')
    ax3.set_ylabel('Model Size (MB)')
    ax3.set_title('Model Size Comparison')
    ax3.grid(True, alpha=0.3)
    
    # Parameter count comparison
    parameters = [lstm_params, bert_params]
    
    ax4.bar(models, parameters, color=['purple', 'brown'], alpha=0.7)
    ax4.set_xlabel('Model')
    ax4.set_ylabel('Number of Parameters')
    ax4.set_title('Parameter Count Comparison')
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('bert_lstm_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()


def load_lstm_results(path):
    lstm_results = None
    
    try:
        with open(path, 'r') as f:
            lstm_results = json.load(f)
    except Exception as e:  # 修复：改为Exception as e
        print(f"load lstm results failed: {str(e)}")
        # demo params
        lstm_results = {
            'accuracy': 0.8718,
            'f1_score': 0.8728,
            'training_time': 1200,  # 20 minutes
            'inference_time': 30,
            'model_size_mb': 50,
            'total_parameters': 13000000,
            'memory_usage': 2.5
        }

    return lstm_results


"""
Program Main Process
"""
def perform_training():
    work_path = os.getcwd()
    data_file = 'all_data.pt'
    data_file_path = os.path.join(work_path, "processed_data_binary", data_file)

    trainer = BertSentimentTrainer(
        model_name = 'prajjwal1/bert-mini',
        num_labels=2,
        max_length=128
    )

    trainer.load_data(data_file_path, batch_size=32)  # 最大批量
    
    finetune_strategies = ["standard", "differential_lr"]# ,"gradual_unfreeze", "bitfit"] #bitfit bias-term finetune
    
    all_training_stats = []
    strategy_results = {}
    training_times = {}
    memory_usage = {}

    for strategy in finetune_strategies:
        print(f"\n{'='*60}")
        print(f"Training with strategy: {strategy}")
        print(f"{'='*60}")
        
        training_stats, training_time, memory_used = trainer.finetune_by_strategy(
            strategy_name="standard",
            epochs=2,  # 2个epoch更好收敛
            learning_rate=3e-5,
            warmup_steps=50,
            logging_steps=100
        )


        
        all_training_stats.append(training_stats)
        training_times[strategy] = training_time
        memory_usage[strategy] = memory_used
        
        # Evaluate on test set
        results = trainer.overall_evaluation()
        # Add computational metrics to results
        results['training_time'] = training_time
        results['memory_usage'] = memory_used
        strategy_results[strategy] = results
        
        # Plot confusion matrix for this strategy
        plot_confusion_matrix(results['confusion_matrix'], strategy_name=strategy)
        
        # Save model
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        model_save_path = f"sentiment_bert_{strategy}_model_{timestamp}"
        trainer.save_model(model_save_path)


    plt_training_history_data(all_training_stats)
    plot_performance_comparison(strategy_results)

    print("\n" + "="*80)
    print("STRATEGY COMPARISON SUMMARY")
    print("="*80)
    
    for strategy in finetune_strategies:
        results = strategy_results[strategy]
        print(f"\n{strategy.upper()} Strategy:")
        print(f"  Accuracy: {results['accuracy']:.4f}")
        print(f"  F1 Score: {results['f1_score']:.4f}")
        print(f"  Training Time: {training_times[strategy]:.2f}s")
        print(f"  GPU Memory Usage: {memory_usage[strategy]:.2f}GB")
        print(f"  Model Size: {results['model_size_mb']:.2f}MB")
        print(f"  Parameters: {results['total_parameters']:,}")
    

    return trainer, strategy_results, all_training_stats, training_times, memory_usage

def test_texts(trainer):
    test_samples = [
        "quite a good movie, but the ending could be better!",
        "it's a fantastic film I've ever enjoyed!"
    ]

    predictions, probabilities = trainer.predict(test_samples)

    sentiment_mapping = {0 : 'negative', 1 : 'positive'}

    print("\n" + "="*60)
    print("SAMPLE TEXT PREDICTIONS")
    print("="*60)
    
    for text, pred, prob in zip(test_samples, predictions, probabilities):
        print(f"testing text: {text}")
        print(f"prediction: {sentiment_mapping[pred]}, {max(prob[0]):.4f}")
        print("-" * 50)

def main():
    try:
        trainer, bert_results, training_stats, training_times, memory_usage = perform_training()
    
        lstm_results = load_lstm_results("./lstm_results.json")
    
        if lstm_results:
            print("\n" + "="*80)
            print("COMPREHENSIVE COMPARISON WITH LSTM BASELINE")
            print("="*80)
            # Use the best performing BERT strategy for comparison
            best_strategy = max(bert_results.keys(), 
                              key=lambda x: bert_results[x]['accuracy'])
            plot_overall_comparison(bert_results[best_strategy], lstm_results)
        
        test_texts(trainer)
    except Exception as e:
        print(f"error happened: {str(e)}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
     main()