In [5]:
# Imports
import gc
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import BertTokenizer, BertForSequenceClassification, T5Tokenizer, T5ForConditionalGeneration
import optuna
import os
import json
from tqdm import tqdm
import numpy as np

In [6]:
# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"Number of GPUs available: {torch.cuda.device_count()}")

# Load dataset
with open('/kaggle/input/moe-dataset/combined_scientific_papers.json', 'r') as f:
    data = json.load(f)

# Domain to label mapping
domain_to_label = {domain: idx for idx, domain in enumerate(set(entry['domain'] for entry in data))}
num_labels = len(domain_to_label)
# Define num_experts based on the number of unique domains/labels
num_experts = num_labels
print(f"Number of experts/domains: {num_experts}")

Using device: cuda
Number of GPUs available: 2
Number of experts/domains: 3


In [7]:
class ScientificDataset(Dataset):
    def __init__(self, data, domain_to_label):
        self.queries = [entry['text'][:100] for entry in data]
        self.labels = [domain_to_label[entry['domain']] for entry in data]
        self.responses = [entry['text'] for entry in data]
    
    def __len__(self):
        return len(self.queries)
    
    def __getitem__(self, idx):
        return self.queries[idx], self.labels[idx], self.responses[idx]

# Collate functions
def gating_collate_fn(batch):
    queries, labels, _ = zip(*batch)
    tokenized = bert_tokenizer(list(queries), padding=True, truncation=True, return_tensors='pt')
    return tokenized, torch.tensor(labels)

def expert_collate_fn(batch):
    queries, responses = zip(*batch)
    inputs = t5_tokenizer(list(queries), padding=True, truncation=True, return_tensors='pt')
    targets = t5_tokenizer(list(responses), padding=True, truncation=True, return_tensors='pt')
    return inputs, targets['input_ids']

# Expert dataset
class TextDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

# Extract expert data
def get_expert_data(dataset, expert_id):
    return [(query, response) for query, label, response in dataset if label == expert_id]

# Training functions with gradient accumulation
def train_gating_model(model, train_loader, val_loader, lr, epochs, accumulation_steps=4):
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    model = nn.DataParallel(model)  # Utilize multiple GPUs
    model.to(device)
    loss_fct = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        optimizer.zero_grad()
        for i, batch in enumerate(train_loader):
            inputs, labels = batch
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)
            # Call the model without labels to get logits
            outputs = model(input_ids=inputs['input_ids'], 
                            attention_mask=inputs['attention_mask'])
            logits = outputs.logits
            loss = loss_fct(logits, labels) / accumulation_steps
            loss.backward()
            if (i + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
            total_loss += loss.item() * accumulation_steps
        avg_loss = total_loss / len(train_loader)
        print(f"Gating Epoch {epoch+1}, Train Loss: {avg_loss:.4f}")
        gc.collect()
        torch.cuda.empty_cache()
    return model

def train_expert(model, train_loader, val_loader, lr, epochs, expert_id, accumulation_steps=4):
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    model = nn.DataParallel(model)  # Utilize multiple GPUs
    model.to(device)
    pad_token_id = model.module.config.pad_token_id  # Access underlying model's config
    loss_fct = nn.CrossEntropyLoss(ignore_index=pad_token_id)
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        optimizer.zero_grad()
        for i, batch in enumerate(train_loader):
            inputs, labels = batch
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)
            # Prepare decoder input and target labels
            decoder_input_ids = labels[:, :-1].clone()
            target_labels = labels[:, 1:].clone()
            # Call model without labels to get logits
            outputs = model(input_ids=inputs['input_ids'], 
                            attention_mask=inputs['attention_mask'], 
                            decoder_input_ids=decoder_input_ids)
            logits = outputs.logits
            # Compute loss
            loss = loss_fct(logits.view(-1, logits.size(-1)), target_labels.view(-1)) / accumulation_steps
            loss.backward()
            if (i + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
            total_loss += loss.item() * accumulation_steps
        avg_loss = total_loss / len(train_loader)
        print(f"Expert {expert_id} Epoch {epoch+1}, Train Loss: {avg_loss:.4f}")
        gc.collect()
        torch.cuda.empty_cache()
    return model

# Evaluation functions
def evaluate_gating_model(model, val_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in val_loader:
            inputs, labels = batch
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)
            # Explicitly specify input parameters
            outputs = model(input_ids=inputs['input_ids'], 
                           attention_mask=inputs['attention_mask'])
            _, predicted = torch.max(outputs.logits, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

def evaluate_expert(model, val_loader):
    if len(val_loader) == 0:
        return 0.0  # Return 0 loss if validation loader is empty
    model.eval()
    total_loss = 0
    pad_token_id = model.module.config.pad_token_id
    loss_fct = nn.CrossEntropyLoss(ignore_index=pad_token_id)
    with torch.no_grad():
        for batch in val_loader:
            inputs, labels = batch
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)
            decoder_input_ids = labels[:, :-1].clone()
            target_labels = labels[:, 1:].clone()
            outputs = model(input_ids=inputs['input_ids'], 
                            attention_mask=inputs['attention_mask'], 
                            decoder_input_ids=decoder_input_ids)
            logits = outputs.logits
            loss = loss_fct(logits.view(-1, logits.size(-1)), target_labels.view(-1))
            total_loss += loss.item()
    return total_loss / len(val_loader)

# Optuna objective
def objective(trial):
    # Hyperparameter suggestions
    gating_lr = trial.suggest_float('gating_lr', 1e-5, 1e-3, log=True)
    gating_epochs = trial.suggest_int('gating_epochs', 3, 10)
    expert_lr = trial.suggest_float('expert_lr', 1e-5, 1e-3, log=True)
    expert_epochs = trial.suggest_int('expert_epochs', 3, 10)
    
    # Load and train gating model
    gating_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=num_experts).to(device)
    gating_model = train_gating_model(gating_model, train_loader_gating, val_loader_gating, gating_lr, gating_epochs)
    gating_accuracy = evaluate_gating_model(gating_model, val_loader_gating)
    gating_model.to('cpu')  # Move gating model to CPU
    torch.cuda.empty_cache()  # Clear GPU memory
    
    # Initialize experts dictionary
    experts = {}
    expert_losses = []
    
    # Train each expert sequentially
    for expert_id in range(num_experts):
        train_data = train_expert_data[expert_id]
        val_data = val_expert_data[expert_id]
        if not train_data:
            continue
        # Load expert model onto GPU
        expert_model = T5ForConditionalGeneration.from_pretrained('t5-small').to(device)
        train_dataset_expert = TextDataset(train_data)
        val_dataset_expert = TextDataset(val_data)
        train_loader_expert = DataLoader(train_dataset_expert, batch_size=2, shuffle=True, collate_fn=expert_collate_fn)
        val_loader_expert = DataLoader(val_dataset_expert, batch_size=2, shuffle=False, collate_fn=expert_collate_fn)
        expert_model = train_expert(expert_model, train_loader_expert, val_loader_expert, expert_lr, expert_epochs, expert_id)
        loss = evaluate_expert(expert_model, val_loader_expert)
        expert_losses.append(loss)
        expert_model.to('cpu')  # Move expert model to CPU
        experts[expert_id] = expert_model  # Store on CPU
        torch.cuda.empty_cache()  # Clear GPU memory
    
    # Compute average expert loss
    avg_expert_loss = sum(expert_losses) / len(expert_losses) if expert_losses else 0
    
    # Compute combined metric
    combined_metric = gating_accuracy - 0.1 * avg_expert_loss
    
    # Set user attributes (models are already on CPU)
    trial.set_user_attr('gating_accuracy', gating_accuracy)
    trial.set_user_attr('avg_expert_loss', avg_expert_loss)
    trial.set_user_attr('experts', experts)
    trial.set_user_attr('gating_model', gating_model)
    
    return combined_metric

In [None]:
if __name__ == "__main__":
    # Setup dataset and dataloaders
    full_dataset = ScientificDataset(data, domain_to_label)
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    
    bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    t5_tokenizer = T5Tokenizer.from_pretrained('t5-small', legacy=False)
    
    train_loader_gating = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=gating_collate_fn)
    val_loader_gating = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=gating_collate_fn)
    
    train_expert_data = {i: get_expert_data(train_dataset, i) for i in range(num_experts)}
    val_expert_data = {i: get_expert_data(val_dataset, i) for i in range(num_experts)}
    
    # Run Optuna study
    study = optuna.create_study(direction='maximize')
    study.optimize(objective, n_trials=20)
    
    # Benchmark and compare
    print("\nBenchmarking Top 5 Configurations:")
    top_trials = sorted(study.trials, key=lambda t: t.value, reverse=True)[:5]
    for i, trial in enumerate(top_trials):
        print(f"\nMOE {i+1}:")
        print(f"  Trial Number: {trial.number}")
        print(f"  Combined Metric: {trial.value:.4f}")
        print(f"  Gating Accuracy: {trial.user_attrs['gating_accuracy']:.4f}")
        print(f"  Avg Expert Loss: {trial.user_attrs['avg_expert_loss']:.4f}")
        print(f"  Hyperparameters: {trial.params}")
    
    # Save top 5 MOE models
    for i, trial in enumerate(top_trials):
        moe_dir = f'MOE_{i+1}'
        os.makedirs(moe_dir, exist_ok=True)
        
        gating_model = trial.user_attrs['gating_model']
        experts = trial.user_attrs['experts']
        
        torch.save(gating_model.state_dict(), os.path.join(moe_dir, 'gating.pt'))
        for expert_id, expert in experts.items():
            torch.save(expert.state_dict(), os.path.join(moe_dir, f'expert_{expert_id}.pt'))
        
        metrics = {
            'gating_accuracy': trial.user_attrs['gating_accuracy'],
            'avg_expert_loss': trial.user_attrs['avg_expert_loss'],
            'combined_metric': trial.value
        }
        with open(os.path.join(moe_dir, 'metrics.json'), 'w') as f:
            json.dump(metrics, f)
        with open(os.path.join(moe_dir, 'hyperparams.json'), 'w') as f:
            json.dump(trial.params, f)
    
    print("\nTop 5 MOE models saved.")

[I 2025-04-15 20:54:11,575] A new study created in memory with name: no-name-dca24d30-514f-4b88-a086-0900ac680d04
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Gating Epoch 1, Train Loss: 1.1737
Gating Epoch 2, Train Loss: 1.0351
Gating Epoch 3, Train Loss: 1.0402
Gating Epoch 4, Train Loss: 1.0824
Gating Epoch 5, Train Loss: 1.2443
Expert 0 Epoch 1, Train Loss: 8.6673
Expert 0 Epoch 2, Train Loss: 7.5167
Expert 0 Epoch 3, Train Loss: 6.6618
Expert 0 Epoch 4, Train Loss: 6.6367
Expert 0 Epoch 5, Train Loss: 6.2782
Expert 1 Epoch 1, Train Loss: 9.0753
Expert 1 Epoch 2, Train Loss: 8.4852
Expert 1 Epoch 3, Train Loss: 8.7325
Expert 1 Epoch 4, Train Loss: 7.8988
Expert 1 Epoch 5, Train Loss: 8.8625
Expert 2 Epoch 1, Train Loss: 8.5271
Expert 2 Epoch 2, Train Loss: 8.5603
Expert 2 Epoch 3, Train Loss: 8.0651
Expert 2 Epoch 4, Train Loss: 8.1677
Expert 2 Epoch 5, Train Loss: 8.6763


[I 2025-04-15 20:54:34,592] Trial 0 finished with value: -0.32766657670338944 and parameters: {'gating_lr': 1.101847007202224e-05, 'gating_epochs': 5, 'expert_lr': 8.98230572282742e-05, 'expert_epochs': 5}. Best is trial 0 with value: -0.32766657670338944.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Gating Epoch 1, Train Loss: 1.2531
Gating Epoch 2, Train Loss: 0.9784
Gating Epoch 3, Train Loss: 1.0459
Gating Epoch 4, Train Loss: 0.9803
Gating Epoch 5, Train Loss: 1.3713
Gating Epoch 6, Train Loss: 1.3111
Gating Epoch 7, Train Loss: 1.3390
Gating Epoch 8, Train Loss: 1.3813
Expert 0 Epoch 1, Train Loss: 9.4270
Expert 0 Epoch 2, Train Loss: 7.0293
Expert 0 Epoch 3, Train Loss: 6.6949
Expert 0 Epoch 4, Train Loss: 6.3036
Expert 0 Epoch 5, Train Loss: 6.1333
Expert 0 Epoch 6, Train Loss: 6.0103
Expert 0 Epoch 7, Train Loss: 6.0199
Expert 0 Epoch 8, Train Loss: 5.8100
Expert 0 Epoch 9, Train Loss: 5.8444
Expert 0 Epoch 10, Train Loss: 5.7876
Expert 1 Epoch 1, Train Loss: 8.5045
Expert 1 Epoch 2, Train Loss: 7.9095
Expert 1 Epoch 3, Train Loss: 8.4441
Expert 1 Epoch 4, Train Loss: 8.7526
Expert 1 Epoch 5, Train Loss: 9.1628
Expert 1 Epoch 6, Train Loss: 8.4946
Expert 1 Epoch 7, Train Loss: 8.1799
Expert 1 Epoch 8, Train Loss: 8.7275
Expert 1 Epoch 9, Train Loss: 8.4489
Expert 1 Epoch 1

[I 2025-04-15 20:55:14,473] Trial 1 finished with value: -0.32766657670338944 and parameters: {'gating_lr': 0.00022329585625252254, 'gating_epochs': 8, 'expert_lr': 0.00011410160679719041, 'expert_epochs': 10}. Best is trial 0 with value: -0.32766657670338944.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Gating Epoch 1, Train Loss: 1.1637
Gating Epoch 2, Train Loss: 1.1620
Gating Epoch 3, Train Loss: 1.1692
Gating Epoch 4, Train Loss: 1.1596
Gating Epoch 5, Train Loss: 1.2054
Gating Epoch 6, Train Loss: 1.1765
Gating Epoch 7, Train Loss: 1.1189
Gating Epoch 8, Train Loss: 1.1145
Gating Epoch 9, Train Loss: 1.0725
Expert 0 Epoch 1, Train Loss: 8.5783
Expert 0 Epoch 2, Train Loss: 8.2485
Expert 0 Epoch 3, Train Loss: 7.4153
Expert 0 Epoch 4, Train Loss: 6.7987
Expert 0 Epoch 5, Train Loss: 6.6722
Expert 0 Epoch 6, Train Loss: 6.5171
Expert 1 Epoch 1, Train Loss: 8.5231
Expert 1 Epoch 2, Train Loss: 8.5096
Expert 1 Epoch 3, Train Loss: 8.3391
Expert 1 Epoch 4, Train Loss: 9.4782
Expert 1 Epoch 5, Train Loss: 7.9984
Expert 1 Epoch 6, Train Loss: 9.0284
Expert 2 Epoch 1, Train Loss: 8.4273
Expert 2 Epoch 2, Train Loss: 8.3770
Expert 2 Epoch 3, Train Loss: 7.6669
Expert 2 Epoch 4, Train Loss: 8.1926
Expert 2 Epoch 5, Train Loss: 8.0312
Expert 2 Epoch 6, Train Loss: 8.1981


[I 2025-04-15 20:55:42,867] Trial 2 finished with value: 0.2723334232966106 and parameters: {'gating_lr': 1.3373924091049982e-05, 'gating_epochs': 9, 'expert_lr': 5.378527196272705e-05, 'expert_epochs': 6}. Best is trial 2 with value: 0.2723334232966106.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Gating Epoch 1, Train Loss: 1.1618
Gating Epoch 2, Train Loss: 1.0874
Gating Epoch 3, Train Loss: 1.1540
Gating Epoch 4, Train Loss: 1.0243
Gating Epoch 5, Train Loss: 1.0117
Gating Epoch 6, Train Loss: 1.1699
Expert 0 Epoch 1, Train Loss: 8.4911
Expert 0 Epoch 2, Train Loss: 7.6335
Expert 0 Epoch 3, Train Loss: 7.9650
Expert 1 Epoch 1, Train Loss: 7.8268
Expert 1 Epoch 2, Train Loss: 8.3831
Expert 1 Epoch 3, Train Loss: 9.1865
Expert 2 Epoch 1, Train Loss: 8.8025
Expert 2 Epoch 2, Train Loss: 8.8494
Expert 2 Epoch 3, Train Loss: 8.8332


[I 2025-04-15 20:56:00,115] Trial 3 finished with value: 0.2723334232966106 and parameters: {'gating_lr': 2.0300149566087104e-05, 'gating_epochs': 6, 'expert_lr': 3.0255410006207283e-05, 'expert_epochs': 3}. Best is trial 2 with value: 0.2723334232966106.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Gating Epoch 1, Train Loss: 1.1567
Gating Epoch 2, Train Loss: 1.2704
Gating Epoch 3, Train Loss: 1.1117
Gating Epoch 4, Train Loss: 1.1872
Gating Epoch 5, Train Loss: 1.1143
Gating Epoch 6, Train Loss: 1.1386
Gating Epoch 7, Train Loss: 1.1457
Gating Epoch 8, Train Loss: 1.2420
Gating Epoch 9, Train Loss: 1.2019
Gating Epoch 10, Train Loss: 1.2455
Expert 0 Epoch 1, Train Loss: 9.0788
Expert 0 Epoch 2, Train Loss: 7.7817
Expert 0 Epoch 3, Train Loss: 7.1346
Expert 0 Epoch 4, Train Loss: 6.9690
Expert 0 Epoch 5, Train Loss: 6.8768
Expert 0 Epoch 6, Train Loss: 6.3883
Expert 1 Epoch 1, Train Loss: 8.5003
Expert 1 Epoch 2, Train Loss: 7.8517
Expert 1 Epoch 3, Train Loss: 7.9393
Expert 1 Epoch 4, Train Loss: 8.5147
Expert 1 Epoch 5, Train Loss: 8.4649
Expert 1 Epoch 6, Train Loss: 8.1259
Expert 2 Epoch 1, Train Loss: 9.0304
Expert 2 Epoch 2, Train Loss: 8.1939
Expert 2 Epoch 3, Train Loss: 8.5570
Expert 2 Epoch 4, Train Loss: 8.2146
Expert 2 Epoch 5, Train Loss: 8.3347
Expert 2 Epoch 6, Tr

[I 2025-04-15 20:56:29,842] Trial 4 finished with value: 0.2723334232966106 and parameters: {'gating_lr': 4.7122299629652036e-05, 'gating_epochs': 10, 'expert_lr': 5.4244736163590344e-05, 'expert_epochs': 6}. Best is trial 2 with value: 0.2723334232966106.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Gating Epoch 1, Train Loss: 1.1325
Gating Epoch 2, Train Loss: 1.1495
Gating Epoch 3, Train Loss: 1.0698
Gating Epoch 4, Train Loss: 1.2660
Gating Epoch 5, Train Loss: 1.2952
Expert 0 Epoch 1, Train Loss: 10.0311
Expert 0 Epoch 2, Train Loss: 5.9555
Expert 0 Epoch 3, Train Loss: 5.8558
Expert 0 Epoch 4, Train Loss: 5.4818
Expert 0 Epoch 5, Train Loss: 5.3929
Expert 0 Epoch 6, Train Loss: 5.2637
Expert 0 Epoch 7, Train Loss: 5.1762
Expert 0 Epoch 8, Train Loss: 5.1043
Expert 0 Epoch 9, Train Loss: 4.9578
Expert 1 Epoch 1, Train Loss: 8.5043
Expert 1 Epoch 2, Train Loss: 7.8339
Expert 1 Epoch 3, Train Loss: 8.9376
Expert 1 Epoch 4, Train Loss: 8.7681
Expert 1 Epoch 5, Train Loss: 8.3820
Expert 1 Epoch 6, Train Loss: 8.7119
Expert 1 Epoch 7, Train Loss: 8.6810
Expert 1 Epoch 8, Train Loss: 8.6591
Expert 1 Epoch 9, Train Loss: 9.0675
Expert 2 Epoch 1, Train Loss: 7.8978
Expert 2 Epoch 2, Train Loss: 7.6901
Expert 2 Epoch 3, Train Loss: 8.4966
Expert 2 Epoch 4, Train Loss: 8.3452
Expert 2 E

[I 2025-04-15 20:57:04,630] Trial 5 finished with value: -0.32766657670338944 and parameters: {'gating_lr': 0.0007685339203285668, 'gating_epochs': 5, 'expert_lr': 0.0008806969767732265, 'expert_epochs': 9}. Best is trial 2 with value: 0.2723334232966106.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Gating Epoch 1, Train Loss: 1.0796
Gating Epoch 2, Train Loss: 1.1052
Gating Epoch 3, Train Loss: 1.1449
Gating Epoch 4, Train Loss: 1.1308
Gating Epoch 5, Train Loss: 1.1830
Gating Epoch 6, Train Loss: 1.1104
Gating Epoch 7, Train Loss: 1.1247
Gating Epoch 8, Train Loss: 1.1197
Gating Epoch 9, Train Loss: 1.1501
Gating Epoch 10, Train Loss: 1.1547
Expert 0 Epoch 1, Train Loss: 8.3208
Expert 0 Epoch 2, Train Loss: 9.3641
Expert 0 Epoch 3, Train Loss: 8.2834
Expert 0 Epoch 4, Train Loss: 7.6718
Expert 0 Epoch 5, Train Loss: 7.2179
Expert 1 Epoch 1, Train Loss: 9.1766
Expert 1 Epoch 2, Train Loss: 8.4944
Expert 1 Epoch 3, Train Loss: 7.9644
Expert 1 Epoch 4, Train Loss: 9.0719
Expert 1 Epoch 5, Train Loss: 8.9040
Expert 2 Epoch 1, Train Loss: 8.6533
Expert 2 Epoch 2, Train Loss: 8.9811
Expert 2 Epoch 3, Train Loss: 8.0847
Expert 2 Epoch 4, Train Loss: 8.8539
Expert 2 Epoch 5, Train Loss: 8.9097


[I 2025-04-15 20:57:30,843] Trial 6 finished with value: -0.32766657670338944 and parameters: {'gating_lr': 7.035190432362874e-05, 'gating_epochs': 10, 'expert_lr': 2.2483756650338913e-05, 'expert_epochs': 5}. Best is trial 2 with value: 0.2723334232966106.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Gating Epoch 1, Train Loss: 1.2174
Gating Epoch 2, Train Loss: 1.1029
Gating Epoch 3, Train Loss: 1.1444
Gating Epoch 4, Train Loss: 1.1685
Gating Epoch 5, Train Loss: 1.0776
Gating Epoch 6, Train Loss: 1.3123
Gating Epoch 7, Train Loss: 1.0959
Gating Epoch 8, Train Loss: 1.0957
Gating Epoch 9, Train Loss: 1.0967
Expert 0 Epoch 1, Train Loss: 9.0240
Expert 0 Epoch 2, Train Loss: 8.3266
Expert 0 Epoch 3, Train Loss: 7.7788
Expert 0 Epoch 4, Train Loss: 7.7969
Expert 0 Epoch 5, Train Loss: 8.0650
Expert 0 Epoch 6, Train Loss: 6.7992
Expert 0 Epoch 7, Train Loss: 7.4439
Expert 0 Epoch 8, Train Loss: 6.5733
Expert 1 Epoch 1, Train Loss: 9.0224
Expert 1 Epoch 2, Train Loss: 8.8159
Expert 1 Epoch 3, Train Loss: 8.2515
Expert 1 Epoch 4, Train Loss: 8.9365
Expert 1 Epoch 5, Train Loss: 8.3065
Expert 1 Epoch 6, Train Loss: 8.3160
Expert 1 Epoch 7, Train Loss: 7.9756
Expert 1 Epoch 8, Train Loss: 9.0135
Expert 2 Epoch 1, Train Loss: 8.7916
Expert 2 Epoch 2, Train Loss: 9.3881
Expert 2 Epoch 3, T

[I 2025-04-15 20:58:05,865] Trial 7 finished with value: -0.32766657670338944 and parameters: {'gating_lr': 5.973444157027659e-05, 'gating_epochs': 9, 'expert_lr': 3.064725374736926e-05, 'expert_epochs': 8}. Best is trial 2 with value: 0.2723334232966106.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Gating Epoch 1, Train Loss: 1.2590
Gating Epoch 2, Train Loss: 1.0353
Gating Epoch 3, Train Loss: 1.2603
Gating Epoch 4, Train Loss: 1.2102
Gating Epoch 5, Train Loss: 1.1734
Gating Epoch 6, Train Loss: 1.0104
Gating Epoch 7, Train Loss: 1.2556
Gating Epoch 8, Train Loss: 1.2266
Gating Epoch 9, Train Loss: 1.2191
Gating Epoch 10, Train Loss: 1.1154
Expert 0 Epoch 1, Train Loss: 8.3001
Expert 0 Epoch 2, Train Loss: 6.0499
Expert 0 Epoch 3, Train Loss: 5.7407
Expert 0 Epoch 4, Train Loss: 5.5405
Expert 0 Epoch 5, Train Loss: 5.3863
Expert 0 Epoch 6, Train Loss: 5.2240
Expert 0 Epoch 7, Train Loss: 5.1692
Expert 1 Epoch 1, Train Loss: 8.4841
Expert 1 Epoch 2, Train Loss: 8.0414
Expert 1 Epoch 3, Train Loss: 8.6922
Expert 1 Epoch 4, Train Loss: 8.2716
Expert 1 Epoch 5, Train Loss: 9.1645
Expert 1 Epoch 6, Train Loss: 8.8625
Expert 1 Epoch 7, Train Loss: 8.4552
Expert 2 Epoch 1, Train Loss: 8.0035
Expert 2 Epoch 2, Train Loss: 8.2022
Expert 2 Epoch 3, Train Loss: 7.9322
Expert 2 Epoch 4, Tr

[I 2025-04-15 20:58:38,997] Trial 8 finished with value: -0.32766657670338944 and parameters: {'gating_lr': 1.5906958698915196e-05, 'gating_epochs': 10, 'expert_lr': 0.0009534210073054425, 'expert_epochs': 7}. Best is trial 2 with value: 0.2723334232966106.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Gating Epoch 1, Train Loss: 1.1080
Gating Epoch 2, Train Loss: 1.0907
Gating Epoch 3, Train Loss: 1.1378
Gating Epoch 4, Train Loss: 1.0137
Gating Epoch 5, Train Loss: 1.0308
Gating Epoch 6, Train Loss: 1.1355
Expert 0 Epoch 1, Train Loss: 8.3170
Expert 0 Epoch 2, Train Loss: 8.5138
Expert 0 Epoch 3, Train Loss: 8.1230
Expert 0 Epoch 4, Train Loss: 7.8453
Expert 0 Epoch 5, Train Loss: 8.4285
Expert 0 Epoch 6, Train Loss: 7.6027
Expert 0 Epoch 7, Train Loss: 7.9420
Expert 0 Epoch 8, Train Loss: 7.3979
Expert 1 Epoch 1, Train Loss: 8.0707
Expert 1 Epoch 2, Train Loss: 9.2668
Expert 1 Epoch 3, Train Loss: 8.2195
Expert 1 Epoch 4, Train Loss: 8.4675
Expert 1 Epoch 5, Train Loss: 9.0194
Expert 1 Epoch 6, Train Loss: 8.0658
Expert 1 Epoch 7, Train Loss: 8.8575
Expert 1 Epoch 8, Train Loss: 8.7023
Expert 2 Epoch 1, Train Loss: 8.5094
Expert 2 Epoch 2, Train Loss: 8.3339
Expert 2 Epoch 3, Train Loss: 9.1252
Expert 2 Epoch 4, Train Loss: 8.1975
Expert 2 Epoch 5, Train Loss: 8.1516
Expert 2 Epoc

[I 2025-04-15 20:59:12,842] Trial 9 finished with value: -0.32766657670338944 and parameters: {'gating_lr': 1.2369144642351554e-05, 'gating_epochs': 6, 'expert_lr': 1.2989445609013305e-05, 'expert_epochs': 8}. Best is trial 2 with value: 0.2723334232966106.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Gating Epoch 1, Train Loss: 1.2002
Gating Epoch 2, Train Loss: 1.1283
Gating Epoch 3, Train Loss: 1.2340
Gating Epoch 4, Train Loss: 1.1191
Gating Epoch 5, Train Loss: 1.1778
Gating Epoch 6, Train Loss: 1.0604
Gating Epoch 7, Train Loss: 1.1183
Gating Epoch 8, Train Loss: 1.0758
Expert 0 Epoch 1, Train Loss: 8.1719
Expert 0 Epoch 2, Train Loss: 6.3017
Expert 0 Epoch 3, Train Loss: 6.0650
Expert 1 Epoch 1, Train Loss: 7.7945
Expert 1 Epoch 2, Train Loss: 8.3396
Expert 1 Epoch 3, Train Loss: 8.3464
Expert 2 Epoch 1, Train Loss: 8.5369
Expert 2 Epoch 2, Train Loss: 8.8786
Expert 2 Epoch 3, Train Loss: 9.3162
