<a href="https://colab.research.google.com/github/Kishan-prajapati-242/ATCTM/blob/main/notebooks/EC_demo_8_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, mean_squared_error
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
import json
import os
warnings.filterwarnings('ignore')

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

print("✓ All libraries imported successfully!")
print(f"✓ PyTorch version: {torch.__version__}")
print(f"✓ Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")


class Config:
    # Model
    model_name = 'microsoft/deberta-v3-base'
    max_length = 256

    # Training
    batch_size = 16
    learning_rate = 2e-5
    num_epochs = 20
    warmup_steps = 500
    weight_decay = 0.01

    # Architecture
    hidden_dim = 768
    num_attention_heads = 12
    dropout_rate = 0.3
    label_smoothing = 0.1

    # Multi-task loss weights
    task_weights = {
        'event_type': 1.0,      # Primary task
        'event_group': 0.8,
        'sentiment_valence': 0.6,
        'emotion': 0.7,
        'sarcasm': 0.5,
        'tense': 0.5,
        'certainty': 0.5
    }

    # Training strategy
    gradient_accumulation_steps = 2
    max_grad_norm = 1.0
    early_stopping_patience = 5

    # Multi-event threshold
    multi_event_threshold = 0.7

    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Save paths
    save_dir = '/content/drive/MyDrive/EC-model-v2'

config = Config()
print(f"Configuration loaded. Device: {config.device}")

# Create save directory
os.makedirs(config.save_dir, exist_ok=True)
print(f"Save directory: {config.save_dir}")


df = pd.read_csv('/content/drive/MyDrive/ATCTM/EVENT_CLASSIFICATION/EC-demo.csv')
print(f"Dataset loaded: {len(df)} samples")

# Data analysis
print("\nMissing values:")
print(df.isnull().sum())
print("\n" + "="*50)

print("Data types:")
print(df.dtypes)
print("\n" + "="*50)

# Fix boolean columns if needed
if df['SARCASM'].dtype == 'bool':
    df['SARCASM'] = df['SARCASM'].map({True: 'TRUE', False: 'FALSE'})
    print("Converted SARCASM from bool to string")

if 'GENERATED' in df.columns and df['GENERATED'].dtype == 'bool':
    df['GENERATED'] = df['GENERATED'].map({True: 'TRUE', False: 'FALSE'})
    print("Converted GENERATED from bool to string")

# Show distribution
print("\nEvent type distribution:")
print(df['EVENT_TYPE'].value_counts().head(10))

# import pandas as pd

# # Load only first 200 rows
# df = pd.read_csv('/content/drive/MyDrive/ATCTM/EVENT_CLASSIFICATION/EC-demo.csv').head(400)
# print(f"Dataset loaded: {len(df)} samples")

# # Data analysis
# print("\nMissing values:")
# print(df.isnull().sum())
# print("\n" + "="*50)

# print("Data types:")
# print(df.dtypes)
# print("\n" + "="*50)

# # Fix boolean columns if needed
# if df['SARCASM'].dtype == 'bool':
#     df['SARCASM'] = df['SARCASM'].map({True: 'TRUE', False: 'FALSE'})
#     print("Converted SARCASM from bool to string")

# if 'GENERATED' in df.columns and df['GENERATED'].dtype == 'bool':
#     df['GENERATED'] = df['GENERATED'].map({True: 'TRUE', False: 'FALSE'})
#     print("Converted GENERATED from bool to string")

# # Show distribution
# print("\nEvent type distribution:")
# print(df['EVENT_TYPE'].value_counts().head(10))


class MultiTaskLabelEncoders:
    def __init__(self):
        self.encoders = {}
        self.num_classes = {}

    def fit(self, df):
        # EVENT_TYPE encoder
        self.encoders['event_type'] = LabelEncoder()
        self.encoders['event_type'].fit(df['EVENT_TYPE'])
        self.num_classes['event_type'] = len(self.encoders['event_type'].classes_)

        # EVENT_GROUP encoder
        self.encoders['event_group'] = LabelEncoder()
        self.encoders['event_group'].fit(df['EVENT_GROUP'])
        self.num_classes['event_group'] = len(self.encoders['event_group'].classes_)

        # EMOTION encoder
        self.encoders['emotion'] = LabelEncoder()
        self.encoders['emotion'].fit(df['EMOTION'].fillna('neutral'))
        self.num_classes['emotion'] = len(self.encoders['emotion'].classes_)

        # TENSE encoder (handle missing values)
        tense_values = df['TENSE'].dropna().unique()
        self.encoders['tense'] = LabelEncoder()
        self.encoders['tense'].fit(tense_values)
        self.num_classes['tense'] = len(self.encoders['tense'].classes_)

        # Binary classification for SARCASM
        self.num_classes['sarcasm'] = 2

        print(f"Encoders fitted:")
        for key, num in self.num_classes.items():
            if key in self.encoders:
                print(f"  {key}: {num} classes - {list(self.encoders[key].classes_)[:5]}...")
            else:
                print(f"  {key}: {num} classes")

    def transform(self, value, task, default=None):
        if task == 'sarcasm':
            return 1 if str(value).upper() == 'TRUE' or value == True else 0
        elif task in ['sentiment_valence', 'certainty']:
            return float(value) if pd.notna(value) else 0.5
        else:
            if pd.isna(value):
                if task == 'tense':
                    value = 'present'  # Default for TENSE
                elif task == 'emotion':
                    value = 'neutral'  # Default for EMOTION
                else:
                    value = default if default else list(self.encoders[task].classes_)[0]
            return self.encoders[task].transform([value])[0]

    def inverse_transform(self, labels, task):
        if task == 'sarcasm':
            return ['FALSE' if l == 0 else 'TRUE' for l in labels]
        elif task in ['sentiment_valence', 'certainty']:
            return labels
        else:
            return self.encoders[task].inverse_transform(labels)

# Initialize encoders
label_encoders = MultiTaskLabelEncoders()
label_encoders.fit(df)


class MultiTaskEventDataset(Dataset):
    def __init__(self, texts, labels_df, tokenizer, label_encoders, max_length=256):
        self.texts = texts
        self.labels_df = labels_df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.label_encoders = label_encoders
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        row = self.labels_df.iloc[idx]

        # Tokenize
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        # Get all labels
        labels = {
            'event_type': self.label_encoders.transform(row['EVENT_TYPE'], 'event_type'),
            'event_group': self.label_encoders.transform(row['EVENT_GROUP'], 'event_group'),
            'emotion': self.label_encoders.transform(row['EMOTION'], 'emotion'),
            'tense': self.label_encoders.transform(row['TENSE'], 'tense'),
            'sarcasm': self.label_encoders.transform(row['SARCASM'], 'sarcasm'),
            'sentiment_valence': self.label_encoders.transform(row['SENTIMENT_VALENCE'], 'sentiment_valence'),
            'certainty': self.label_encoders.transform(row['CERTAINTY'], 'certainty')
        }

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': labels
        }

print("✓ Dataset class defined")


class MultiTaskEventClassifier(nn.Module):
    def __init__(self, config, label_encoders):
        super().__init__()
        self.config = config
        self.label_encoders = label_encoders

        # Load pre-trained transformer
        self.transformer = AutoModel.from_pretrained(config.model_name)

        # Freeze embeddings
        for param in self.transformer.embeddings.parameters():
            param.requires_grad = False

        # Attention for text features
        self.text_attention = nn.MultiheadAttention(
            embed_dim=self.transformer.config.hidden_size,
            num_heads=config.num_attention_heads,
            dropout=config.dropout_rate,
            batch_first=True
        )

        # Shared feature extractor
        self.shared_layer = nn.Sequential(
            nn.Linear(self.transformer.config.hidden_size, config.hidden_dim),
            nn.LayerNorm(config.hidden_dim),
            nn.GELU(),
            nn.Dropout(config.dropout_rate)
        )

        # Task-specific classifiers
        self.event_type_classifier = self._make_classifier(
            config.hidden_dim, label_encoders.num_classes['event_type'])
        self.event_group_classifier = self._make_classifier(
            config.hidden_dim, label_encoders.num_classes['event_group'])
        self.emotion_classifier = self._make_classifier(
            config.hidden_dim, label_encoders.num_classes['emotion'])
        self.tense_classifier = self._make_classifier(
            config.hidden_dim, label_encoders.num_classes['tense'])
        self.sarcasm_classifier = self._make_classifier(
            config.hidden_dim, 2)

        # Regression heads
        self.sentiment_regressor = self._make_regressor(config.hidden_dim)
        self.certainty_regressor = self._make_regressor(config.hidden_dim)

    def _make_classifier(self, input_dim, num_classes):
        return nn.Sequential(
            nn.Linear(input_dim, input_dim // 2),
            nn.LayerNorm(input_dim // 2),
            nn.GELU(),
            nn.Dropout(self.config.dropout_rate),
            nn.Linear(input_dim // 2, num_classes)
        )

    def _make_regressor(self, input_dim):
        return nn.Sequential(
            nn.Linear(input_dim, input_dim // 2),
            nn.LayerNorm(input_dim // 2),
            nn.GELU(),
            nn.Dropout(self.config.dropout_rate),
            nn.Linear(input_dim // 2, 1),
            nn.Sigmoid()
        )

    def forward(self, input_ids, attention_mask):
        # Get transformer outputs
        outputs = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        # Apply attention
        hidden_state = outputs.last_hidden_state
        attended, _ = self.text_attention(
            hidden_state, hidden_state, hidden_state,
            key_padding_mask=~attention_mask.bool()
        )

        # Pool
        text_features = (attended * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)

        # Get shared features
        shared_features = self.shared_layer(text_features)

        # Get predictions from each head
        outputs = {
            'event_type': self.event_type_classifier(shared_features),
            'event_group': self.event_group_classifier(shared_features),
            'emotion': self.emotion_classifier(shared_features),
            'tense': self.tense_classifier(shared_features),
            'sarcasm': self.sarcasm_classifier(shared_features),
            'sentiment_valence': self.sentiment_regressor(shared_features).squeeze(-1),
            'certainty': self.certainty_regressor(shared_features).squeeze(-1)
        }

        return outputs

print("✓ Model architecture defined")

class MultiTaskLoss(nn.Module):
    def __init__(self, config, label_smoothing=0.1):
        super().__init__()
        self.config = config
        self.ce_loss = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
        self.mse_loss = nn.MSELoss()

    def forward(self, outputs, labels):
        losses = {}

        # Classification losses
        for task in ['event_type', 'event_group', 'emotion', 'tense', 'sarcasm']:
            losses[task] = self.ce_loss(outputs[task], labels[task])

        # Regression losses
        for task in ['sentiment_valence', 'certainty']:
            losses[task] = self.mse_loss(outputs[task], labels[task].float())

        # Weighted sum
        total_loss = sum(
            self.config.task_weights.get(task, 1.0) * loss
            for task, loss in losses.items()
        )

        return total_loss, losses

print("✓ Loss function defined")

def train_epoch(model, dataloader, optimizer, scheduler, criterion, device):
    model.train()
    total_loss = 0
    task_losses = {task: 0 for task in config.task_weights.keys()}

    for batch in tqdm(dataloader, desc='Training'):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        # Convert labels to tensors
        labels = {}
        for key, value in batch['labels'].items():
            if isinstance(value, list):
                labels[key] = torch.tensor(value, device=device)
            else:
                labels[key] = value.to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        loss, losses = criterion(outputs, labels)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        total_loss += loss.item()
        for task, task_loss in losses.items():
            task_losses[task] += task_loss.item()

    avg_loss = total_loss / len(dataloader)
    avg_task_losses = {k: v / len(dataloader) for k, v in task_losses.items()}

    return avg_loss, avg_task_losses

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    predictions = {task: [] for task in config.task_weights.keys()}
    true_labels = {task: [] for task in config.task_weights.keys()}

    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Evaluating'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            labels = {}
            for key, value in batch['labels'].items():
                if isinstance(value, list):
                    labels[key] = torch.tensor(value, device=device)
                else:
                    labels[key] = value.to(device)
                true_labels[key].extend(labels[key].cpu().numpy())

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            loss, _ = criterion(outputs, labels)
            total_loss += loss.item()

            # Get predictions
            for task in ['event_type', 'event_group', 'emotion', 'tense', 'sarcasm']:
                preds = torch.argmax(outputs[task], dim=1)
                predictions[task].extend(preds.cpu().numpy())

            for task in ['sentiment_valence', 'certainty']:
                predictions[task].extend(outputs[task].cpu().numpy())

    # Calculate metrics
    accuracies = {}
    for task in ['event_type', 'event_group', 'emotion', 'tense', 'sarcasm']:
        accuracies[task] = accuracy_score(true_labels[task], predictions[task])

    mse_scores = {}
    for task in ['sentiment_valence', 'certainty']:
        mse_scores[task] = mean_squared_error(true_labels[task], predictions[task])

    return total_loss / len(dataloader), accuracies, mse_scores, predictions, true_labels

print("✓ Training functions defined")

train_indices, val_indices = train_test_split(
    range(len(df)), test_size=0.2, random_state=42,
    stratify=df['EVENT_TYPE']
)

train_texts = df.iloc[train_indices]['TEXT'].values
val_texts = df.iloc[val_indices]['TEXT'].values
train_labels = df.iloc[train_indices]
val_labels = df.iloc[val_indices]

print(f"Training samples: {len(train_texts)}")
print(f"Validation samples: {len(val_texts)}")

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

# Create datasets
train_dataset = MultiTaskEventDataset(train_texts, train_labels, tokenizer, label_encoders, config.max_length)
val_dataset = MultiTaskEventDataset(val_texts, val_labels, tokenizer, label_encoders, config.max_length)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=2)

print(f"✓ DataLoaders created")
print(f"  - Train batches: {len(train_loader)}")
print(f"  - Val batches: {len(val_loader)}")


model = MultiTaskEventClassifier(config, label_encoders)
model.to(config.device)

# Model info
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Optimizer and scheduler
optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
total_steps = len(train_loader) * config.num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup_steps, num_training_steps=total_steps)
criterion = MultiTaskLoss(config, label_smoothing=config.label_smoothing)

print("✓ Model initialized")

# Cell 9 REPLACEMENT - Fix DataLoader Issues
# Split data
train_indices, val_indices = train_test_split(
    range(len(df)), test_size=0.2, random_state=42,
    stratify=df['EVENT_TYPE']
)

train_texts = df.iloc[train_indices]['TEXT'].values
val_texts = df.iloc[val_indices]['TEXT'].values
train_labels = df.iloc[train_indices]
val_labels = df.iloc[val_indices]

print(f"Training samples: {len(train_texts)}")
print(f"Validation samples: {len(val_texts)}")

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

# Create datasets
train_dataset = MultiTaskEventDataset(train_texts, train_labels, tokenizer, label_encoders, config.max_length)
val_dataset = MultiTaskEventDataset(val_texts, val_labels, tokenizer, label_encoders, config.max_length)

# Create dataloaders with num_workers=0 to avoid hanging
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=0,  # THIS IS THE KEY FIX
    pin_memory=False  # Also disable pin_memory
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=0,  # THIS IS THE KEY FIX
    pin_memory=False
)

print(f"✓ DataLoaders created (num_workers=0)")
print(f"  - Train batches: {len(train_loader)}")
print(f"  - Val batches: {len(val_loader)}")

# %%
# Cell 11 REPLACEMENT - Simple Working Training Loop
import gc
import sys

print("\nStarting training...")
train_losses, val_losses = [], []
best_val_loss = float('inf')
patience_counter = 0

for epoch in range(config.num_epochs):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch + 1}/{config.num_epochs}")
    print(f"{'='*60}")

    # TRAINING
    model.train()
    total_loss = 0
    num_batches = len(train_loader)

    for batch_idx, batch in enumerate(train_loader):
        # Print progress immediately
        sys.stdout.write(f'\rTraining: {batch_idx+1}/{num_batches} [{(batch_idx+1)/num_batches*100:.1f}%]')
        sys.stdout.flush()

        # Get data
        input_ids = batch['input_ids'].to(config.device)
        attention_mask = batch['attention_mask'].to(config.device)

        labels = {}
        for key, value in batch['labels'].items():
            labels[key] = value.to(config.device)

        # Forward
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        loss, _ = criterion(outputs, labels)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

        # Free memory
        del input_ids, attention_mask, labels, outputs, loss

    avg_train_loss = total_loss / num_batches
    train_losses.append(avg_train_loss)
    print(f'\nTrain Loss: {avg_train_loss:.4f}')

    # EVALUATION
    model.eval()
    total_val_loss = 0
    val_accuracies = {'event_type': 0, 'event_group': 0, 'emotion': 0, 'tense': 0, 'sarcasm': 0}
    val_counts = {'event_type': 0, 'event_group': 0, 'emotion': 0, 'tense': 0, 'sarcasm': 0}

    with torch.no_grad():
        for batch_idx, batch in enumerate(val_loader):
            sys.stdout.write(f'\rEvaluating: {batch_idx+1}/{len(val_loader)} [{(batch_idx+1)/len(val_loader)*100:.1f}%]')
            sys.stdout.flush()

            input_ids = batch['input_ids'].to(config.device)
            attention_mask = batch['attention_mask'].to(config.device)

            labels = {}
            for key, value in batch['labels'].items():
                labels[key] = value.to(config.device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            loss, _ = criterion(outputs, labels)
            total_val_loss += loss.item()

            # Calculate accuracies
            for task in ['event_type', 'event_group', 'emotion', 'tense', 'sarcasm']:
                preds = torch.argmax(outputs[task], dim=1)
                correct = (preds == labels[task]).sum().item()
                val_accuracies[task] += correct
                val_counts[task] += labels[task].size(0)

            del input_ids, attention_mask, labels, outputs

    # Calculate final metrics
    avg_val_loss = total_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)

    for task in val_accuracies:
        val_accuracies[task] = val_accuracies[task] / val_counts[task]

    print(f'\nVal Loss: {avg_val_loss:.4f}')
    print('Val Accuracies:', {k: f'{v:.3f}' for k, v in val_accuracies.items()})

    # Save checkpoint
    checkpoint_path = os.path.join(config.save_dir, f'checkpoint_epoch_{epoch+1}.pth')
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'label_encoders': label_encoders,
        'config': config,
        'val_loss': avg_val_loss,
        'val_accuracies': val_accuracies
    }, checkpoint_path)
    print(f'✓ Saved checkpoint: {checkpoint_path}')

    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save({
            'model_state_dict': model.state_dict(),
            'label_encoders': label_encoders,
            'config': config,
            'best_val_loss': best_val_loss
        }, os.path.join(config.save_dir, 'best_model.pth'))
        print('✓ New best model saved!')
        patience_counter = 0
    else:
        patience_counter += 1

    # Clear memory
    gc.collect()
    torch.cuda.empty_cache()

    if patience_counter >= config.early_stopping_patience:
        print(f'\nEarly stopping at epoch {epoch+1}')
        break

print(f'\n✓ Training complete! Best val loss: {best_val_loss:.4f}')

In [None]:
# %%
# Cell 12: Plot Training History
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss', linewidth=2)
plt.plot(val_losses, label='Val Loss', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.grid(True, alpha=0.3)

# Plot accuracies if available
if 'val_accuracies' in locals():
    plt.subplot(1, 2, 2)
    tasks = list(val_accuracies.keys())
    values = list(val_accuracies.values())
    plt.bar(tasks, values)
    plt.xlabel('Task')
    plt.ylabel('Validation Accuracy')
    plt.title('Final Validation Accuracies')
    plt.xticks(rotation=45)
    plt.ylim(0, 1.0)

plt.tight_layout()
plt.savefig(os.path.join(config.save_dir, 'training_history.png'))
plt.show()

# %%
# Cell 13: Prediction Functions
def predict_event(text, model, tokenizer, label_encoders, device='cpu'):
    """Single prediction function"""
    model.eval()

    # Tokenize
    encoding = tokenizer(text, truncation=True, padding='max_length',
                        max_length=config.max_length, return_tensors='pt')

    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    # Predict
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

        # Get EVENT_TYPE predictions (primary task)
        probabilities = torch.softmax(outputs['event_type'], dim=1)
        prediction = torch.argmax(outputs['event_type'], dim=1)

    predicted_event = label_encoders.inverse_transform(prediction.cpu().numpy(), 'event_type')[0]
    confidence = probabilities[0, prediction[0]].item()

    # Top 3 predictions for EVENT_TYPE
    top_probs, top_indices = torch.topk(probabilities[0], 3)
    top_predictions = [(label_encoders.inverse_transform([idx.item()], 'event_type')[0], prob.item())
                      for idx, prob in zip(top_indices, top_probs)]

    # Get other predictions
    other_predictions = {}

    # Classifications
    for task in ['event_group', 'emotion', 'tense']:
        task_probs = torch.softmax(outputs[task], dim=1)
        task_pred = torch.argmax(outputs[task], dim=1)
        other_predictions[task] = {
            'prediction': label_encoders.inverse_transform(task_pred.cpu().numpy(), task)[0],
            'confidence': task_probs[0, task_pred[0]].item()
        }

    # Sarcasm
    sarcasm_probs = torch.softmax(outputs['sarcasm'], dim=1)
    sarcasm_pred = torch.argmax(outputs['sarcasm'], dim=1)
    other_predictions['sarcasm'] = {
        'prediction': 'TRUE' if sarcasm_pred.item() == 1 else 'FALSE',
        'confidence': sarcasm_probs[0, sarcasm_pred[0]].item()
    }

    # Regressions
    for task in ['sentiment_valence', 'certainty']:
        other_predictions[task] = outputs[task][0].item()

    return {
        'predicted_event': predicted_event,
        'confidence': confidence,
        'top_predictions': top_predictions,
        'other_outputs': other_predictions
    }

def predict_event_multi(text, model, tokenizer, label_encoders, device='cpu', certainty_threshold=0.7):
    """Multi-event prediction function"""
    model.eval()

    # Tokenize
    encoding = tokenizer(text, truncation=True, padding='max_length',
                        max_length=config.max_length, return_tensors='pt')

    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    # Predict
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        probabilities = torch.softmax(outputs['event_type'], dim=1)

    # Get all predictions sorted by probability
    probs, indices = torch.sort(probabilities[0], descending=True)

    # Convert to list of (event, certainty) tuples
    event_classes = label_encoders.encoders['event_type'].classes_
    all_predictions = [(event_classes[idx.item()], prob.item())
                      for idx, prob in zip(indices, probs)]

    # Determine primary event and certainty
    primary_event = all_predictions[0][0]
    primary_certainty = all_predictions[0][1]

    # If certainty is high enough, return single event
    if primary_certainty >= certainty_threshold:
        interpretation = "Clear single event detected"
        events_info = [{
            'event': primary_event,
            'confidence': primary_certainty
        }]
        combined_certainty = primary_certainty
    else:
        # Accumulate events until threshold is met
        selected_events = []
        selected_certainties = []
        cumulative_certainty = 0.0

        for event, certainty in all_predictions:
            if cumulative_certainty >= certainty_threshold or len(selected_events) >= 3:
                break
            selected_events.append(event)
            selected_certainties.append(certainty)
            cumulative_certainty += certainty * (1 - cumulative_certainty)

        # Create interpretation
        if len(selected_events) == 2:
            interpretation = f"Could be {selected_events[0]} or {selected_events[1]}"
        else:
            interpretation = "Multiple possible events detected"

        events_info = [{'event': e, 'confidence': c}
                      for e, c in zip(selected_events, selected_certainties)]
        combined_certainty = cumulative_certainty

    # Get all other predictions
    all_outputs = predict_event(text, model, tokenizer, label_encoders, device)

    return {
        'events': events_info,
        'combined_certainty': combined_certainty,
        'interpretation': interpretation,
        'all_predictions': all_outputs['other_outputs']
    }

print("✓ Prediction functions defined")

# %%
# Cell 14: Load Best Model and Test
# Load best model
best_model_path = os.path.join(config.save_dir, 'best_model.pth')
if os.path.exists(best_model_path):
    checkpoint = torch.load(best_model_path, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print(f"✓ Best model loaded")
    print(f"  Best val loss: {checkpoint['best_val_loss']:.4f}")
else:
    print("⚠️  No best model found, using current model")

# %%
# Cell 15: Test Examples
print("\nTesting predictions:")
print("="*60)

test_examples = [
    # Easy examples
    "I just got promoted to senior manager!",
    "They laid me off after 10 years with the company",
    "We're getting married next month!",
    "I was diagnosed with diabetes yesterday",

    # Medium difficulty
    "Starting my freelance journey next month!",
    "Got the promotion but no salary increase",
    "We're taking a break to figure things out",

    # Hard examples
    "Oh great, another pay cut. Just what I needed.",
    "They called it a 'mutual decision' but we all know what that means",
    "Everyone says I'm lucky to still have a job"
]

for i, text in enumerate(test_examples):
    result = predict_event_multi(text, model, tokenizer, label_encoders, config.device)

    print(f"\n{i+1}. Text: \"{text}\"")
    print(f"   Interpretation: {result['interpretation']}")

    # Show events
    for event_info in result['events']:
        print(f"   EVENT: {event_info['event']} ({event_info['confidence']:.2%})")

    # Show all predictions
    preds = result['all_predictions']
    print(f"   GROUP: {preds['event_group']['prediction']} ({preds['event_group']['confidence']:.2%})")
    print(f"   EMOTION: {preds['emotion']['prediction']} ({preds['emotion']['confidence']:.2%})")
    print(f"   SENTIMENT: {preds['sentiment_valence']:.2f}")
    print(f"   SARCASM: {preds['sarcasm']['prediction']}")
    print(f"   TENSE: {preds['tense']['prediction']}")
    print(f"   CERTAINTY: {preds['certainty']:.2f}")
    print("-"*60)

# %%
# Cell 16: Save Final Model Package
# Save everything in an organized way
final_save_path = os.path.join(config.save_dir, 'final_model_package.pth')
torch.save({
    'model_state_dict': model.state_dict(),
    'label_encoders': label_encoders,
    'config': config,
    'tokenizer_name': config.model_name,
    'event_classes': list(label_encoders.encoders['event_type'].classes_),
    'emotion_classes': list(label_encoders.encoders['emotion'].classes_),
    'tense_classes': list(label_encoders.encoders['tense'].classes_),
    'event_group_classes': list(label_encoders.encoders['event_group'].classes_),
}, final_save_path)

# Save tokenizer
tokenizer_path = os.path.join(config.save_dir, 'tokenizer')
tokenizer.save_pretrained(tokenizer_path)

# Save label encoders separately for compatibility
import pickle
for name, encoder in label_encoders.encoders.items():
    with open(os.path.join(config.save_dir, f'le_{name}.pkl'), 'wb') as f:
        pickle.dump(encoder, f)

# Save inference config as JSON
inference_config = {
    'model_name': config.model_name,
    'max_length': config.max_length,
    'multi_event_threshold': config.multi_event_threshold,
    'device': str(config.device),
    'num_classes': label_encoders.num_classes,
    'save_dir': config.save_dir
}

with open(os.path.join(config.save_dir, 'inference_config.json'), 'w') as f:
    json.dump(inference_config, f, indent=2)

print(f"\n✓ Everything saved to: {config.save_dir}")
print(f"  - Final model package: final_model_package.pth")
print(f"  - Tokenizer: tokenizer/")
print(f"  - Label encoders: le_*.pkl")
print(f"  - Best model: best_model.pth")
print(f"  - Inference config: inference_config.json")

# %%
# Cell 17: Quick Helper Functions
def quick_predict(text):
    """Quick prediction using the loaded model"""
    return predict_event_multi(text, model, tokenizer, label_encoders, config.device)

def show_confidence_distribution(text):
    """Show confidence distribution for all event types"""
    model.eval()
    encoding = tokenizer(text, truncation=True, padding='max_length',
                        max_length=config.max_length, return_tensors='pt')

    with torch.no_grad():
        outputs = model(encoding['input_ids'].to(config.device),
                       encoding['attention_mask'].to(config.device))
        probs = torch.softmax(outputs['event_type'], dim=1)[0]

    # Get top 10
    top_probs, top_indices = torch.topk(probs, 10)

    print(f"\nTop 10 event predictions for: \"{text}\"")
    for i, (idx, prob) in enumerate(zip(top_indices, top_probs)):
        event = label_encoders.encoders['event_type'].classes_[idx]
        print(f"{i+1:2d}. {event:30s} {prob.item():.2%}")

# Test the helper
print("\n" + "="*60)
print("Testing helper functions:")
show_confidence_distribution("I got laid off but already have 3 interviews lined up")

# %%
# Cell 18: Interactive User Testing
print("\n" + "="*60)
print("INTERACTIVE TESTING MODE")
print("="*60)
print("\nEnter text to analyze (or 'quit' to exit)")
print("Commands:")
print("  - 'quit' or 'q': exit")
print("  - 'top10': show top 10 predictions for last text")
print("  - 'examples': show example texts")
print("="*60)

example_texts = [
    "I finally got the promotion I've been working towards!",
    "My contract wasn't renewed, time to look for something new",
    "Starting my own business after 15 years in corporate",
    "The doctor says I need surgery next month",
    "We're expecting our first baby!",
    "Another reorganization, and I'm being moved to a different department"
]

last_text = ""

while True:
    user_input = input("\nEnter text: ").strip()

    if user_input.lower() in ['quit', 'exit', 'q']:
        print("Goodbye!")
        break

    elif user_input.lower() == 'examples':
        print("\nExample texts:")
        for i, ex in enumerate(example_texts, 1):
            print(f"{i}. {ex}")
        continue

    elif user_input.lower() == 'top10' and last_text:
        show_confidence_distribution(last_text)
        continue

    elif user_input == '':
        continue

    # Make prediction
    last_text = user_input
    result = predict_event_multi(user_input, model, tokenizer, label_encoders, config.device)

    # Display results
    print(f"\n{'='*60}")
    print(f"ANALYSIS RESULTS")
    print(f"{'='*60}")

    print(f"\n📝 Text: \"{user_input}\"")
    print(f"\n🎯 {result['interpretation']}")

    # Event predictions
    print("\n📊 Event Predictions:")
    for i, event_info in enumerate(result['events'], 1):
        emoji = "✅" if i == 1 else "🔸"
        print(f"   {emoji} {event_info['event']:25s} ({event_info['confidence']:.1%})")

    # Other predictions
    preds = result['all_predictions']

    print(f"\n📁 Event Group: {preds['event_group']['prediction']} ({preds['event_group']['confidence']:.1%})")

    # Emotion with emoji
    emotion = preds['emotion']['prediction']
    emotion_emojis = {
        'joy': '😊', 'sadness': '😢', 'anger': '😠', 'fear': '😨',
        'surprise': '😲', 'disgust': '🤢', 'neutral': '😐', 'anxiety': '😰',
        'hope': '🤗', 'pride': '😌', 'disappointment': '😞', 'relief': '😌'
    }
    emoji = emotion_emojis.get(emotion, '🔵')
    print(f"💭 Emotion: {emotion} {emoji} ({preds['emotion']['confidence']:.1%})")

    # Sentiment with bar
    sentiment = preds['sentiment_valence']
    bar_length = 20
    filled = int(sentiment * bar_length)
    bar = '█' * filled + '░' * (bar_length - filled)
    print(f"📈 Sentiment: [{bar}] {sentiment:.2f}")

    # Sarcasm
    if preds['sarcasm']['prediction'] == 'TRUE':
        print(f"🙄 Sarcasm: DETECTED! ({preds['sarcasm']['confidence']:.1%})")
    else:
        print(f"📌 Sarcasm: Not detected")

    # Tense and Certainty
    print(f"⏰ Tense: {preds['tense']['prediction']}")
    print(f"🎲 Certainty: {preds['certainty']:.2f}")

    print(f"\n💡 Tip: Type 'top10' to see confidence distribution for all events")

print("\n✅ Session ended. Model and results saved to:", config.save_dir)

In [None]:
# Load your already trained model and save clean weights
import torch
import json
import os
import numpy as np
import torch.nn as nn

# Mount drive if not already mounted
from google.colab import drive
drive.mount('/content/drive')

# Define minimal classes that the checkpoint expects (just empty shells)
class MultiTaskLabelEncoders:
    def __init__(self):
        self.encoders = {}
        self.num_classes = {}

class Config:
    pass

save_dir = '/content/drive/MyDrive/EC-model-v2'

print("Loading your existing trained model...")
# Load with weights_only=False since it's your own trusted model
checkpoint = torch.load(os.path.join(save_dir, 'best_model.pth'),
                       map_location='cpu',
                       weights_only=False)

# Extract just the model weights
print("Extracting clean weights...")
model_weights = checkpoint['model_state_dict']

# Save ONLY the weights (no classes or other objects)
clean_path = os.path.join(save_dir, 'model_weights_clean.pth')
torch.save(model_weights, clean_path)
print(f"✓ Saved clean weights to: {clean_path}")

# Verify file size
size_mb = os.path.getsize(clean_path) / (1024 * 1024)
print(f"✓ File size: {size_mb:.1f} MB")

# Quick check - show some layer names
print(f"✓ Total weight tensors: {len(model_weights)}")
print("\nSample layers:")
for i, key in enumerate(list(model_weights.keys())[:5]):
    print(f"  - {key}")
print("  ...")
print("✓ Ready to download and use locally!")