## Loading a custom mode


In [1]:
from mycelia.shared.modeling.custom_qwen3_next import CustomQwen3NextModel, get_moe_model_config
from mycelia.shared.config import MinerConfig
from mycelia.shared.expert_manager import ExpertManager

import os
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ["CUDA_LAUNCH_BLOCKING"] = '1'

import json
import torch
from __future__ import annotations
from transformers import AutoTokenizer, AutoModelForCausalLM
from itertools import chain
from collections import Counter
import matplotlib.pyplot as plt 

from mycelia.shared.helper import route_tokens_to_experts
from mycelia.shared.config import ValidatorConfig
from mycelia.shared.dataloader import get_dataloader
from mycelia.shared.expert_manager import is_expert_param
from mycelia.shared.modeling.mycelia import get_layer_expert_id

  from .autonotebook import tqdm as notebook_tqdm


PyTorch version 2.9.1 available.


In [2]:
# ---- Get config ----
rank = 0
config = ValidatorConfig()
config.task.data.batch_size = 1
config.task.data.sequence_length = 100



In [3]:
config = MinerConfig()
em = ExpertManager(config)

[2m2025-11-20 10:35:22[0m [[32m[1minfo     [0m] [1mloading task folder         [0m [[0m[1m[34mmycelia.shared.expert_manager[0m][0m [36mpositional_args[0m=[35m(PosixPath('/home/isabella/crucible/subnet-MoE/expert_groups/exp_dummy'),)[0m
[2m2025-11-20 10:35:22[0m [[32m[1minfo     [0m] [1mloading task folder         [0m [[0m[1m[34mmycelia.shared.expert_manager[0m][0m [36mpositional_args[0m=[35m(PosixPath('/home/isabella/crucible/subnet-MoE/expert_groups/exp_math'),)[0m


In [None]:
topk = 10
group_ids = None
moe_config = get_moe_model_config(config, topk, group_ids, em.expert_group_assignment)

# get full model
model = CustomQwen3NextModel(moe_config)

In [None]:
# get partial model
topk = 10
group_ids = [0]
moe_config = get_moe_model_config(config, topk, group_ids, em.expert_group_assignment)

# get partial model
moe_config.num_experts = em.num_experts
partial_model = CustomQwen3NextModel(moe_config)


## Model check

In [5]:
gate = model.layers[0].mlp.gate # .available_experts

In [8]:
_,  routing_weights, selected_experts = gate.forward(torch.rand(200, 2048))

expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=25).permute(2, 1, 0)

expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
# routing_weights.sum(dim = 0).shape, expert_hit
expert_mask.sum(dim = 1)

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 1, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 0, 0, 0],
        [1, 0, 1,  ..., 1, 1, 0]])

## Test forward 

In [9]:
# ---- Get tokenizer ---- 
tokenizer = AutoTokenizer.from_pretrained(config.model.model_path)

# ---- Get dataloader ----
train_dataloader = get_dataloader(config, rank=rank, world_size=config.task.data.world_size, tokenizer=tokenizer)
iter_dataloader = iter(train_dataloader)

Too many dataloader workers: 4 (max is dataset.n_shards=1). Stopping 3 dataloader workers.


Too many dataloader workers: 4 (max is dataset.n_shards=1). Stopping 3 dataloader workers.


In [10]:
model.eval()
outputs = []
for i in range(20):
    tokens = next(iter_dataloader)
    for k, v in tokens.items():
        tokens[k] = v[0]
    del tokens['labels']
    with torch.no_grad():
        tokens.to(model.device)
        output = model(**tokens)
    outputs.append(output)

expert keys dict_keys(['5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24'])
expert_hit tensor([[ 0],
        [ 1],
        [ 2],
        [ 3],
        [ 4],
        [ 5],
        [ 6],
        [ 7],
        [ 8],
        [ 9],
        [10],
        [11],
        [12],
        [13],
        [14],
        [15],
        [16],
        [17],
        [18],
        [19],
        [20],
        [21],
        [22],
        [23],
        [24]])
expert keys dict_keys(['5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24'])
expert_hit tensor([[ 0],
        [ 1],
        [ 2],
        [ 3],
        [ 4],
        [ 5],
        [ 6],
        [ 7],
        [ 8],
        [ 9],
        [10],
        [11],
        [12],
        [13],
        [14],
        [15],
        [16],
        [17],
        [18],
        [19],
        [20],
        [21],
        [22],
        [23],
        [24]]