In [None]:
# moe.py
import torch
import torch.nn as nn
from typing import Optional
from transformers import BertPreTrainedModel, BertModel
from transformers.models.bert.modeling_bert import (
    BertLayer,
    BertOutput,
    BertLMPredictionHead,
)


class MoEFFN(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_experts: int = 4,
        expert_size: Optional[int] = None,
        k: int = 2,
        dropout_prob: float = 0.1,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_experts = num_experts
        self.k = k
        self.expert_size = expert_size or hidden_size * 4

        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, self.expert_size),
                nn.GELU(),
                nn.Dropout(dropout_prob),
                nn.Linear(self.expert_size, hidden_size),
            )
            for _ in range(num_experts)
        ])

        self.gate = nn.Linear(hidden_size, num_experts, bias=False)

    def forward(self, hidden_states):
        batch_size, seq_len, hidden_dim = hidden_states.shape
        assert hidden_dim == self.hidden_size

        x = hidden_states.view(-1, hidden_dim)  # [N, H]
        gate_logits = self.gate(x)  # [N, E]
        top_k_logits, top_k_indices = torch.topk(gate_logits, self.k, dim=1)  # [N, k]
        top_k_weights = torch.softmax(top_k_logits, dim=1)  # [N, k]

        final_output = torch.zeros_like(x)

        for i in range(self.num_experts):
            expert_mask = (top_k_indices == i)  # [N, k]
            if expert_mask.any():
                token_indices = expert_mask.nonzero(as_tuple=True)[0]  # [M]
                pos_in_topk = expert_mask.nonzero(as_tuple=True)[1]    # [M]

                expert_inputs = x[token_indices]  # [M, H]
                expert_weights = top_k_weights[token_indices, pos_in_topk]  # [M]
                expert_out = self.experts[i](expert_inputs)  # [M, H]
                weighted_out = expert_out * expert_weights.unsqueeze(-1)  # [M, H]

                final_output.index_add_(0, token_indices, weighted_out)

        return final_output.view(batch_size, seq_len, hidden_dim)


from transformers.models.bert.modeling_bert import BertLayer
import torch.nn as nn

class BertLayerWithMoE(BertLayer):
    def __init__(self, config):
        super().__init__(config)
        # –£–¥–∞–ª—è–µ–º —Å—Ç–∞–Ω–¥–∞—Ä—Ç–Ω—ã–π FFN
        del self.intermediate
        del self.output

        self.moe_ffn = MoEFFN(
            hidden_size=config.hidden_size,
            num_experts=getattr(config, "num_experts", 4),
            expert_size=config.intermediate_size,  # –∏—Å–ø–æ–ª—å–∑—É–µ—Ç—Å—è –≤–Ω—É—Ç—Ä–∏ MoE
            k=getattr(config, "moe_k", 2),
            dropout_prob=config.hidden_dropout_prob,
        )

        # –í–º–µ—Å—Ç–æ BertOutput ‚Äî —Å–æ–∑–¥–∞—ë–º —Å–≤–æ–π –ø—Ä–æ—Å—Ç–æ–π LayerNorm + Dropout
        self.moe_output_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.moe_output_dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=False,
        **kwargs,
    ):
        self_attn_output = self.attention(
            hidden_states,
            attention_mask=attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
        )
        attn_output = self_attn_output[0]

        moe_output = self.moe_ffn(attn_output)  # [B, L, hidden_size]

        # Residual + Dropout + LayerNorm (–∫–∞–∫ –≤ –æ—Ä–∏–≥–∏–Ω–∞–ª—å–Ω–æ–º BERT)
        moe_output = self.moe_output_dropout(moe_output)
        layer_output = self.moe_output_layer_norm(attn_output + moe_output)

        outputs = (layer_output,) + self_attn_output[1:]
        return outputs


class BertMoEForMaskedLM(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config

        # –°–æ–∑–¥–∞—ë–º BERT –∏ –∑–∞–º–µ–Ω—è–µ–º —Å–ª–æ–∏ –Ω–∞ MoE
        self.bert = BertModel(config, add_pooling_layer=False)
        for layer in self.bert.encoder.layer:
            layer.__class__ = BertLayerWithMoE
            layer.__init__(config)

        self.cls = BertLMPredictionHead(config)
        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        labels=None,
        **kwargs,
    ):
        # –ü–µ—Ä–µ–¥–∞—ë–º –¢–û–õ–¨–ö–û –ø–æ–¥–¥–µ—Ä–∂–∏–≤–∞–µ–º—ã–µ –∞—Ä–≥—É–º–µ–Ω—Ç—ã –≤ BertModel
        bert_kwargs = {
            k: v for k, v in {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "token_type_ids": token_type_ids,
                "position_ids": position_ids,
                "head_mask": head_mask,
                "inputs_embeds": inputs_embeds,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
            }.items() if v is not None
        }

        outputs = self.bert(**bert_kwargs)

        sequence_output = outputs.last_hidden_state
        prediction_scores = self.cls(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                prediction_scores.view(-1, self.config.vocab_size),
                labels.view(-1)
            )

        return {
            "loss": loss,
            "logits": prediction_scores,
            "hidden_states": outputs.hidden_states,
            "attentions": outputs.attentions,
        }

In [None]:
class PretrainConfig:
    model_name = "your-moe-bert"
    dataset_name = "wikimedia/wikipedia"
    dataset_config = "20231101.en"
    text_column = "text"
    tokenizer = "bert-base-uncased"
    output_dir = "."
    seq_len = 128
    batch_size = 32

    masking_prob = 0.15

    lr = 5e-5
    weight_decay = 0.01
    warmup_steps = 1000
    max_steps = 10_000

    save_steps = 5_000
    logging_steps = 100
    eval_steps = 2000

    # BERT / MoE –ø–∞—Ä–∞–º–µ—Ç—Ä—ã
    bert_hidden_size = 256
    bert_intermediate_size = 1024
    bert_num_hidden_layers = 4
    bert_num_attention_heads = 4
    num_experts = 4

In [None]:
import os
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
    set_seed,
)
from transformers import BertConfig


def main():
    set_seed(42)
    cfg = PretrainConfig()

    # --- –ú–æ–¥–µ–ª—å –∏ —Ç–æ–∫–µ–Ω–∏–∑–∞—Ç–æ—Ä ---
    model_config = BertConfig(
        vocab_size=30522,
        hidden_size=cfg.bert_hidden_size,
        num_hidden_layers=cfg.bert_num_hidden_layers,
        num_attention_heads=cfg.bert_num_attention_heads,
        intermediate_size=cfg.bert_intermediate_size,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=2,
        pad_token_id=0,
        num_experts=cfg.num_experts,
        moe_k=2,
    )

    tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer)
    model = BertMoEForMaskedLM(model_config)

    print(f"Model initialized with {cfg.num_experts} experts.")
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

    # --- –ó–ê–ì–†–£–ó–ö–ê –î–ê–¢–ê–°–ï–¢–ê –í STREAMING –†–ï–ñ–ò–ú–ï ---
    print("Loading dataset in streaming mode...")
    dataset = load_dataset(
        cfg.dataset_name,
        cfg.dataset_config,
        split="train",
        streaming=True  # üî• –∫–ª—é—á–µ–≤–æ–µ –∏–∑–º–µ–Ω–µ–Ω–∏–µ!
    )

    # --- –§–£–ù–ö–¶–ò–Ø –¢–û–ö–ï–ù–ò–ó–ê–¶–ò–ò (–±—É–¥–µ—Ç –ø—Ä–∏–º–µ–Ω—è—Ç—å—Å—è –ª–µ–Ω–∏–≤–æ) ---
    def tokenize_function(examples):
        return tokenizer(
            examples[cfg.text_column],
            truncation=True,
            padding=False,  # collator —Å–∞–º —Å–¥–µ–ª–∞–µ—Ç padding –¥–æ batch max
            max_length=cfg.seq_len,
            return_special_tokens_mask=True,
        )

    # –ü—Ä–∏–º–µ–Ω—è–µ–º —Ç–æ–∫–µ–Ω–∏–∑–∞—Ü–∏—é –∏ —É–¥–∞–ª—è–µ–º –í–°–ï –∏—Å—Ö–æ–¥–Ω—ã–µ –∫–æ–ª–æ–Ω–∫–∏
    original_columns = dataset.column_names  # ['id', 'text', 'url'] ‚Äî –¥–ª—è wikipedia
    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=original_columns,  # ‚Üê —É–¥–∞–ª—è–µ–º –í–°–Å, –∫—Ä–æ–º–µ output tokenizer'–∞
    )

    # --- –§–∏–ª—å—Ç—Ä–∞—Ü–∏—è —Å–ª–∏—à–∫–æ–º –∫–æ—Ä–æ—Ç–∫–∏—Ö –ø—Ä–∏–º–µ—Ä–æ–≤ (–æ–ø—Ü–∏–æ–Ω–∞–ª—å–Ω–æ, –Ω–æ –æ—Å—Ç–æ—Ä–æ–∂–Ω–æ –≤ streaming!) ---
    def filter_short(example):
        return len(example["input_ids"]) >= cfg.seq_len // 2

    tokenized_dataset = tokenized_dataset.filter(filter_short)

    # --- Data collator ---
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=True,
        mlm_probability=cfg.masking_prob,
    )

    # --- Training args ---
    training_args = TrainingArguments(
        output_dir=cfg.output_dir,
        overwrite_output_dir=True,
        max_steps=cfg.max_steps,
        per_device_train_batch_size=cfg.batch_size,
        gradient_accumulation_steps=1,
        learning_rate=cfg.lr,
        weight_decay=cfg.weight_decay,
        warmup_steps=cfg.warmup_steps,
        logging_steps=cfg.logging_steps,
        save_steps=cfg.save_steps,
        save_strategy="steps",
        load_best_model_at_end=False,
        fp16=True,
        dataloader_num_workers=2,  # –º–æ–∂–Ω–æ 0‚Äì4, –Ω–æ –≤ streaming –ª—É—á—à–µ 0‚Äì2
        remove_unused_columns=False,
        report_to="none",
        # ‚ö†Ô∏è –í–ê–ñ–ù–û: –æ—Ç–∫–ª—é—á–∞–µ–º shuffle –¥–ª—è streaming (–∏–ª–∏ –∏—Å–ø–æ–ª—å–∑—É–µ–º –±—É—Ñ–µ—Ä)
        dataloader_drop_last=True,
        save_safetensors=False
    )

    # --- –°–æ–∑–¥–∞—ë–º Trainer ---
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,  # ‚Üê streaming dataset!
        data_collator=data_collator,
        tokenizer=tokenizer,
    )

    # --- –û–±—É—á–µ–Ω–∏–µ ---
    print("Starting pretraining (streaming)...")
    trainer.train()

    # --- –°–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ ---
    final_dir = os.path.join(cfg.output_dir, "final_model")
    trainer.save_model(final_dir)
    tokenizer.save_pretrained(final_dir)
    print(f"Model saved to {final_dir}")


if __name__ == "__main__":
    main()

In [None]:
import torch
from transformers import AutoTokenizer

model_dir = "/content/final_model"

tokenizer = AutoTokenizer.from_pretrained(model_dir)

model = BertMoEForMaskedLM.from_pretrained(model_dir)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

import torch

text = "Paris is the [MASK] of France."

inputs = tokenizer(
    text,
    return_tensors="pt"
)

inputs = {k: v.to(device) for k, v in inputs.items()}

with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs["logits"]

mask_token_id = tokenizer.mask_token_id
mask_positions = (inputs["input_ids"] == mask_token_id).nonzero(as_tuple=False)

batch_idx, mask_pos = mask_positions[0].tolist()

mask_logits = logits[batch_idx, mask_pos, :]
top_k = torch.topk(mask_logits, k=10)
top_ids = top_k.indices.tolist()
top_scores = top_k.values.tolist()

print("Input:", text)
print("Top predictions for [MASK]:")
for token_id, score in zip(top_ids, top_scores):
    token = tokenizer.decode([token_id])
    print(f"{token!r}  logit={score:.3f}")


In [None]:
# moe_multilabel_model.py

import torch
import torch.nn as nn
from transformers import BertConfig
from transformers.modeling_outputs import SequenceClassifierOutput


class BertMoEForMultiLabelClassification(BertMoEForMaskedLM):
    """
    –ú—É–ª—å—Ç–∏–ª–µ–π–±–ª –∫–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ç–æ—Ä –Ω–∞ –æ—Å–Ω–æ–≤–µ —Ç–≤–æ–µ–π MLM-–º–æ–¥–µ–ª–∏ BertMoEForMaskedLM.
    - bert + moe –±–µ—Ä—É—Ç—Å—è –∏–∑ —Ä–æ–¥–∏—Ç–µ–ª—å—Å–∫–æ–≥–æ –∫–ª–∞—Å—Å–∞
    - MLM-–≥–æ–ª–æ–≤–∞ (self.cls) –æ—Å—Ç–∞—ë—Ç—Å—è, –Ω–æ –≤ —ç—Ç–æ–º forward –ù–ï –∏—Å–ø–æ–ª—å–∑—É–µ—Ç—Å—è
    - –ø–æ–≤–µ—Ä—Ö –¥–æ–±–∞–≤–ª—è–µ—Ç—Å—è multilabel-–∫–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ü–∏–æ–Ω–Ω–∞—è –≥–æ–ª–æ–≤–∞ —Å BCEWithLogitsLoss
    """

    def __init__(self, config: BertConfig):
        super().__init__(config)

        self.num_labels = config.num_labels
        # –ø–æ–ª–µ–∑–Ω–æ —è–≤–Ω–æ –ø–æ–º–µ—Ç–∏—Ç—å —Ç–∏–ø –∑–∞–¥–∞—á–∏
        self.config.problem_type = "multi_label_classification"

        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, self.num_labels)

        # post_init –≤ —Ä–æ–¥–∏—Ç–µ–ª–µ —É–∂–µ –±—ã–ª –≤—ã–∑–≤–∞–Ω –≤ super().__init__

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        labels=None,  # [batch, num_labels] —Å 0/1
        **kwargs,
    ):
        # 1) –ü—Ä–æ–≥–æ–Ω—è–µ–º —á–µ—Ä–µ–∑ BERT+MoE (–∫–∞–∫ –≤ MLM-–∫–ª–∞—Å—Å–µ, –Ω–æ –±–µ–∑ MLM-–≥–æ–ª–æ–≤—ã)
        bert_kwargs = {
            k: v for k, v in {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "token_type_ids": token_type_ids,
                "position_ids": position_ids,
                "head_mask": head_mask,
                "inputs_embeds": inputs_embeds,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
            }.items() if v is not None
        }

        outputs = self.bert(**bert_kwargs)
        sequence_output = outputs.last_hidden_state  # [B, L, H]

        # 2) –ò—Å–ø–æ–ª—å–∑—É–µ–º [CLS] —Ç–æ–∫–µ–Ω
        cls_output = sequence_output[:, 0, :]        # [B, H]
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)         # [B, num_labels]

        # 3) –õ–æ—Å—Å –¥–ª—è multilabel
        loss = None
        if labels is not None:
            # labels: int/float {0,1}, —Ä–∞–∑–º–µ—Ä [B, num_labels]
            loss_fct = nn.BCEWithLogitsLoss()
            loss = loss_fct(logits, labels.float())


        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [None]:
LABEL_COLUMNS = [
    "Computer Science",
    "Physics",
    "Mathematics",
    "Statistics",
    "Quantitative Biology",
    "Quantitative Finance",
]

In [None]:
# multilabel_config.py

from dataclasses import dataclass, field
from typing import List


@dataclass
class MultiLabelConfig:
    # –ü—É—Ç—å –∫ —Ç–≤–æ–µ–π –ø—Ä–µ–¥–æ–±—É—á–µ–Ω–Ω–æ–π MLM-MoE –º–æ–¥–µ–ª–∏ (BertMoEForMaskedLM)
    pretrained_mlm_path: str = "/content/final_model"  # ‚Üê –ó–ê–ú–ï–ù–ò

    # –ü—É—Ç—å –∫ Kaggle csv
    train_csv: str = "/content/data/train.csv"
    test_csv: str = "/content/data/train.csv"

    # –ö–æ–ª–æ–Ω–∫–∏
    title_column: str = "TITLE"
    abstract_column: str = "ABSTRACT"
    label_columns: List[str] = field(default_factory=lambda: [
        "Computer Science",
        "Physics",
        "Mathematics",
        "Statistics",
        "Quantitative Biology",
        "Quantitative Finance",
    ])

    # –¢–æ–∫–µ–Ω–∏–∑–∞—Ç–æ—Ä
    tokenizer_name: str = "bert-base-uncased"  # –∏–ª–∏ —Ç–æ—Ç, —á—Ç–æ —Ç—ã –∏—Å–ø–æ–ª—å–∑–æ–≤–∞–ª –¥–ª—è MLM

    max_length: int = 256
    train_batch_size: int = 16
    eval_batch_size: int = 16
    num_train_epochs: int = 3
    learning_rate: float = 2e-5
    weight_decay: float = 0.01
    logging_steps: int = 50
    save_steps: int = 500
    output_dir: str = "./moe_multilabel"


In [None]:
# preprocess_multilabel.py (–ª–æ–≥–∏–∫–∞ –≤–Ω—É—Ç—Ä–∏ train-—Å–∫—Ä–∏–ø—Ç–∞, —Å–º. –Ω–∏–∂–µ)

from datasets import load_dataset
from transformers import AutoTokenizer

cfg = MultiLabelConfig()
tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_name, use_fast=True)

raw_dataset = load_dataset("csv", data_files={"train": cfg.train_csv})["train"]

def preprocess_function(examples):
    # –¢–∏—Ç–ª + –∞–±—Å—Ç—Ä–∞–∫—Ç –ø–æ–¥–∞—ë–º –∫–∞–∫ –ø–∞—Ä—É –ø—Ä–µ–¥–ª–æ–∂–µ–Ω–∏–π
    tokenized = tokenizer(
        examples[cfg.title_column],
        examples[cfg.abstract_column],
        truncation=True,
        max_length=cfg.max_length,
        padding=False,
    )

    labels = []
    for i in range(len(examples[cfg.title_column])):
        labels.append([
            examples[col][i] for col in cfg.label_columns
        ])
    tokenized["labels"] = labels
    return tokenized

processed_dataset = raw_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=raw_dataset.column_names,  # —É–¥–∞–ª—è–µ–º ID, TITLE, ABSTRACT, label-–∫–æ–ª–æ–Ω–∫–∏
)

In [None]:
# train_multilabel_moe.py

import os
import torch
from datasets import load_dataset
from transformers import (
    AutoConfig,
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
    set_seed,
)
from sklearn.metrics import precision_recall_fscore_support
from transformers import EvalPrediction
import numpy as np

def build_compute_metrics_fn(threshold: float = 0.5):
    """
    –í–æ–∑–≤—Ä–∞—â–∞–µ—Ç —Ñ—É–Ω–∫—Ü–∏—é, –∫–æ—Ç–æ—Ä—É—é –º–æ–∂–Ω–æ –ø–µ—Ä–µ–¥–∞—Ç—å –≤ Trainer –∫–∞–∫ compute_metrics.
    –î–ª—è multilabel:
    - logits -> sigmoid -> >= threshold -> 0/1
    - —Å—á–∏—Ç–∞–µ—Ç macro/micro precision/recall/F1
    """
    def compute_metrics(eval_pred: EvalPrediction):
        logits, labels = eval_pred
        # labels: shape [N, num_labels]
        # logits: shape [N, num_labels]

        # 1) —Å–∏–≥–º–æ–∏–¥–∞
        probs = 1 / (1 + np.exp(-logits))

        # 2) –±–∏–Ω–∞—Ä–∏–∑–∞—Ü–∏—è
        y_pred = (probs >= threshold).astype(int)
        y_true = labels.astype(int)

        # 3) —Å—á–∏—Ç–∞–µ–º –º–µ—Ç—Ä–∏–∫–∏
        # micro: —Å—É–º–º–∞—Ä–Ω–æ –ø–æ –≤—Å–µ–º –∫–ª–∞—Å—Å–∞–º
        precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
            y_true, y_pred, average="micro", zero_division=0
        )
        # macro: —Å—Ä–µ–¥–Ω–µ–µ –ø–æ –∫–ª–∞—Å—Å–∞–º
        precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
            y_true, y_pred, average="macro", zero_division=0
        )

        return {
            "precision_micro": precision_micro,
            "recall_micro": recall_micro,
            "f1_micro": f1_micro,
            "precision_macro": precision_macro,
            "recall_macro": recall_macro,
            "f1_macro": f1_macro,
        }

    return compute_metrics

def main():
    set_seed(42)
    cfg = MultiLabelConfig()

    tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_name, use_fast=True)

    raw_dataset = load_dataset("csv", data_files={"train": cfg.train_csv})["train"]

    dataset_splits = raw_dataset.train_test_split(test_size=0.1, seed=42)
    train_dataset = dataset_splits["train"]
    eval_dataset = dataset_splits["test"]

    LABEL_COLUMNS = cfg.label_columns
    print("–ö–æ–ª–æ–Ω–∫–∏:", raw_dataset.column_names)
    print("Label-–∫–æ–ª–æ–Ω–∫–∏:", LABEL_COLUMNS)

    base_config = AutoConfig.from_pretrained(cfg.pretrained_mlm_path)
    base_config.num_labels = len(LABEL_COLUMNS)
    base_config.problem_type = "multi_label_classification"

    model = BertMoEForMultiLabelClassification.from_pretrained(
        cfg.pretrained_mlm_path,
        config=base_config,
        ignore_mismatched_sizes=True,
    )

    def preprocess_function(examples):
        # —Ç–æ–∫–µ–Ω–∏–∑–∞—Ü–∏—è: TITLE + ABSTRACT –∫–∞–∫ –ø–∞—Ä–∞ —Ç–µ–∫—Å—Ç–æ–≤
        tokenized = tokenizer(
            examples[cfg.title_column],
            examples[cfg.abstract_column],
            truncation=True,
            max_length=cfg.max_length,
            padding=False,
        )

        labels = []
        for i in range(len(examples[cfg.title_column])):
            labels.append([
                examples[col][i] for col in LABEL_COLUMNS
            ])
        tokenized["labels"] = labels
        return tokenized

    train_dataset = train_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=raw_dataset.column_names,
    )
    eval_dataset = eval_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=raw_dataset.column_names,
    )

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    training_args = TrainingArguments(
        output_dir=cfg.output_dir,
        overwrite_output_dir=True,
        num_train_epochs=cfg.num_train_epochs,
        per_device_train_batch_size=cfg.train_batch_size,
        per_device_eval_batch_size=cfg.eval_batch_size,
        learning_rate=cfg.learning_rate,
        weight_decay=cfg.weight_decay,
        logging_steps=cfg.logging_steps,
        save_steps=cfg.save_steps,
        eval_strategy="steps",
        eval_steps=cfg.save_steps,
        save_total_limit=2,
        load_best_model_at_end=True,
        fp16=torch.cuda.is_available(),
        report_to="none",
        save_safetensors=False
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=build_compute_metrics_fn(threshold=0.5)
    )

    trainer.train()

    final_dir = os.path.join(cfg.output_dir, "final_model")
    trainer.save_model(final_dir)
    tokenizer.save_pretrained(final_dir)
    print(f"Model saved to {final_dir}")


if __name__ == "__main__":
    main()


In [None]:
# infer_multilabel_from_test_csv.py

import os
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoConfig


def load_model_for_inference(model_dir: str, cfg: MultiLabelConfig):
    config = AutoConfig.from_pretrained(model_dir)
    tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
    model = BertMoEForMultiLabelClassification.from_pretrained(model_dir, config=config)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    return model, tokenizer, device


def predict_on_test_csv(
    model,
    tokenizer,
    device,
    cfg: MultiLabelConfig,
    threshold: float = 0.5,
    batch_size: int = 32,
    output_path: str = "sample_submission.csv",
):
    # 1) –ó–∞–≥—Ä—É–∂–∞–µ–º —Ç–µ—Å—Ç–æ–≤—ã–π CSV
    test_df = pd.read_csv(cfg.test_csv)  # ‚Üê –≤ –∫–æ–Ω—Ñ–∏–≥–µ —É–∂–µ –µ—Å—Ç—å test_dir
    print("–¢–µ—Å—Ç–æ–≤—ã–π –¥–∞—Ç–∞—Å–µ—Ç:", test_df.shape)
    print("–ö–æ–ª–æ–Ω–∫–∏ –≤ —Ç–µ—Å—Ç–µ:", list(test_df.columns))

    # –û–∂–∏–¥–∞–µ–º, —á—Ç–æ –µ—Å—Ç—å —Ö–æ—Ç—è –±—ã ID, TITLE, ABSTRACT
    id_col = "ID"

    # 2) –ü—Ä–æ–≥–æ–Ω—è–µ–º –¥–∞–Ω–Ω—ã–µ –±–∞—Ç—á–∞–º–∏
    all_preds = []

    num_labels = len(cfg.label_columns)
    n = len(test_df)

    for start in range(0, n, batch_size):
        end = min(start + batch_size, n)
        batch = test_df.iloc[start:end]

        encoded = tokenizer(
            batch[cfg.title_column].tolist(),
            batch[cfg.abstract_column].tolist(),
            truncation=True,
            max_length=cfg.max_length,
            padding=True,
            return_tensors="pt",
        )

        encoded = {k: v.to(device) for k, v in encoded.items()}

        with torch.no_grad():
            outputs = model(**encoded)
            logits = outputs["logits"]  # [B, num_labels]
            probs = torch.sigmoid(logits).cpu().numpy()

        preds = (probs >= threshold).astype(int)  # 0/1
        all_preds.append(preds)

    all_preds = np.vstack(all_preds)  # [N, num_labels]
    assert all_preds.shape == (len(test_df), num_labels)

    # 3) –§–æ—Ä–º–∏—Ä—É–µ–º —Ç–∞–±–ª–∏—Ü—É –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π
    submission_df = pd.DataFrame(
        all_preds,
        columns=cfg.label_columns,
    )
    submission_df.insert(0, id_col, test_df[id_col].values)

    # 4) –°–æ—Ö—Ä–∞–Ω—è–µ–º –≤ sample_submission.csv
    submission_df.to_csv(output_path, index=False)
    print(f"–°–æ—Ö—Ä–∞–Ω–µ–Ω–æ –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–µ –≤ {output_path}")
    print(submission_df.head())


if __name__ == "__main__":
    cfg = MultiLabelConfig()

    # –î–∏—Ä–µ–∫—Ç–æ—Ä–∏—è, –≥–¥–µ –ª–µ–∂–∏—Ç –æ–±—É—á–µ–Ω–Ω–∞—è –º–æ–¥–µ–ª—å
    model_dir = "/content/moe_multilabel/checkpoint-3500"  # –∫–∞–∫ –≤ train-—Å–∫—Ä–∏–ø—Ç–µ

    model, tokenizer, device = load_model_for_inference(model_dir, cfg)

    predict_on_test_csv(
        model=model,
        tokenizer=tokenizer,
        device=device,
        cfg=cfg,
        threshold=0.5,
        batch_size=32,
        output_path="sample_submission.csv",
    )
