<a href="https://colab.research.google.com/github/FlewRr/moe/blob/main/exps_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from typing import Optional
from transformers import BertPreTrainedModel, BertModel
from transformers.models.bert.modeling_bert import (
    BertLayer,
    BertOutput,
    BertLMPredictionHead,
)


class MoEFFN(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_experts: int = 4,
        expert_size: Optional[int] = None,
        k: int = 2,
        dropout_prob: float = 0.1,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_experts = num_experts
        self.k = k
        self.expert_size = expert_size or hidden_size * 4

        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, self.expert_size),
                nn.GELU(),
                nn.Dropout(dropout_prob),
                nn.Linear(self.expert_size, hidden_size),
            )
            for _ in range(num_experts)
        ])

        self.gate = nn.Linear(hidden_size, num_experts, bias=False)

    def forward(self, hidden_states):
        batch_size, seq_len, hidden_dim = hidden_states.shape
        assert hidden_dim == self.hidden_size

        x = hidden_states.view(-1, hidden_dim)
        gate_logits = self.gate(x)
        top_k_logits, top_k_indices = torch.topk(gate_logits, self.k, dim=1)
        top_k_weights = torch.softmax(top_k_logits, dim=1)

        final_output = torch.zeros_like(x)

        for i in range(self.num_experts):
            expert_mask = (top_k_indices == i)
            if expert_mask.any():
                token_indices = expert_mask.nonzero(as_tuple=True)[0]
                pos_in_topk = expert_mask.nonzero(as_tuple=True)[1]

                expert_inputs = x[token_indices]
                expert_weights = top_k_weights[token_indices, pos_in_topk]
                expert_out = self.experts[i](expert_inputs)
                weighted_out = expert_out * expert_weights.unsqueeze(-1)

                final_output.index_add_(0, token_indices, weighted_out)

        return final_output.view(batch_size, seq_len, hidden_dim)


from transformers.models.bert.modeling_bert import BertLayer
import torch.nn as nn

class BertLayerWithMoE(BertLayer):
    def __init__(self, config):
        super().__init__(config)
        del self.intermediate
        del self.output

        self.moe_ffn = MoEFFN(
            hidden_size=config.hidden_size,
            num_experts=getattr(config, "num_experts", 4),
            expert_size=config.intermediate_size,
            k=getattr(config, "moe_k", 2),
            dropout_prob=config.hidden_dropout_prob,
        )

        self.moe_output_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.moe_output_dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=False,
        **kwargs,
    ):
        self_attn_output = self.attention(
            hidden_states,
            attention_mask=attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
        )
        attn_output = self_attn_output[0]

        moe_output = self.moe_ffn(attn_output)

        moe_output = self.moe_output_dropout(moe_output)
        layer_output = self.moe_output_layer_norm(attn_output + moe_output)

        outputs = (layer_output,) + self_attn_output[1:]
        return outputs


class BertMoEForMaskedLM(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config

        self.bert = BertModel(config, add_pooling_layer=False)
        for layer in self.bert.encoder.layer:
            layer.__class__ = BertLayerWithMoE
            layer.__init__(config)

        self.cls = BertLMPredictionHead(config)
        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        labels=None,
        **kwargs,
    ):
        bert_kwargs = {
            k: v for k, v in {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "token_type_ids": token_type_ids,
                "position_ids": position_ids,
                "head_mask": head_mask,
                "inputs_embeds": inputs_embeds,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
            }.items() if v is not None
        }

        outputs = self.bert(**bert_kwargs)

        sequence_output = outputs.last_hidden_state
        prediction_scores = self.cls(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                prediction_scores.view(-1, self.config.vocab_size),
                labels.view(-1)
            )

        return {
            "loss": loss,
            "logits": prediction_scores,
            "hidden_states": outputs.hidden_states,
            "attentions": outputs.attentions,
        }

In [None]:
class PretrainConfig:
    model_name = "your-moe-bert"
    dataset_name = "wikimedia/wikipedia"
    dataset_config = "20231101.en"
    text_column = "text"
    tokenizer = "bert-base-uncased"
    output_dir = "."
    seq_len = 128
    batch_size = 32

    masking_prob = 0.15

    lr = 5e-5
    weight_decay = 0.01
    warmup_steps = 1000
    max_steps = 10_000

    save_steps = 5_000
    logging_steps = 100
    eval_steps = 2000

    bert_hidden_size = 256
    bert_intermediate_size = 1024
    bert_num_hidden_layers = 4
    bert_num_attention_heads = 4
    num_experts = 4

In [None]:
import os
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
    set_seed,
)
from transformers import BertConfig


def main():
    set_seed(42)
    cfg = PretrainConfig()

    model_config = BertConfig(
        vocab_size=30522,
        hidden_size=cfg.bert_hidden_size,
        num_hidden_layers=cfg.bert_num_hidden_layers,
        num_attention_heads=cfg.bert_num_attention_heads,
        intermediate_size=cfg.bert_intermediate_size,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=2,
        pad_token_id=0,
        num_experts=cfg.num_experts,
        moe_k=2,
    )

    tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer)
    model = BertMoEForMaskedLM(model_config)

    print(f"Model initialized with {cfg.num_experts} experts.")
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

    print("Loading dataset in streaming mode...")
    dataset = load_dataset(
        cfg.dataset_name,
        cfg.dataset_config,
        split="train",
        streaming=True
    )

    def tokenize_function(examples):
        return tokenizer(
            examples[cfg.text_column],
            truncation=True,
            padding=False,
            max_length=cfg.seq_len,
            return_special_tokens_mask=True,
        )

    original_columns = dataset.column_names
    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=original_columns,
    )

    def filter_short(example):
        return len(example["input_ids"]) >= cfg.seq_len // 2

    tokenized_dataset = tokenized_dataset.filter(filter_short)

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=True,
        mlm_probability=cfg.masking_prob,
    )

    training_args = TrainingArguments(
        output_dir=cfg.output_dir,
        overwrite_output_dir=True,
        max_steps=cfg.max_steps,
        per_device_train_batch_size=cfg.batch_size,
        gradient_accumulation_steps=1,
        learning_rate=cfg.lr,
        weight_decay=cfg.weight_decay,
        warmup_steps=cfg.warmup_steps,
        logging_steps=cfg.logging_steps,
        save_steps=cfg.save_steps,
        save_strategy="steps",
        load_best_model_at_end=False,
        fp16=True,
        dataloader_num_workers=2,
        remove_unused_columns=False,
        report_to="none",
        dataloader_drop_last=True,
        save_safetensors=False
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
    )

    print("Starting pretraining (streaming)...")
    trainer.train()

    final_dir = os.path.join(cfg.output_dir, "final_model")
    trainer.save_model(final_dir)
    tokenizer.save_pretrained(final_dir)
    print(f"Model saved to {final_dir}")


if __name__ == "__main__":
    main()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Model initialized with 4 experts.
Total parameters: 25,326,138
Loading dataset in streaming mode...


README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

  trainer = Trainer(


Starting pretraining (streaming)...


Step,Training Loss
100,10.2965
200,9.9951
300,9.5963
400,9.1732
500,8.7537
600,8.2938
700,7.8631
800,7.5573
900,7.3662
1000,7.3379


Model saved to ./final_model


In [None]:
!tar -xvf /content/moe_bert_checkpoints_and_model.tar

final_model/
final_model/pytorch_model.bin
final_model/training_args.bin
final_model/config.json
final_model/tokenizer.json
final_model/tokenizer_config.json
final_model/special_tokens_map.json
final_model/vocab.txt


In [None]:
import torch
from transformers import AutoTokenizer

model_dir = "/content/final_model"

tokenizer = AutoTokenizer.from_pretrained(model_dir)

model = BertMoEForMaskedLM.from_pretrained(model_dir)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

import torch

text = "Paris is the [MASK] of France."

inputs = tokenizer(
    text,
    return_tensors="pt"
)

inputs = {k: v.to(device) for k, v in inputs.items()}

with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs["logits"]

mask_token_id = tokenizer.mask_token_id
mask_positions = (inputs["input_ids"] == mask_token_id).nonzero(as_tuple=False)

batch_idx, mask_pos = mask_positions[0].tolist()

mask_logits = logits[batch_idx, mask_pos, :]
top_k = torch.topk(mask_logits, k=10)
top_ids = top_k.indices.tolist()
top_scores = top_k.values.tolist()

print("Input:", text)
print("Top predictions for [MASK]:")
for token_id, score in zip(top_ids, top_scores):
    token = tokenizer.decode([token_id])
    print(f"{token!r}  logit={score:.3f}")


Input: Paris is the [MASK] of France.
Top predictions for [MASK]:
'capital'  logit=9.794
'department'  logit=9.573
'arrondissement'  logit=9.051
'list'  logit=8.876
'commune'  logit=8.847
'republic'  logit=8.844
'name'  logit=8.790
'kingdom'  logit=8.455
'communes'  logit=8.389
'chateau'  logit=8.104


## –ö–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ç–æ—Ä


In [None]:
import torch
import torch.nn as nn
from transformers import BertConfig
from transformers.modeling_outputs import SequenceClassifierOutput


class BertMoEForMultiLabelClassification(BertMoEForMaskedLM):
    def __init__(self, config: BertConfig):
        super().__init__(config)

        self.num_labels = config.num_labels
        self.config.problem_type = "multi_label_classification"

        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, self.num_labels)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        labels=None,
        **kwargs,
    ):
        bert_kwargs = {
            k: v for k, v in {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "token_type_ids": token_type_ids,
                "position_ids": position_ids,
                "head_mask": head_mask,
                "inputs_embeds": inputs_embeds,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
            }.items() if v is not None
        }

        outputs = self.bert(**bert_kwargs)
        sequence_output = outputs.last_hidden_state

        cls_output = sequence_output[:, 0, :]
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)

        loss = None
        if labels is not None:
            loss_fct = nn.BCEWithLogitsLoss()
            loss = loss_fct(logits, labels.float())


        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [None]:
LABEL_COLUMNS = [
    "Computer Science",
    "Physics",
    "Mathematics",
    "Statistics",
    "Quantitative Biology",
    "Quantitative Finance",
]

In [None]:
from dataclasses import dataclass, field
from typing import List


@dataclass
class MultiLabelConfig:
    pretrained_mlm_path: str = "/content/final_model"

    train_csv: str = "/content/data/train.csv"
    test_csv: str = "/content/data/train.csv"

    title_column: str = "TITLE"
    abstract_column: str = "ABSTRACT"
    label_columns: List[str] = field(default_factory=lambda: [
        "Computer Science",
        "Physics",
        "Mathematics",
        "Statistics",
        "Quantitative Biology",
        "Quantitative Finance",
    ])

    tokenizer_name: str = "bert-base-uncased"

    max_length: int = 256
    train_batch_size: int = 16
    eval_batch_size: int = 16
    num_train_epochs: int = 3
    learning_rate: float = 2e-5
    weight_decay: float = 0.01
    logging_steps: int = 50
    save_steps: int = 500
    output_dir: str = "./moe_multilabel"


In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer

cfg = MultiLabelConfig()
tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_name, use_fast=True)

raw_dataset = load_dataset("csv", data_files={"train": cfg.train_csv})["train"]

def preprocess_function(examples):
    tokenized = tokenizer(
        examples[cfg.title_column],
        examples[cfg.abstract_column],
        truncation=True,
        max_length=cfg.max_length,
        padding=False,
    )

    labels = []
    for i in range(len(examples[cfg.title_column])):
        labels.append([
            examples[col][i] for col in cfg.label_columns
        ])
    tokenized["labels"] = labels
    return tokenized

processed_dataset = raw_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=raw_dataset.column_names,
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/20972 [00:00<?, ? examples/s]

In [None]:
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoConfig,
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
    set_seed,
)
from sklearn.metrics import precision_recall_fscore_support
from transformers import EvalPrediction
import numpy as np

def build_compute_metrics_fn(threshold: float = 0.5):
    def compute_metrics(eval_pred: EvalPrediction):
        logits, labels = eval_pred
        probs = 1 / (1 + np.exp(-logits))

        y_pred = (probs >= threshold).astype(int)
        y_true = labels.astype(int)

        precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
            y_true, y_pred, average="micro", zero_division=0
        )
        precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
            y_true, y_pred, average="macro", zero_division=0
        )

        return {
            "precision_micro": precision_micro,
            "recall_micro": recall_micro,
            "f1_micro": f1_micro,
            "precision_macro": precision_macro,
            "recall_macro": recall_macro,
            "f1_macro": f1_macro,
        }

    return compute_metrics

def main():
    set_seed(42)
    cfg = MultiLabelConfig()

    tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_name, use_fast=True)

    raw_dataset = load_dataset("csv", data_files={"train": cfg.train_csv})["train"]

    dataset_splits = raw_dataset.train_test_split(test_size=0.1, seed=42)
    train_dataset = dataset_splits["train"]
    eval_dataset = dataset_splits["test"]

    LABEL_COLUMNS = cfg.label_columns
    print("–ö–æ–ª–æ–Ω–∫–∏:", raw_dataset.column_names)
    print("Label-–∫–æ–ª–æ–Ω–∫–∏:", LABEL_COLUMNS)

    base_config = AutoConfig.from_pretrained(cfg.pretrained_mlm_path)
    base_config.num_labels = len(LABEL_COLUMNS)
    base_config.problem_type = "multi_label_classification"

    model = BertMoEForMultiLabelClassification.from_pretrained(
        cfg.pretrained_mlm_path,
        config=base_config,
        ignore_mismatched_sizes=True,
    )

    def preprocess_function(examples):
        tokenized = tokenizer(
            examples[cfg.title_column],
            examples[cfg.abstract_column],
            truncation=True,
            max_length=cfg.max_length,
            padding=False,
        )

        labels = []
        for i in range(len(examples[cfg.title_column])):
            labels.append([
                examples[col][i] for col in LABEL_COLUMNS
            ])
        tokenized["labels"] = labels
        return tokenized

    train_dataset = train_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=raw_dataset.column_names,
    )
    eval_dataset = eval_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=raw_dataset.column_names,
    )

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    training_args = TrainingArguments(
        output_dir=cfg.output_dir,
        overwrite_output_dir=True,
        num_train_epochs=cfg.num_train_epochs,
        per_device_train_batch_size=cfg.train_batch_size,
        per_device_eval_batch_size=cfg.eval_batch_size,
        learning_rate=cfg.learning_rate,
        weight_decay=cfg.weight_decay,
        logging_steps=cfg.logging_steps,
        save_steps=cfg.save_steps,
        eval_strategy="steps",
        eval_steps=cfg.save_steps,
        save_total_limit=2,
        load_best_model_at_end=True,
        fp16=torch.cuda.is_available(),
        report_to="none",
        save_safetensors=False
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=build_compute_metrics_fn(threshold=0.5)
    )

    trainer.train()

    final_dir = os.path.join(cfg.output_dir, "final_model")
    trainer.save_model(final_dir)
    tokenizer.save_pretrained(final_dir)
    print(f"Model saved to {final_dir}")


if __name__ == "__main__":
    main()


Some weights of BertMoEForMultiLabelClassification were not initialized from the model checkpoint at /content/final_model and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


–ö–æ–ª–æ–Ω–∫–∏: ['ID', 'TITLE', 'ABSTRACT', 'Computer Science', 'Physics', 'Mathematics', 'Statistics', 'Quantitative Biology', 'Quantitative Finance']
Label-–∫–æ–ª–æ–Ω–∫–∏: ['Computer Science', 'Physics', 'Mathematics', 'Statistics', 'Quantitative Biology', 'Quantitative Finance']


Map:   0%|          | 0/2098 [00:00<?, ? examples/s]

  trainer = Trainer(


Step,Training Loss,Validation Loss,Precision Micro,Recall Micro,F1 Micro,Precision Macro,Recall Macro,F1 Macro
500,0.2844,0.272263,0.765871,0.719058,0.741727,0.50783,0.481146,0.493374
1000,0.2467,0.2428,0.820844,0.723614,0.769169,0.550729,0.47767,0.506334
1500,0.2316,0.226443,0.829526,0.744495,0.784714,0.54996,0.503301,0.525272
2000,0.2109,0.218956,0.831807,0.758542,0.793487,0.555408,0.507483,0.52876
2500,0.2146,0.214082,0.813486,0.783219,0.798066,0.628465,0.547177,0.567458


Step,Training Loss,Validation Loss,Precision Micro,Recall Micro,F1 Micro,Precision Macro,Recall Macro,F1 Macro
500,0.2844,0.272263,0.765871,0.719058,0.741727,0.50783,0.481146,0.493374
1000,0.2467,0.2428,0.820844,0.723614,0.769169,0.550729,0.47767,0.506334
1500,0.2316,0.226443,0.829526,0.744495,0.784714,0.54996,0.503301,0.525272
2000,0.2109,0.218956,0.831807,0.758542,0.793487,0.555408,0.507483,0.52876
2500,0.2146,0.214082,0.813486,0.783219,0.798066,0.628465,0.547177,0.567458
3000,0.2053,0.211894,0.823386,0.769932,0.795762,0.630308,0.542726,0.567393
3500,0.2097,0.209608,0.827141,0.777525,0.801566,0.634389,0.547533,0.573763


Model saved to ./moe_multilabel/final_model


In [None]:
import os
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoConfig


def load_model_for_inference(model_dir: str, cfg: MultiLabelConfig):
    config = AutoConfig.from_pretrained(model_dir)
    tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
    model = BertMoEForMultiLabelClassification.from_pretrained(model_dir, config=config)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    return model, tokenizer, device


def predict_on_test_csv(
    model,
    tokenizer,
    device,
    cfg: MultiLabelConfig,
    threshold: float = 0.5,
    batch_size: int = 32,
    output_path: str = "sample_submission.csv",
):
    test_df = pd.read_csv(cfg.test_csv)
    print("–¢–µ—Å—Ç–æ–≤—ã–π –¥–∞—Ç–∞—Å–µ—Ç:", test_df.shape)
    print("–ö–æ–ª–æ–Ω–∫–∏ –≤ —Ç–µ—Å—Ç–µ:", list(test_df.columns))

    id_col = "ID"

    all_preds = []

    num_labels = len(cfg.label_columns)
    n = len(test_df)

    for start in range(0, n, batch_size):
        end = min(start + batch_size, n)
        batch = test_df.iloc[start:end]

        encoded = tokenizer(
            batch[cfg.title_column].tolist(),
            batch[cfg.abstract_column].tolist(),
            truncation=True,
            max_length=cfg.max_length,
            padding=True,
            return_tensors="pt",
        )

        encoded = {k: v.to(device) for k, v in encoded.items()}

        with torch.no_grad():
            outputs = model(**encoded)
            logits = outputs["logits"]
            probs = torch.sigmoid(logits).cpu().numpy()

        preds = (probs >= threshold).astype(int)
        all_preds.append(preds)

    all_preds = np.vstack(all_preds)
    assert all_preds.shape == (len(test_df), num_labels)

    submission_df = pd.DataFrame(
        all_preds,
        columns=cfg.label_columns,
    )
    submission_df.insert(0, id_col, test_df[id_col].values)

    submission_df.to_csv(output_path, index=False)
    print(f"–°–æ—Ö—Ä–∞–Ω–µ–Ω–æ –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–µ –≤ {output_path}")
    print(submission_df.head())


if __name__ == "__main__":
    cfg = MultiLabelConfig()

    model_dir = "/content/moe_multilabel/checkpoint-3500"

    model, tokenizer, device = load_model_for_inference(model_dir, cfg)

    predict_on_test_csv(
        model=model,
        tokenizer=tokenizer,
        device=device,
        cfg=cfg,
        threshold=0.5,
        batch_size=32,
        output_path="sample_submission.csv",
    )


–¢–µ—Å—Ç–æ–≤—ã–π –¥–∞—Ç–∞—Å–µ—Ç: (20972, 9)
–ö–æ–ª–æ–Ω–∫–∏ –≤ —Ç–µ—Å—Ç–µ: ['ID', 'TITLE', 'ABSTRACT', 'Computer Science', 'Physics', 'Mathematics', 'Statistics', 'Quantitative Biology', 'Quantitative Finance']
–°–æ—Ö—Ä–∞–Ω–µ–Ω–æ –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–µ –≤ sample_submission.csv
   ID  Computer Science  Physics  Mathematics  Statistics  \
0   1                 1        0            0           1   
1   2                 1        0            0           1   
2   3                 0        0            1           0   
3   4                 0        0            1           0   
4   5                 1        0            0           1   

   Quantitative Biology  Quantitative Finance  
0                     0                     0  
1                     0                     0  
2                     0                     0  
3                     0                     0  
4                     0                     0  


In [None]:
import os
import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import (
    AutoConfig,
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
    set_seed,
    EarlyStoppingCallback,
)
from sklearn.metrics import precision_recall_fscore_support
from transformers import EvalPrediction
import numpy as np
import math
from dataclasses import dataclass, field
from typing import List

from test import BertMoEForMultiLabelClassification

@dataclass
class ImprovedMultiLabelConfig:
    pretrained_mlm_path: str = "final_model"

    train_csv: str = "train.csv"
    test_csv: str = "test.csv"

    title_column: str = "TITLE"
    abstract_column: str = "ABSTRACT"
    label_columns: List[str] = field(default_factory=lambda: [
        "Computer Science",
        "Physics",
        "Mathematics",
        "Statistics",
        "Quantitative Biology",
        "Quantitative Finance",
    ])

    tokenizer_name: str = "bert-base-uncased"

    max_length: int = 256
    train_batch_size: int = 16
    eval_batch_size: int = 32
    num_train_epochs: int = 20
    learning_rate: float = 3e-5
    weight_decay: float = 0.1
    warmup_ratio: float = 0.1
    max_grad_norm: float = 1.0

    early_stopping_patience: int = 3
    early_stopping_threshold: float = 0.001
    metric_for_best_model: str = "eval_f1_micro"
    greater_is_better: bool = True
    load_best_model_at_end: bool = True

    logging_steps: int = 50
    save_steps: int = 50
    eval_steps: int = 50
    save_total_limit: int = 3
    output_dir: str = "./improved_moe_multilabel"

    hidden_dropout_prob: float = 0.3
    attention_probs_dropout_prob: float = 0.2
    classifier_dropout: float = 0.4

    lr_scheduler_type: str = "cosine"

    use_data_augmentation: bool = True
    augmentation_prob: float = 0.1

class ImprovedBertMoEForMultiLabelClassification(BertMoEForMultiLabelClassification):
    """
    –£–ª—É—á—à–µ–Ω–Ω–∞—è –≤–µ—Ä—Å–∏—è –º–æ–¥–µ–ª–∏ —Å –¥–æ–ø–æ–ª–Ω–∏—Ç–µ–ª—å–Ω–æ–π —Ä–µ–≥—É–ª—è—Ä–∏–∑–∞—Ü–∏–µ–π
    """

    def __init__(self, config):
        super().__init__(config)

        self.classifier_dropout = nn.Dropout(config.classifier_dropout if hasattr(config, 'classifier_dropout') else 0.4)

        self.additional_dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.additional_layer_norm = nn.LayerNorm(config.hidden_size)

        self.classifier = nn.Sequential(
            self.additional_layer_norm,
            self.classifier_dropout,
            self.additional_dense,
            nn.GELU(),
            nn.Dropout(config.hidden_dropout_prob),
            nn.Linear(config.hidden_size, config.num_labels)
        )

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.LayerNorm):
            nn.init.constant_(module.bias, 0)
            nn.init.constant_(module.weight, 1.0)

def cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps, min_lr=0):
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))

        progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return max(min_lr, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

def build_compute_metrics_fn(threshold: float = 0.5):
    def compute_metrics(eval_pred: EvalPrediction):
        logits, labels = eval_pred

        probs = 1 / (1 + np.exp(-logits))

        y_pred = (probs >= threshold).astype(int)
        y_true = labels.astype(int)

        precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
            y_true, y_pred, average="micro", zero_division=0
        )
        precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
            y_true, y_pred, average="macro", zero_division=0
        )

        precision_per_class, recall_per_class, f1_per_class, _ = precision_recall_fscore_support(
            y_true, y_pred, average=None, zero_division=0
        )

        return {
            "precision_micro": precision_micro,
            "recall_micro": recall_micro,
            "f1_micro": f1_micro,
            "precision_macro": precision_macro,
            "recall_macro": recall_macro,
            "f1_macro": f1_macro,
            "precision_per_class": precision_per_class.tolist(),
            "recall_per_class": recall_per_class.tolist(),
            "f1_per_class": f1_per_class.tolist(),
        }

    return compute_metrics

def apply_text_augmentation(texts, augmentation_prob=0.1):
    augmented_texts = []
    for text in texts:
        if np.random.random() < augmentation_prob:
            words = text.split()
            if len(words) > 3:
                mid_start = len(words) // 4
                mid_end = 3 * len(words) // 4
                mid_words = words[mid_start:mid_end]
                np.random.shuffle(mid_words)
                words[mid_start:mid_end] = mid_words
                augmented_texts.append(' '.join(words))
            else:
                augmented_texts.append(text)
        else:
            augmented_texts.append(text)
    return augmented_texts

class ImprovedTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.best_metric = -float('inf')
        self.patience_counter = 0

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.logits

        if labels is not None:
            loss_fct = nn.BCEWithLogitsLoss()

            if hasattr(self.args, 'label_smoothing_factor') and self.args.label_smoothing_factor > 0:
                smooth_factor = self.args.label_smoothing_factor
                labels = labels * (1 - smooth_factor) + 0.5 * smooth_factor

            loss = loss_fct(logits, labels.float())
        else:
            loss = outputs.loss

        return (loss, outputs) if return_outputs else loss

def main():
    set_seed(42)
    cfg = ImprovedMultiLabelConfig()

    print(f"Configuration:")
    print(f"  - Learning rate: {cfg.learning_rate}")
    print(f"  - Weight decay: {cfg.weight_decay}")
    print(f"  - Dropout: {cfg.hidden_dropout_prob}")
    print(f"  - Early stopping patience: {cfg.early_stopping_patience}")
    print(f"  - Max epochs: {cfg.num_train_epochs}")

    tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_name, use_fast=True)

    raw_dataset = load_dataset("csv", data_files={"train": cfg.train_csv})["train"]

    dataset_splits = raw_dataset.train_test_split(test_size=0.15, seed=42)
    train_dataset = dataset_splits["train"]
    eval_dataset = dataset_splits["test"]

    LABEL_COLUMNS = cfg.label_columns
    print("üìä Dataset info:")
    print(f"  - Total samples: {len(raw_dataset)}")
    print(f"  - Train samples: {len(train_dataset)}")
    print(f"  - Eval samples: {len(eval_dataset)}")
    print(f"  - Label columns: {LABEL_COLUMNS}")

    base_config = AutoConfig.from_pretrained(cfg.pretrained_mlm_path)
    base_config.num_labels = len(LABEL_COLUMNS)
    base_config.problem_type = "multi_label_classification"

    base_config.hidden_dropout_prob = cfg.hidden_dropout_prob
    base_config.attention_probs_dropout_prob = cfg.attention_probs_dropout_prob
    base_config.classifier_dropout = cfg.classifier_dropout

    model = ImprovedBertMoEForMultiLabelClassification.from_pretrained(
        cfg.pretrained_mlm_path,
        config=base_config,
        ignore_mismatched_sizes=True,
    )

    print(f"üß† Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters")

    def preprocess_function(examples):
        if cfg.use_data_augmentation:
            titles = apply_text_augmentation(examples[cfg.title_column], cfg.augmentation_prob)
            abstracts = apply_text_augmentation(examples[cfg.abstract_column], cfg.augmentation_prob)
        else:
            titles = examples[cfg.title_column]
            abstracts = examples[cfg.abstract_column]

        tokenized = tokenizer(
            titles,
            abstracts,
            truncation=True,
            max_length=cfg.max_length,
            padding=False,
        )

        labels = []
        for i in range(len(titles)):
            labels.append([examples[col][i] for col in LABEL_COLUMNS])
        tokenized["labels"] = labels

        return tokenized

    train_dataset = train_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=raw_dataset.column_names,
    )
    eval_dataset = eval_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=raw_dataset.column_names,
    )

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    training_args = TrainingArguments(
        output_dir=cfg.output_dir,
        overwrite_output_dir=True,
        num_train_epochs=cfg.num_train_epochs,
        per_device_train_batch_size=cfg.train_batch_size,
        per_device_eval_batch_size=cfg.eval_batch_size,
        learning_rate=cfg.learning_rate,
        weight_decay=cfg.weight_decay,
        warmup_ratio=cfg.warmup_ratio,
        max_grad_norm=cfg.max_grad_norm,

        lr_scheduler_type=cfg.lr_scheduler_type,

        load_best_model_at_end=cfg.load_best_model_at_end,
        metric_for_best_model=cfg.metric_for_best_model,
        greater_is_better=cfg.greater_is_better,

        logging_steps=cfg.logging_steps,
        save_steps=cfg.save_steps,
        eval_steps=cfg.eval_steps,
        save_total_limit=cfg.save_total_limit,
        eval_strategy="steps",

        fp16=torch.cuda.is_available(),
        dataloader_num_workers=2,
        gradient_accumulation_steps=2,

        report_to="none",
        save_safetensors=False,

        label_smoothing_factor=0.01,
    )

    trainer = ImprovedTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=build_compute_metrics_fn(threshold=0.5),
    )

    print("üèÉ‚Äç‚ôÇÔ∏è Starting training...")
    trainer.train()

    print("üìä Final evaluation...")
    eval_results = trainer.evaluate()
    print("Final results:", eval_results)

    final_dir = os.path.join(cfg.output_dir, "final_model")
    trainer.save_model(final_dir)
    tokenizer.save_pretrained(final_dir)

    import json
    with open(os.path.join(final_dir, "training_metrics.json"), "w") as f:
        json.dump(eval_results, f, indent=2)

    print(f"‚úÖ Model saved to {final_dir}")
    print(f"üéØ Best F1 Micro: {eval_results.get('eval_f1_micro', 'N/A'):.4f}")
    print(f"üéØ Best F1 Macro: {eval_results.get('eval_f1_macro', 'N/A'):.4f}")

if __name__ == "__main__":
    main()

KeyboardInterrupt: 

In [None]:
!unzip improved_moe_multilabel.zip

Archive:  improved_moe_multilabel.zip
   creating: improved_moe_multilabel/
  inflating: __MACOSX/._improved_moe_multilabel  
   creating: improved_moe_multilabel/checkpoint-4000/
  inflating: __MACOSX/improved_moe_multilabel/._checkpoint-4000  
  inflating: improved_moe_multilabel/.DS_Store  
  inflating: __MACOSX/improved_moe_multilabel/._.DS_Store  
   creating: improved_moe_multilabel/checkpoint-11160/
  inflating: __MACOSX/improved_moe_multilabel/._checkpoint-11160  
   creating: improved_moe_multilabel/checkpoint-11000/
  inflating: __MACOSX/improved_moe_multilabel/._checkpoint-11000  
   creating: improved_moe_multilabel/final_model/
  inflating: __MACOSX/improved_moe_multilabel/._final_model  
  inflating: improved_moe_multilabel/checkpoint-4000/rng_state.pth  
  inflating: __MACOSX/improved_moe_multilabel/checkpoint-4000/._rng_state.pth  
  inflating: improved_moe_multilabel/checkpoint-4000/tokenizer_config.json  
  inflating: __MACOSX/improved_moe_multilabel/checkpoint-4000/.

In [None]:
import os
import csv
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

from dataclasses import dataclass, field
from typing import List, Dict

from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoTokenizer, DataCollatorWithPadding, set_seed

import re

class ImprovedBertMoEForMultiLabelClassification(BertMoEForMultiLabelClassification):
    def __init__(self, config):
        super().__init__(config)

        self.classifier_dropout = nn.Dropout(getattr(config, "classifier_dropout", 0.4))
        self.additional_dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.additional_layer_norm = nn.LayerNorm(config.hidden_size)

        self.classifier = nn.Sequential(
            self.additional_layer_norm,
            self.classifier_dropout,
            self.additional_dense,
            nn.GELU(),
            nn.Dropout(config.hidden_dropout_prob),
            nn.Linear(config.hidden_size, config.num_labels)
        )
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.LayerNorm):
            nn.init.constant_(module.bias, 0)
            nn.init.constant_(module.weight, 1.0)


model_dir = "/content/improved_moe_multilabel/final_model"
train_csv = "test_experts.csv"
outdir = "./moe_report_figures"

seed = 42
test_size = 0.15

title_column = "TITLE"
abstract_column = "ABSTRACT"
label_columns = [
    "Computer Science", "Physics", "Mathematics",
    "Statistics", "Quantitative Biology", "Quantitative Finance"
]

max_length = 256
batch_size = 32

threshold = 0.5
aggregate_by = "true"

make_layer_heatmaps = True
make_aggregate_heatmaps = True
make_expert_profiles = True
make_top_experts_per_class = True


def ensure_dir(p):
    os.makedirs(p, exist_ok=True)


def save_heatmap(path, mat, row_names, col_names, title):
    ensure_dir(os.path.dirname(path))
    plt.figure(figsize=(1 + 0.65 * len(col_names), 1 + 0.45 * len(row_names)))
    plt.imshow(mat, aspect="auto")
    plt.colorbar()
    plt.yticks(range(len(row_names)), row_names)
    plt.xticks(range(len(col_names)), col_names, rotation=45, ha="right")
    plt.title(title)
    plt.tight_layout()
    plt.savefig(path, dpi=200)
    plt.close()


def save_bar_topk(path, values, labels, title, ylabel="score"):
    ensure_dir(os.path.dirname(path))
    plt.figure(figsize=(9, 4))
    x = np.arange(len(labels))
    plt.bar(x, values)
    plt.xticks(x, labels, rotation=45, ha="right")
    plt.ylabel(ylabel)
    plt.title(title)
    plt.grid(True, axis="y", alpha=0.25)
    plt.tight_layout()
    plt.savefig(path, dpi=200)
    plt.close()


def save_stacked_bar(path, mat, row_names, col_names, title):
    ensure_dir(os.path.dirname(path))
    plt.figure(figsize=(11, 5))
    x = np.arange(len(row_names))
    bottom = np.zeros(len(row_names), dtype=np.float64)

    for j, cname in enumerate(col_names):
        plt.bar(x, mat[:, j], bottom=bottom, label=cname)
        bottom += mat[:, j]

    plt.xticks(x, row_names, rotation=0)
    plt.ylabel("share")
    plt.title(title)
    plt.legend(ncol=3, fontsize=9)
    plt.grid(True, axis="y", alpha=0.25)
    plt.tight_layout()
    plt.savefig(path, dpi=200)
    plt.close()


class GateHookCollector:
    def __init__(self, model: nn.Module):
        self.model = model
        self.handles = []
        self.layer_logits: Dict[int, torch.Tensor] = {}

    def install(self):
        for i, layer in enumerate(self.model.bert.encoder.layer):
            if hasattr(layer, "moe_ffn") and hasattr(layer.moe_ffn, "gate"):
                h = layer.moe_ffn.gate.register_forward_hook(self._make_hook(i))
                self.handles.append(h)

    def _make_hook(self, layer_idx: int):
        def hook(module, inp, out):
            self.layer_logits[layer_idx] = out.detach()
        return hook

    def clear(self):
        self.layer_logits = {}

    def remove(self):
        for h in self.handles:
            h.remove()
        self.handles = []


def preprocess_builder(tokenizer):
    def preprocess(examples):
        tok = tokenizer(
            examples[title_column],
            examples[abstract_column],
            truncation=True,
            max_length=max_length,
            padding=False,
        )
        labels = []
        for i in range(len(examples[title_column])):
            labels.append([examples[col][i] for col in label_columns])
        tok["labels"] = labels
        return tok
    return preprocess


def compute_topk_contrib(gate_logits: torch.Tensor, k: int, num_experts: int) -> torch.Tensor:
    topk_logits, topk_idx = torch.topk(gate_logits, k, dim=-1)
    topk_w = torch.softmax(topk_logits, dim=-1)
    contrib = torch.zeros((gate_logits.size(0), num_experts), device=gate_logits.device, dtype=topk_w.dtype)
    contrib.scatter_add_(1, topk_idx, topk_w)
    return contrib


ensure_dir(outdir)
set_seed(seed)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
model_cfg = AutoConfig.from_pretrained(model_dir)

model = ImprovedBertMoEForMultiLabelClassification.from_pretrained(
    model_dir,
    config=model_cfg,
    ignore_mismatched_sizes=True
).to(device)
model.eval()

num_experts = int(getattr(model_cfg, "num_experts", 4))
moe_k = int(getattr(model_cfg, "moe_k", 2))
n_layers = len(model.bert.encoder.layer)
n_classes = len(label_columns)

experts = [f"expert_{i}" for i in range(num_experts)]
classes = label_columns

print(f"layers={n_layers}, num_experts={num_experts}, moe_k={moe_k}, classes={n_classes}")

raw = load_dataset("csv", data_files={"train": train_csv})["train"]
spl = raw.train_test_split(test_size=test_size, seed=seed)
eval_ds = spl["test"].map(
    preprocess_builder(tokenizer),
    batched=True,
    remove_columns=raw.column_names
)

loader = DataLoader(
    eval_ds,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=DataCollatorWithPadding(tokenizer=tokenizer),
)

collector = GateHookCollector(model)
collector.install()

weighted_sum = np.zeros((n_layers, n_classes, num_experts), dtype=np.float64)
top1_count = np.zeros((n_layers, n_classes, num_experts), dtype=np.int64)
class_counts = np.zeros((n_classes,), dtype=np.int64)

with torch.no_grad():
    for step, batch in enumerate(loader):
        collector.clear()

        labels_np = batch["labels"].numpy().astype(int)
        attn_np = batch["attention_mask"].numpy().astype(int)

        batch = {k: v.to(device) for k, v in batch.items()}
        out = model(**batch)

        logits = out.logits.detach().cpu().numpy()
        probs = 1.0 / (1.0 + np.exp(-logits))
        preds_np = (probs >= threshold).astype(int)

        mask = labels_np if aggregate_by == "true" else preds_np
        B, L = batch["attention_mask"].shape

        if len(collector.layer_logits) == 0:
            raise RuntimeError("–ù–µ –ø–æ–π–º–∞–ª–∏ gate_logits. –ü—Ä–æ–≤–µ—Ä—å, —á—Ç–æ —Å–ª–æ–∏ –¥–µ–π—Å—Ç–≤–∏—Ç–µ–ª—å–Ω–æ BertLayerWithMoE.")

        attn_t = torch.tensor(attn_np, device=device)
        attn_mask = attn_t.unsqueeze(-1)

        for layer_idx, gate_logits in collector.layer_logits.items():
            gate_logits = gate_logits.to(device)

            contrib = compute_topk_contrib(gate_logits, k=moe_k, num_experts=num_experts)
            top1 = gate_logits.argmax(dim=-1)

            contrib = contrib.view(B, L, num_experts) * attn_mask
            top1 = top1.view(B, L)

            per_example = contrib.sum(dim=1).detach().cpu().numpy()
            top1_np = top1.detach().cpu().numpy()

            for c in range(n_classes):
                idx = np.where(mask[:, c] == 1)[0]
                if idx.size == 0:
                    continue

                weighted_sum[layer_idx, c] += per_example[idx].sum(axis=0)
                class_counts[c] += idx.size

                for bi in idx:
                    valid_pos = np.where(attn_np[bi] == 1)[0]
                    if valid_pos.size == 0:
                        continue
                    t1_tokens = top1_np[bi, valid_pos]
                    bc = np.bincount(t1_tokens, minlength=num_experts)
                    top1_count[layer_idx, c] += bc

collector.remove()

print("class_counts:", dict(zip(classes, class_counts.tolist())))

denom = np.maximum(class_counts, 1)[:, None]
weighted_mean = weighted_sum / denom[None, :, :]

top1_share = top1_count.astype(np.float64)
top1_share = top1_share / np.maximum(top1_share.sum(axis=2, keepdims=True), 1.0)

agg_weighted = weighted_mean.mean(axis=0)
agg_top1 = top1_share.mean(axis=0)
ensure_dir(outdir)

if make_aggregate_heatmaps:
    save_heatmap(
        os.path.join(outdir, "FIG01_ALL_LAYERS_top1_share_by_class.png"),
        agg_top1, classes, experts,
        title=f"Top-1 expert share by class (mean over layers) | by={aggregate_by}"
    )
    save_heatmap(
        os.path.join(outdir, "FIG02_ALL_LAYERS_weighted_usage_by_class.png"),
        agg_weighted, classes, experts,
        title=f"Weighted expert usage by class (mean over layers) | by={aggregate_by}"
    )

if make_layer_heatmaps:
    for l in range(n_layers):
        save_heatmap(
            os.path.join(outdir, f"FIG_layer_{l:02d}_top1_share_by_class.png"),
            top1_share[l], classes, experts,
            title=f"Layer {l}: Top-1 expert share by class | by={aggregate_by}"
        )
        save_heatmap(
            os.path.join(outdir, f"FIG_layer_{l:02d}_weighted_usage_by_class.png"),
            weighted_mean[l], classes, experts,
            title=f"Layer {l}: Weighted expert usage by class | by={aggregate_by}"
        )

if make_expert_profiles:
    exp_class = agg_top1.T
    exp_class = exp_class / np.maximum(exp_class.sum(axis=1, keepdims=True), 1e-12)

    save_stacked_bar(
        os.path.join(outdir, "FIG03_expert_profile_from_top1.png"),
        exp_class,
        row_names=experts,
        col_names=classes,
        title="Expert proficiency profile (from Top-1 routing)\n(each expert bar sums to 1)"
    )

    exp_class_w = agg_weighted.T
    exp_class_w = exp_class_w / np.maximum(exp_class_w.sum(axis=1, keepdims=True), 1e-12)
    save_stacked_bar(
        os.path.join(outdir, "FIG04_expert_profile_from_weighted.png"),
        exp_class_w,
        row_names=experts,
        col_names=classes,
        title="Expert proficiency profile (from weighted routing)\n(each expert bar sums to 1)"
    )

if make_top_experts_per_class:
    for ci, cname in enumerate(classes):
        vals = agg_top1[ci]
        order = np.argsort(-vals)
        vals_sorted = vals[order]
        labs_sorted = [experts[i] for i in order]
        save_bar_topk(
            os.path.join(outdir, f"FIG_class_{ci:02d}_{re.sub(r'[^a-zA-Z0-9]+','_',cname)}_top1_experts.png"),
            vals_sorted,
            labs_sorted,
            title=f"Class: {cname} | Top-1 expert share (agg over layers)",
            ylabel="share"
        )

print("Saved figures to:", outdir)
print("Files:", sorted([f for f in os.listdir(outdir) if f.endswith(".png")])[:10], "...")


device: cuda
layers=4, num_experts=4, moe_k=2, classes=6


Map:   0%|          | 0/3146 [00:00<?, ? examples/s]

class_counts: {'Computer Science': 5252, 'Physics': 3572, 'Mathematics': 3388, 'Statistics': 3072, 'Quantitative Biology': 312, 'Quantitative Finance': 136}
Saved figures to: ./moe_report_figures
Files: ['FIG01_ALL_LAYERS_top1_share_by_class.png', 'FIG02_ALL_LAYERS_weighted_usage_by_class.png', 'FIG03_expert_profile_from_top1.png', 'FIG04_expert_profile_from_weighted.png', 'FIG_class_00_Computer_Science_top1_experts.png', 'FIG_class_01_Physics_top1_experts.png', 'FIG_class_02_Mathematics_top1_experts.png', 'FIG_class_03_Statistics_top1_experts.png', 'FIG_class_04_Quantitative_Biology_top1_experts.png', 'FIG_class_05_Quantitative_Finance_top1_experts.png'] ...


In [None]:
import ast
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import matplotlib.pyplot as plt


def iter_dict_blocks(text: str):
    buf = []
    depth = 0
    in_block = False

    for ch in text:
        if ch == "{":
            depth += 1
            in_block = True
        if in_block:
            buf.append(ch)
        if ch == "}":
            depth -= 1
            if in_block and depth == 0:
                block = "".join(buf).strip()
                buf = []
                in_block = False
                yield block


def parse_records(log_path: str) -> List[Dict[str, Any]]:
    with open(log_path, "r", encoding="utf-8", errors="ignore") as f:
        text = f.read()

    records = []
    bad = 0
    for block in iter_dict_blocks(text):
        cleaned = block.replace("\n", "").replace("\r", "")

        try:
            d = ast.literal_eval(cleaned)
            if isinstance(d, dict):
                records.append(d)
        except Exception:
            bad += 1
            continue

    print(f"parsed dicts: {len(records)}, skipped(bad): {bad}")
    return records



def split_train_eval(records: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
    train, evals = [], []
    for r in records:
        if "eval_loss" in r:
            evals.append(r)
        elif "loss" in r:
            train.append(r)
    return train, evals


def get_xy(records: List[Dict[str, Any]], xkey: str, ykey: str) -> Tuple[List[float], List[float]]:
    xs, ys = [], []
    for r in records:
        if xkey in r and ykey in r:
            try:
                xs.append(float(r[xkey]))
                ys.append(float(r[ykey]))
            except Exception:
                pass
    pairs = sorted(zip(xs, ys), key=lambda t: t[0])
    if not pairs:
        return [], []
    xs, ys = zip(*pairs)
    return list(xs), list(ys)


def rolling_mean(y: List[float], window: int) -> List[float]:
    if window <= 1 or len(y) == 0:
        return y
    out = []
    s = 0.0
    q = []
    for v in y:
        q.append(v)
        s += v
        if len(q) > window:
            s -= q.pop(0)
        out.append(s / len(q))
    return out


def save_line_plot(
    out_path: str,
    series: List[Tuple[List[float], List[float], str]],
    title: str,
    xlabel: str = "epoch",
    ylabel: str = "",
    ylim: Optional[Tuple[float, float]] = None,
):
    plt.figure(figsize=(10, 5))
    for x, y, label in series:
        if x and y:
            plt.plot(x, y, label=label)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    if ylim is not None:
        plt.ylim(*ylim)
    plt.grid(True, alpha=0.3)
    if len(series) > 1:
        plt.legend()
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()


def save_multiline_per_class(
    out_path: str,
    eval_records: List[Dict[str, Any]],
    metric_key: str,
    class_names: Optional[List[str]] = None,
    title: Optional[str] = None,
):
    points = []
    for r in eval_records:
        if "epoch" in r and metric_key in r and isinstance(r[metric_key], (list, tuple)):
            try:
                ep = float(r["epoch"])
                vals = [float(v) for v in r[metric_key]]
                points.append((ep, vals))
            except Exception:
                pass

    points.sort(key=lambda t: t[0])
    if not points:
        return

    epochs = [p[0] for p in points]
    k = len(points[0][1])

    if class_names is None or len(class_names) != k:
        class_names = [f"class_{i}" for i in range(k)]

    plt.figure(figsize=(11, 6))
    for i in range(k):
        yi = [p[1][i] for p in points]
        plt.plot(epochs, yi, label=class_names[i])

    plt.title(title or metric_key)
    plt.xlabel("epoch")
    plt.ylabel(metric_key)
    plt.grid(True, alpha=0.3)
    plt.legend(ncol=2, fontsize=9)
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()


def main():
    outdir = "logs"
    log = "/content/logs.txt"
    smooth = 25
    os.makedirs(outdir, exist_ok=True)
    class_names = "CS, Physics, Math, Stats, QBio, QFin"
    class_names = [s.strip() for s in class_names.split(",") if s.strip()] or None

    records = parse_records(log)
    train, evals = split_train_eval(records)

    tr_ep, tr_loss = get_xy(train, "epoch", "loss")
    tr_ep2, tr_gn = get_xy(train, "epoch", "grad_norm")
    tr_ep3, tr_lr = get_xy(train, "epoch", "learning_rate")

    if tr_ep and tr_loss:
        loss_smooth = rolling_mean(tr_loss, smooth)
        save_line_plot(
            os.path.join(outdir, "train_loss.png"),
            [
                (tr_ep, tr_loss, "train loss (raw)"),
                (tr_ep, loss_smooth, f"train loss (ma={smooth})"),
            ],
            title="Training loss",
            ylabel="loss",
        )

    if tr_ep2 and tr_gn:
        save_line_plot(
            os.path.join(outdir, "train_grad_norm.png"),
            [(tr_ep2, tr_gn, "grad_norm")],
            title="Grad norm",
            ylabel="grad_norm",
        )

    if tr_ep3 and tr_lr:
        save_line_plot(
            os.path.join(outdir, "learning_rate.png"),
            [(tr_ep3, tr_lr, "learning_rate")],
            title="Learning rate schedule",
            ylabel="learning_rate",
        )

    ev_ep, ev_loss = get_xy(evals, "epoch", "eval_loss")
    ev_ep_f1mi, ev_f1mi = get_xy(evals, "epoch", "eval_f1_micro")
    ev_ep_f1ma, ev_f1ma = get_xy(evals, "epoch", "eval_f1_macro")

    if ev_ep and ev_loss:
        save_line_plot(
            os.path.join(outdir, "eval_loss.png"),
            [(ev_ep, ev_loss, "eval_loss")],
            title="Eval loss",
            ylabel="eval_loss",
        )

    if ev_ep_f1mi and ev_ep_f1ma:
        save_line_plot(
            os.path.join(outdir, "eval_f1_micro_macro.png"),
            [
                (ev_ep_f1mi, ev_f1mi, "F1 micro"),
                (ev_ep_f1ma, ev_f1ma, "F1 macro"),
            ],
            title="Eval F1 (micro & macro)",
            ylabel="F1",
            ylim=(0.0, 1.0),
        )

    ev_ep, pmi = get_xy(evals, "epoch", "eval_precision_micro")
    _, rmi = get_xy(evals, "epoch", "eval_recall_micro")
    _, pma = get_xy(evals, "epoch", "eval_precision_macro")
    _, rma = get_xy(evals, "epoch", "eval_recall_macro")

    if ev_ep and pmi and rmi and pma and rma:
        save_line_plot(
            os.path.join(outdir, "eval_precision_recall.png"),
            [
                (ev_ep, pmi, "precision micro"),
                (ev_ep, rmi, "recall micro"),
                (ev_ep, pma, "precision macro"),
                (ev_ep, rma, "recall macro"),
            ],
            title="Eval precision/recall",
            ylabel="score",
            ylim=(0.0, 1.0),
        )

    save_multiline_per_class(
        os.path.join(outdir, "eval_f1_per_class.png"),
        evals,
        "eval_f1_per_class",
        class_names=class_names,
        title="Eval F1 per class",
    )
    save_multiline_per_class(
        os.path.join(outdir, "eval_precision_per_class.png"),
        evals,
        "eval_precision_per_class",
        class_names=class_names,
        title="Eval Precision per class",
    )
    save_multiline_per_class(
        os.path.join(outdir, "eval_recall_per_class.png"),
        evals,
        "eval_recall_per_class",
        class_names=class_names,
        title="Eval Recall per class",
    )

    print(f"Done. Saved PNGs to: {outdir}")


if __name__ == "__main__":
    main()
