# MBTI Personality Prediction - Model Training

This notebook trains 4 transformer models on 4 MBTI traits (16 total training runs).

**Features:**
- Automatic checkpointing after each model/trait
- Resume from where you left off if disconnected
- Saves best models to Google Drive
- Progress tracking

## 1. Setup & Mount Google Drive

In [None]:
# Mount Google Drive for saving checkpoints
from google.colab import drive
drive.mount('/content/drive')

# Create project folder
import os
PROJECT_DIR = '/content/drive/MyDrive/MBTI_AI'
CHECKPOINT_DIR = f'{PROJECT_DIR}/checkpoints'
MODEL_DIR = f'{PROJECT_DIR}/models'

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)
print(f"Project directory: {PROJECT_DIR}")

In [None]:
# Install required packages
!pip install -q transformers datasets accelerate

In [None]:
# Upload your dataset (mbti_1.csv)
# Option 1: Upload from local machine
from google.colab import files

# Check if dataset already exists in Drive
DATASET_PATH = f'{PROJECT_DIR}/mbti_1.csv'

if not os.path.exists(DATASET_PATH):
    print("Please upload mbti_1.csv")
    uploaded = files.upload()
    # Move to project directory
    !mv mbti_1.csv {DATASET_PATH}
else:
    print(f"Dataset found at {DATASET_PATH}")

## 2. Imports & Configuration

In [None]:
import re
import json
import pandas as pd
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score
from tqdm.notebook import tqdm
import warnings

warnings.filterwarnings('ignore')

# Set seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Check GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Configuration
BERT_MODEL_NAME = 'bert-base-uncased'
DEBERTA_MODEL_NAME = 'microsoft/deberta-v3-small'

MAX_LEN = 256
BATCH_SIZE = 16
EPOCHS = 3

# Learning rates
BASIC_LR = 2e-5
BERT_DEEP_HEAD_LR = 1e-4
BERT_DEEP_BERT_LR = 1e-5
BERT_DEEP_DROPOUT = 0.55
DEBERTA_DEEP_HEAD_LR = 1e-4
DEBERTA_DEEP_BERT_LR = 1e-5
DEBERTA_DEEP_DROPOUT = 0.55
ABLATION_HEAD_LR = 1e-4
ABLATION_BERT_LR = 1e-5
ABLATION_DROPOUT = 0.55

# Traits to train
TRAITS = ['is_I', 'is_N', 'is_T', 'is_P']
TRAIT_NAMES = {
    'is_I': 'Mind (I/E)', 
    'is_N': 'Energy (N/S)', 
    'is_T': 'Nature (T/F)', 
    'is_P': 'Tactics (P/J)'
}

# Models to train
MODEL_NAMES = [
    'Basic_BERT',
    'BERT_Deep_Head',
    'DeBERTa_Deep_Head',
    'DeBERTa_AttnPool_Deep'
]

## 3. Data Loading & Preprocessing

In [None]:
def clean_text(text):
    text = text.replace('|||', ' ')
    text = re.sub(r'http\S+', '', text)
    text = re.sub(r'@\w+', '', text)
    text = text.lower()
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def load_and_preprocess_data(path):
    df = pd.read_csv(path)
    df['cleaned_posts'] = df['posts'].apply(clean_text)
    df['is_I'] = df['type'].apply(lambda x: 1 if x[0] == 'I' else 0)
    df['is_N'] = df['type'].apply(lambda x: 1 if x[1] == 'N' else 0)
    df['is_T'] = df['type'].apply(lambda x: 1 if x[2] == 'T' else 0)
    df['is_P'] = df['type'].apply(lambda x: 1 if x[3] == 'P' else 0)
    return df

# Load data
print("Loading data...")
df = load_and_preprocess_data(DATASET_PATH)
print(f"Loaded {len(df)} samples")
print(f"\nMBTI Type distribution:")
print(df['type'].value_counts().head(10))

In [None]:
class MBTITraitDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.float)
        }

## 4. Model Architectures

In [None]:
def create_very_deep_head(input_size, output_size, hidden_sizes, dropout_rate):
    """Creates a sequential head with multiple linear layers."""
    layers = []
    current_size = input_size
    for hidden_size in hidden_sizes:
        layers.extend([
            nn.Linear(current_size, hidden_size),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        ])
        current_size = hidden_size
    layers.append(nn.Linear(current_size, output_size))
    return nn.Sequential(*layers)


class BaselineBERTClassifier(nn.Module):
    """Model 1: Basic BERT"""
    def __init__(self, n_out=1, dropout_rate=0.3):
        super().__init__()
        self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.classifier = nn.Linear(self.bert.config.hidden_size, n_out)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits


class BERTDeepHeadClassifier(nn.Module):
    """Model 2: BERT with very deep head (6 linear layers)"""
    def __init__(self, n_out=1, dropout_rate=BERT_DEEP_DROPOUT):
        super().__init__()
        self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)
        bert_hidden_size = self.bert.config.hidden_size
        hidden_sizes = [512, 256, 128, 64, 32]
        self.head = create_very_deep_head(bert_hidden_size, n_out, hidden_sizes, dropout_rate)
        self.dropout = nn.Dropout(p=dropout_rate)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        logits = self.head(pooled_output)
        return logits


class DeBERTaClassifier(nn.Module):
    """Model 3: DeBERTa with very deep head (6 linear layers, CLS pooling)"""
    def __init__(self, n_out=1, dropout_rate=DEBERTA_DEEP_DROPOUT):
        super().__init__()
        self.bert = AutoModel.from_pretrained(DEBERTA_MODEL_NAME)
        bert_hidden_size = self.bert.config.hidden_size
        hidden_sizes = [512, 256, 128, 64, 32]
        self.head = create_very_deep_head(bert_hidden_size, n_out, hidden_sizes, dropout_rate)
        self.dropout = nn.Dropout(p=dropout_rate)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]  # CLS token
        pooled_output = self.dropout(pooled_output)
        logits = self.head(pooled_output)
        return logits


class AttentionPooling(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attention_net = nn.Linear(hidden_size, 1)
    
    def forward(self, hidden_states, attention_mask):
        scores = self.attention_net(hidden_states)
        scores.masked_fill_(attention_mask.unsqueeze(-1) == 0, -float('inf'))
        attn_weights = F.softmax(scores, dim=1)
        context = torch.sum(attn_weights * hidden_states, dim=1)
        return context


class DeBERTaAblationModel(nn.Module):
    """Model 4: DeBERTa + Attention Pooling + Very Deep Head"""
    def __init__(self, n_out=1, dropout_rate=ABLATION_DROPOUT):
        super().__init__()
        self.bert = AutoModel.from_pretrained(DEBERTA_MODEL_NAME)
        bert_hidden_size = self.bert.config.hidden_size
        self.attention_pooling = AttentionPooling(bert_hidden_size)
        hidden_sizes = [512, 256, 128, 64, 32]
        self.head = create_very_deep_head(bert_hidden_size, n_out, hidden_sizes, dropout_rate)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        pooled_output = self.attention_pooling(last_hidden_state, attention_mask)
        logits = self.head(pooled_output)
        return logits

## 5. Training & Evaluation Functions

In [None]:
def train_epoch(model, data_loader, loss_fn, optimizer, device, scheduler=None):
    model.train()
    total_loss = 0
    pbar = tqdm(data_loader, desc="Training")
    
    for batch in pbar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = loss_fn(outputs.squeeze(), labels)
        total_loss += loss.item()
        
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        if scheduler:
            scheduler.step()
        optimizer.zero_grad()
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / len(data_loader)


def eval_model(model, data_loader, device):
    model.eval()
    all_labels, all_probs, all_preds = [], [], []
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            probs = torch.sigmoid(outputs).cpu().numpy()
            preds = (probs > 0.5).astype(int)
            
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.flatten())
            all_preds.extend(preds.flatten())
    
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='binary', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='binary', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='binary')
    
    try:
        auc_roc = roc_auc_score(all_labels, all_probs)
    except ValueError:
        auc_roc = 0.5
    
    return {
        'Accuracy': accuracy,
        'Precision': precision,
        'Recall': recall,
        'F1': f1,
        'AUC-ROC': auc_roc
    }

## 6. Checkpoint Management

In [None]:
def get_progress_file():
    return f'{CHECKPOINT_DIR}/training_progress.json'

def load_progress():
    """Load training progress from checkpoint."""
    progress_file = get_progress_file()
    if os.path.exists(progress_file):
        with open(progress_file, 'r') as f:
            return json.load(f)
    return {'completed': [], 'results': {}}

def save_progress(progress):
    """Save training progress to checkpoint."""
    progress_file = get_progress_file()
    with open(progress_file, 'w') as f:
        json.dump(progress, f, indent=2)

def is_completed(trait, model_name, progress):
    """Check if a trait/model combination is already trained."""
    key = f"{trait}_{model_name}"
    return key in progress['completed']

def mark_completed(trait, model_name, metrics, progress):
    """Mark a trait/model combination as completed."""
    key = f"{trait}_{model_name}"
    progress['completed'].append(key)
    progress['results'][key] = metrics
    save_progress(progress)

def save_model(model, trait, model_name):
    """Save trained model."""
    path = f"{MODEL_DIR}/{trait}_{model_name}.pt"
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

# Load existing progress
progress = load_progress()
print(f"Completed: {len(progress['completed'])}/16 training runs")
if progress['completed']:
    print(f"Already completed: {progress['completed']}")

## 7. Main Training Loop

In [None]:
# Initialize tokenizers
print("Loading tokenizers...")
bert_tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
deberta_tokenizer = AutoTokenizer.from_pretrained(DEBERTA_MODEL_NAME)
print("Tokenizers loaded!")

In [None]:
def train_model(model_name, trait, df_train, df_val, progress):
    """Train a single model on a single trait."""
    
    # Check if already completed
    if is_completed(trait, model_name, progress):
        print(f"\n[SKIP] {model_name} on {TRAIT_NAMES[trait]} - Already completed")
        return progress['results'][f"{trait}_{model_name}"]
    
    print(f"\n{'='*60}")
    print(f"Training {model_name} on {TRAIT_NAMES[trait]}")
    print(f"{'='*60}")
    
    # Select tokenizer and create data loaders
    if 'BERT' in model_name and 'DeBERTa' not in model_name:
        tokenizer = bert_tokenizer
    else:
        tokenizer = deberta_tokenizer
    
    train_dataset = MBTITraitDataset(df_train['cleaned_posts'].values, df_train[trait].values, tokenizer, MAX_LEN)
    val_dataset = MBTITraitDataset(df_val['cleaned_posts'].values, df_val[trait].values, tokenizer, MAX_LEN)
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
    
    # Create model
    if model_name == 'Basic_BERT':
        model = BaselineBERTClassifier().to(device)
        optimizer = AdamW(model.parameters(), lr=BASIC_LR, correct_bias=False)
        total_steps = len(train_loader) * EPOCHS
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
        use_epoch_scheduler = False
        
    elif model_name == 'BERT_Deep_Head':
        model = BERTDeepHeadClassifier().to(device)
        optimizer = AdamW([
            {"params": model.bert.parameters(), "lr": BERT_DEEP_BERT_LR},
            {"params": model.head.parameters(), "lr": BERT_DEEP_HEAD_LR}
        ])
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=EPOCHS)
        use_epoch_scheduler = True
        
    elif model_name == 'DeBERTa_Deep_Head':
        model = DeBERTaClassifier().to(device)
        optimizer = AdamW([
            {"params": model.bert.parameters(), "lr": DEBERTA_DEEP_BERT_LR},
            {"params": model.head.parameters(), "lr": DEBERTA_DEEP_HEAD_LR}
        ])
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=EPOCHS)
        use_epoch_scheduler = True
        
    elif model_name == 'DeBERTa_AttnPool_Deep':
        model = DeBERTaAblationModel().to(device)
        optimizer = AdamW([
            {"params": model.bert.parameters(), "lr": ABLATION_BERT_LR},
            {"params": model.attention_pooling.parameters(), "lr": ABLATION_HEAD_LR},
            {"params": model.head.parameters(), "lr": ABLATION_HEAD_LR}
        ])
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=EPOCHS)
        use_epoch_scheduler = True
    
    loss_fn = nn.BCEWithLogitsLoss().to(device)
    best_f1 = 0
    best_metrics = {}
    
    # Training loop
    for epoch in range(EPOCHS):
        print(f"\nEpoch {epoch + 1}/{EPOCHS}")
        
        if use_epoch_scheduler:
            train_loss = train_epoch(model, train_loader, loss_fn, optimizer, device, scheduler=None)
            scheduler.step()
        else:
            train_loss = train_epoch(model, train_loader, loss_fn, optimizer, device, scheduler=scheduler)
        
        metrics = eval_model(model, val_loader, device)
        print(f"Val F1: {metrics['F1']:.4f} | Val Acc: {metrics['Accuracy']:.4f} | Val AUC: {metrics['AUC-ROC']:.4f}")
        
        if metrics['F1'] > best_f1:
            best_f1 = metrics['F1']
            best_metrics = metrics.copy()
            # Save best model
            save_model(model, trait, model_name)
    
    # Mark as completed and save progress
    mark_completed(trait, model_name, best_metrics, progress)
    print(f"\nBest F1 for {model_name} on {TRAIT_NAMES[trait]}: {best_f1:.4f}")
    
    # Cleanup
    del model, optimizer, scheduler
    torch.cuda.empty_cache()
    
    return best_metrics

In [None]:
# Main training loop with checkpointing
print("\n" + "="*70)
print("STARTING TRAINING - 4 Models x 4 Traits = 16 Training Runs")
print("Progress is saved after each model. You can resume if disconnected.")
print("="*70)

all_results = {}

for trait in TRAITS:
    trait_name = TRAIT_NAMES[trait]
    print(f"\n\n{'#'*70}")
    print(f"### TRAIT: {trait_name}")
    print(f"{'#'*70}")
    
    # Split data for this trait
    df_train, df_val = train_test_split(
        df, test_size=0.1, random_state=42, stratify=df[trait]
    )
    print(f"Train: {len(df_train)}, Val: {len(df_val)}")
    
    for model_name in MODEL_NAMES:
        metrics = train_model(model_name, trait, df_train, df_val, progress)
        all_results[(trait_name, model_name)] = metrics

print("\n" + "="*70)
print("TRAINING COMPLETE!")
print("="*70)

## 8. Final Results

In [None]:
# Display final results
print("\n" + "="*70)
print("FINAL COMPARISON REPORT")
print("="*70 + "\n")

# Load all results from progress file
progress = load_progress()

# Convert to DataFrame
results_data = []
for key, metrics in progress['results'].items():
    trait, model = key.rsplit('_', 1)[0], key.split('_')[-1]
    # Reconstruct proper names
    for t in TRAITS:
        if key.startswith(t):
            trait = TRAIT_NAMES[t]
            model = key[len(t)+1:]
            break
    results_data.append({
        'Trait': trait,
        'Model': model,
        **metrics
    })

results_df = pd.DataFrame(results_data)
results_df = results_df.round(4)

# Display
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
print(results_df.to_string(index=False))

# Save results to CSV
results_csv_path = f"{PROJECT_DIR}/final_results.csv"
results_df.to_csv(results_csv_path, index=False)
print(f"\nResults saved to {results_csv_path}")

In [None]:
# Best model per trait
print("\n" + "="*70)
print("BEST MODEL PER TRAIT (by F1 Score)")
print("="*70 + "\n")

for trait in TRAIT_NAMES.values():
    trait_results = results_df[results_df['Trait'] == trait]
    if len(trait_results) > 0:
        best_idx = trait_results['F1'].idxmax()
        best = trait_results.loc[best_idx]
        print(f"{trait}: {best['Model']} (F1: {best['F1']:.4f}, Acc: {best['Accuracy']:.4f})")

## 9. Download Models (Optional)

In [None]:
# List saved models
print("Saved models:")
for f in os.listdir(MODEL_DIR):
    if f.endswith('.pt'):
        size_mb = os.path.getsize(f"{MODEL_DIR}/{f}") / (1024*1024)
        print(f"  - {f} ({size_mb:.1f} MB)")

In [None]:
# Optional: Download a specific model
# from google.colab import files
# files.download(f"{MODEL_DIR}/is_I_Basic_BERT.pt")