# SHINE: A Scalable In-Context Hypernetwork for Mapping Context to LoRA in a Single Pass

### This is an example of generating a LoRA with SHINE and using the generated LoRA for multi-turn conversations.

## Imports and Setup

In [1]:
import os
import sys
import gc
import math
import time
import json
import random
import re
import logging
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Any
from functools import partial
from collections import OrderedDict, Counter

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
)
from omegaconf import OmegaConf, DictConfig



# --- Project Specific Imports ---
# Ensure these files exist in your directory
try:
    from metanetwork_family import Metanetwork
    from utils.mydataset import HumanDataset, HumanCollator
    from utils.myseed import set_seed
    from utils.mylogging import get_logger
    from utils.mysaveload import (
        save_checkpoint,
        load_checkpoint,
        get_latest_checkpoint,
    )
    from utils.myfreeze import freeze
    from utils.myinit import _resolve_device, _import_class
except ImportError as e:
    print(f"Error importing local modules: {e}")
    print("Please ensure you are running this notebook from the project root directory.")

# Setup Environment
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.backends.cuda.matmul.allow_tf32 = True

# Initialize simple logger for notebook
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("notebook_test")

  from .autonotebook import tqdm as notebook_tqdm


Set which gpu to use

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

## Configuration

In [3]:
# Configuration dictionary
# Overrides from bash script have been applied here
conf_dict = {
    "name": "8gpu_8lora_128metalora_lr5e-5_grouppretrain_1150",
    "mode": "train",
    "resume_global_step": -1,
    "test_global_step": "epoch-2", # From bash script
    
    "run": {
        "seed": 42,
        "use_amp": False,
        "gradient_accumulation_steps": 4,
        "device": "cuda", # auto | cuda | cpu
        "use_gradient_checkpoint": False,
    },
    
    "paths": {
        "model_path": "./models/Qwen3-8B",
    },
    
    "data": {
        "context_max_length": 1024,
        "conversation_max_length": 1024,
        "train_batch_size": 1,
        "eval_batch_size": 1,
        "num_workers": 4,
        "source": "squad",
    },
    
    "model": {
        "lora_r": 8,           # From bash script
        "metalora_r": 128,     # From bash script
        "ift_additional_metalora_r": -1,
        "num_mem_token": 4,    # Placeholder, calculated later
        "metamodel_class_path": "LoraQwen.LoraQwen3ForCausalLM",
        "config_class_path": "LoraQwen.Qwen3Config",
        "tokenizer_from": "./models/Qwen3-8B",
        "model_from": "./models/Qwen3-8B",
    },
    
    "metanetwork": {
        "type": "transformer",
        "method": "rl",        # From bash script
        "transformer_cfg": {
            "encoder_cfg": {
                "d_model": 4096,
                "nhead": 32,
                "dim_feedforward": 8192,
                "dropout": 0,
                "activation": "gelu",
                "layer_norm_eps": 0.00001,
                "batch_first": True,
                "norm_first": False,
                "bias": True
            },
            "couple_encoder_cfg": {
                "d_model": 4096,
                "nhead": 32,
                "dim_feedforward": 8192,
                "dropout": 0,
                "activation": "gelu",
                "layer_norm_eps": 0.00001,
                "batch_first": True,
                "norm_first": False,
                "bias": True
            },
            "layer_transformer_first": True,
            "mean_pool_size": 1,
            "num_layers": 4,    # From bash script
            "couple_num_layers": 0,
            "scale": 0.001
        },
    },
    
    "test": {
        "context_avg_len": 512,
        "context_max_length": 1550,      # From bash script
        "conversation_max_length": 5000, # From bash script
        "max_new_tokens": 128,
    },
    
    "hidden_size": -1,
    "num_layers": -1,
    "num_mem_token": 4
}

# Convert dict to OmegaConf object to match original code access patterns (cfg.model.lora_r)
cfg = OmegaConf.create(conf_dict)
logger.info(f"Configuration loaded. Seed: {cfg.run.seed}")


2026-02-08 23:57:38,765 - notebook_test - INFO - Configuration loaded. Seed: 42


## Helper Functions

In [4]:
def extract_think_and_answer(text: str) -> Tuple[str, str]:
    """
    Robustly splits model output into (think_part, answer_part).
    Handles cases with missing tags or incomplete generation.
    """
    think = ""
    answer = text
    
    # Check for DeepSeek/Qwen style thinking tags
    if "<think>" in text:
        parts = text.split("<think>", 1)
        # Content before <think> is usually empty or irrelevant, but we focus on what's after
        rest = parts[1]
        
        if "</think>" in rest:
            think_part, answer_part = rest.split("</think>", 1)
            think = think_part.strip()
            answer = answer_part.strip()
        else:
            # Thinking block started but didn't close (incomplete generation)
            think = rest.strip()
            answer = "" # Or handle as [error]
    else:
        # No thinking tags found
        answer = text.strip()

    # Clean up standard answer prefixes if necessary
    answer = re.sub(r"^(final answer|answer)\s*:\s*", "", answer, flags=re.IGNORECASE).strip()
    
    return think, answer

@torch.no_grad()
def generate_multiturn(
    metanetwork,
    dataloader,
    tokenizer,
    device,
    use_metanet: bool = True,
    metalora: Optional[torch.Tensor] = None,
    max_new_tokens: int = 500,
    max_conversation_length: int = 3000,
):
    metanetwork.eval()
    results = []
    
    # Iterate with progress bar
    for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc="Generating"):
        questions = batch['questions'][0]
        
        # Handle initial messages
        initial_msg = batch["initial_messages"][0]
        # Check if initial_msg is effectively empty/None (depending on how collation works)
        if isinstance(initial_msg, dict) and initial_msg:
            messages = [initial_msg]
        else:
            messages = []
            
        evidence = batch["evidence"][0]
        evidence_ids = batch["evidence_ids"].to(device, non_blocking=True)
        evidence_attention_mask = batch["evidence_attention_mask"].to(device, non_blocking=True)
        
        lora_dict = None
        if use_metanet:
            # Generate the LoRA weights using the meta-network based on the evidence
            lora_dict = metanetwork.generate_lora_dict(evidence_ids, evidence_attention_mask, metalora)
        
        conversation_log = [{"initial message": deepcopy(messages)}]
        error_count_local = 0
        
        for q_idx, question in enumerate(questions):
            messages.append({"role": "user", "content": question})
            
            # Use chat template
            input_enc = tokenizer.apply_chat_template(
                messages,
                add_generation_prompt=True,
                tokenize=True,
                return_tensors="pt",
                max_length=max_conversation_length,
                truncation=True,
                return_dict=True,
                padding="max_length"
            )
            
            input_ids = input_enc["input_ids"].to(device)
            attention_mask = input_enc["attention_mask"].to(device)
            
            gen_kwargs = {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "max_new_tokens": max_new_tokens,
                "pad_token_id": tokenizer.pad_token_id,
                "eos_token_id": tokenizer.eos_token_id,
                "do_sample": False,
                "ignore_mem_token": True, 
                "loradict": lora_dict,    # Passing the dynamic LoRA weights
            }
            
            # Generate
            # Note: Assuming metamodel.generate supports'loradict' via your custom implementation
            outputs = metanetwork.metamodel.generate(**gen_kwargs)
            
            # Decode
            new_tokens = outputs[0, input_ids.shape[1]:]
            think_answer_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
            
            think_text, answer_text = extract_think_and_answer(think_answer_text)
            
            if not think_text and not answer_text:
                error_count_local += 1
                think_text = "[error]"
            
            messages.append({"role": "assistant", "content": answer_text})

            conversation_log.append({
                "turn": q_idx + 1,
                "question": question,
                "think": think_text,
                "answer": answer_text,
            })
            
        conversation_log[0]["error_count"] = error_count_local
        results.append(conversation_log)
    
    return results


## Model Initialization & Checkpoint Loading

Here we load the model after mqa instruction fine-tuning for multi-turn conversations.
If you want to better 1-turn conversation performance, modify the resume path to load the model after mqa and 1qa instruction fine-tuning in step 8.

In [5]:
print(f"Current Working Directory: {os.getcwd()}")

# 1. Setup Seed and Device
set_seed(int(cfg.run.seed))
device = torch.device("cuda" if torch.cuda.is_available() and cfg.run.device == "cuda" else "cpu")
logger.info(f"Using device: {device}")

# 2. Dynamic Import of Model Classes
logger.info("Loading model classes...")
try:
    MetaModelCls = _import_class(cfg.model.metamodel_class_path)
    ConfigCls = _import_class(cfg.model.config_class_path)
except Exception as e:
    raise ImportError(f"Could not import model classes: {cfg.model.metamodel_class_path}. Error: {e}")

# 3. Load Config
logger.info(f"Loading config from {cfg.model.model_from}")
config = ConfigCls.from_pretrained(cfg.model.model_from)
config.num_mem_token = -1
cfg.hidden_size = config.hidden_size
cfg.num_layers = config.num_hidden_layers

# 4. Calculate num_mem_token for Transformer Metanetwork
logger.info("Calculating memory tokens for Transformer Metanetwork...")
# Temporarily load model to calculate params
tmp_model = MetaModelCls.from_pretrained(cfg.model.model_from, config=config)

lora_params = tmp_model.lora_params_numel(cfg.model.lora_r)
base_params = cfg.hidden_size * cfg.num_layers

assert lora_params % base_params == 0, \
    f"lora_params ({lora_params}) must be divisible by hidden*layers ({base_params})"

config.num_mem_token = lora_params // base_params
cfg.num_mem_token = config.num_mem_token
del tmp_model
gc.collect()
torch.cuda.empty_cache()
logger.info(f"Set num_mem_token to {config.num_mem_token}")

# 5. Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    cfg.model.tokenizer_from, padding_side="left", use_fast=True
)
tokenizer = AutoTokenizer.from_pretrained(cfg.model.tokenizer_from, padding_side="left", use_fast=True)
tokenizer.add_tokens(['<RECON>', '<COMP>', '<NOTHING>'])
tokenizer.chat_template = "{%- if tools %}\n    {{- '<|im_start|>system\\n' }}\n    {%- if messages[0].role == 'system' %}\n        {{- messages[0].content + '\\n\\n' }}\n    {%- endif %}\n    {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n    {%- for tool in tools %}\n        {{- \"\\n\" }}\n        {{- tool | tojson }}\n    {%- endfor %}\n    {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n    {%- if messages[0].role == 'system' %}\n        {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n    {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n    {%- set index = (messages|length - 1) - loop.index0 %}\n    {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n        {%- set ns.multi_step_tool = false %}\n        {%- set ns.last_query_index = index %}\n    {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n    {%- if message.content is string %}\n        {%- set content = message.content %}\n    {%- else %}\n        {%- set content = '' %}\n    {%- endif %}\n    {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n        {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>\\n' }}\n    {%- elif message.role == \"assistant\" %}\n        {%- set reasoning_content = '' %}\n        {%- if message.reasoning_content is string %}\n            {%- set reasoning_content = message.reasoning_content %}\n        {%- else %}\n            {%- if '</think>' in content %}\n                {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n                {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n            {%- endif %}\n        {%- endif %}\n        {%- if loop.index0 > ns.last_query_index %}\n            {%- if (loop.last or (not loop.last and reasoning_content)) and (enable_thinking is not defined or enable_thinking != false) %}\n                {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n            {%- else %}\n                {{- '<|im_start|>' + message.role + '\\n' + content }}\n            {%- endif %}\n        {%- else %}\n            {{- '<|im_start|>' + message.role + '\\n' + content }}\n        {%- endif %}\n        {%- if message.tool_calls %}\n            {%- for tool_call in message.tool_calls %}\n                {%- if (loop.first and content) or (not loop.first) %}\n                    {{- '\\n' }}\n                {%- endif %}\n                {%- if tool_call.function %}\n                    {%- set tool_call = tool_call.function %}\n                {%- endif %}\n                {{- '<tool_call>\\n{\"name\": \"' }}\n                {{- tool_call.name }}\n                {{- '\", \"arguments\": ' }}\n                {%- if tool_call.arguments is string %}\n                    {{- tool_call.arguments }}\n                {%- else %}\n                    {{- tool_call.arguments | tojson }}\n                {%- endif %}\n                {{- '}\\n</tool_call>' }}\n            {%- endfor %}\n        {%- endif %}\n        {{- '<|im_end|>\\n' }}\n    {%- elif message.role == \"tool\" %}\n        {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n            {{- '<|im_start|>user' }}\n        {%- endif %}\n        {{- '\\n<tool_response>\\n' }}\n        {{- content }}\n        {{- '\\n</tool_response>' }}\n        {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n            {{- '<|im_end|>\\n' }}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|im_start|>assistant\\n' }}\n    {%- if enable_thinking is not defined or enable_thinking != false %}\n        {{- '<think>\\n\\n</think>\\n\\n' }}\n    {%- endif %}\n{%- endif %}"
    

# 6. Load Actual Model
logger.info("Loading main model...")
metamodel = MetaModelCls.from_pretrained(cfg.model.model_from, config=config)
metamodel.reset_mem_tokens()
metamodel.resize_token_embeddings(len(tokenizer)) 

# 7. Initialize Metanetwork
logger.info("Initializing Metanetwork...")
metanetwork = Metanetwork(metamodel, cfg, metamodel.lora_params_numel(cfg.model.lora_r))
metanetwork.to(device)
freeze(metamodel) # Freeze base model, train/use metanetwork

# 8. Load Checkpoint
# Here we load the model after mqa instruction fine-tuning for multi-turn conversations.
# If you want to better 1-turn conversation performance, modify the resume path to load the model after mqa and 1qa instruction fine-tuning.
ckpt_root = os.path.join("checkpoints", f"{cfg.name}", "iftpwc")
resume_dir = os.path.join(ckpt_root, f"checkpoint-{cfg.test_global_step}")

logger.info(f"Attempting to load checkpoint from: {resume_dir}")

if os.path.exists(resume_dir):
    metanetwork, metalora, _ = load_checkpoint(metanetwork, resume_dir, device)
    logger.info("Checkpoint loaded successfully.")
else:
    logger.warning(f"Checkpoint directory {resume_dir} not found! Running with initialized weights.")
    metalora = None 


Current Working Directory: /home/liuyewei/SHINE


2026-02-08 23:57:39,178 - notebook_test - INFO - Using device: cuda
2026-02-08 23:57:39,179 - notebook_test - INFO - Loading model classes...
2026-02-08 23:57:39,644 - notebook_test - INFO - Loading config from ./models/Qwen3-8B
2026-02-08 23:57:39,648 - notebook_test - INFO - Calculating memory tokens for Transformer Metanetwork...
Loading checkpoint shards: 100%|██████████| 5/5 [00:07<00:00,  1.59s/it]
2026-02-08 23:57:53,937 - notebook_test - INFO - Set num_mem_token to 148
2026-02-08 23:57:57,330 - notebook_test - INFO - Loading main model...
Loading checkpoint shards: 100%|██████████| 5/5 [00:07<00:00,  1.58s/it]
Some weights of LoraQwen3ForCausalLM were not initialized from the model checkpoint at ./models/Qwen3-8B and are newly initialized: ['model.mem_tokens']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2026-02-08 23:58:30,618 - notebook_test - INFO - Initializing Metanetwork...
2026-02-08 23:59:28,872 - noteboo

## Data Loading

You can change the data to try on any different contexts and questions

In [6]:
logger.info("Preparing Data...")

# Change the contexts and questions as you like.
data = [{"context": "Apple is green.", "questions":["What color is an apple?", "What color is a banana"]},
        {"context": "Chinese food is the best food on earth.", "questions":["Which food is the best?", "What do you want to eat?"]},
        {"context": "If the light is on, somebody must be at home. If the light is off, often nobody is at home. But this holds true only during the day. In the night people are all sleeping so there will always be no lights.",
         "questions":["What does it mean if the light is on?", "What does it mean if the light is off?", "Why in the night this rule might not hold true?"]},
         {"context": "Whoever organizes cheating in a national examination prescribed by law shall be sentenced to fixed-term imprisonment of not more than three years or criminal detention and shall also be fined, or shall be fined only; if the circumstances are serious, he shall be sentenced to fixed-term imprisonment of not less than three years but not more than seven years and shall also be fined.",
         "questions":["What will happen if one organize cheating?", "How long will one be imprisoned if he cheated in an exam and the situation is very serious?"]},
         {"context": "When someone says \"fair enough\", it can mean two slightly different things, and you usually understand which one it is from the tone and the moment. Sometimes it means real agreement — the person has heard your reason, it makes sense to them, and they are genuinely okay with it. Other times, it does not mean they agree at all. It is more like a polite way of saying, \"I do not think the same, but I am done arguing.\" In that case, fair enough is about keeping the conversation calm and moving on, not about changing their mind.",
          "questions":["What does \"fair enough\" mean?", "Does \"fair enough\" have the agree meaning?", "Does \"fair enough\" have disagree meaning?", "OK, fair enough."]}]


human_dataset = HumanDataset(data)

test_collator = HumanCollator(tokenizer, context_max_length=cfg.test.context_max_length, conversation_max_length=cfg.test.conversation_max_length, cfg=cfg)
test_prompt_collator = HumanCollator(tokenizer, context_max_length=cfg.test.context_max_length, conversation_max_length=cfg.test.conversation_max_length, cfg=cfg, sys_msg=True)
test_only_question_collator = HumanCollator(tokenizer, context_max_length=cfg.test.context_max_length, conversation_max_length=cfg.test.conversation_max_length, cfg=cfg, sys_msg=True, no_evidence=True)

generate_test_loader = DataLoader(
    human_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=test_collator,
    num_workers=0,
    pin_memory=(device.type == "cuda")
)
generate_prompt_test_loader = DataLoader(
    human_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=test_prompt_collator,
    num_workers=0,
    pin_memory=(device.type == "cuda")
)
generate_only_question_test_loader = DataLoader(
    human_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=test_only_question_collator,
    num_workers=0,
    pin_memory=(device.type == "cuda")
)


logger.info("Data Loader ready.")


2026-02-09 00:00:55,176 - notebook_test - INFO - Preparing Data...
2026-02-09 00:00:55,180 - notebook_test - INFO - Data Loader ready.


## Inference

In [7]:
logger.info("Starting Inference...")

# Run generation
# metalora is loaded from checkpoint in Cell 4

results = generate_multiturn(
    metanetwork,
    generate_test_loader,
    tokenizer,
    device,
    use_metanet=True,
    metalora=metalora,
    max_new_tokens=500,
    max_conversation_length=3000,
)

results_in_context = generate_multiturn(
    metanetwork,
    generate_prompt_test_loader,
    tokenizer,
    device,
    use_metanet=False,
    max_new_tokens=500,
    max_conversation_length=3000,
)

results_only_question = generate_multiturn(
    metanetwork,
    generate_only_question_test_loader,
    tokenizer,
    device,
    use_metanet=False,
    max_new_tokens=500,
    max_conversation_length=3000,
)
for i in range(len(results)):
    print(f"--- Conversation {i + 1} ---")
    print(f"Initial message:")
    print(f"[{"SHINE":<14}]: {results[i][0]["initial message"]}\n[{'In-Context':<14}]: {results_in_context[i][0]["initial message"]}\n[{'Only Question':<14}]: {results_only_question[i][0]["initial message"]}")
    print(f"Context: {data[i]['context']}") 
    for j in range(1, len(results[i])):
        print(f"Turn: {j}")
        print(f"Question: {results[i][j]["question"]}")
        print(f"[{'SHINE':<14}]: {results[i][j]['answer']}")
        print(f"[{'In-Context':<14}]: {results_in_context[i][j]['answer']}")
        print(f"[{'Only Question':<14}]: {results_only_question[i][j]['answer']}")
    print(f"\n\n")


2026-02-09 00:00:55,197 - notebook_test - INFO - Starting Inference...
Generating:   0%|          | 0/5 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Generating: 100%|██████████| 5/5 [00:39<00:00,  7.82s/it]
Generating: 100%|██████████| 5/5 [00:43<00:00,  8.67s/it]
Generating: 100%|██████████| 5/5 [00:24<00:00,  4.80s/it]

--- Conversation 1 ---
Initial message:
[SHINE         ]: []
[In-Context    ]: [{'role': 'system', 'content': 'You are a helpful assistant, answer the questions based on the given context. Each answer must be directly extractable from the context (i.e., an exact span or minor paraphrase for fluency). Do not invent information. Answer the question directly and output nothing else. Never enter think mode.\n\nContext: Apple is green.'}]
[Only Question ]: [{'role': 'system', 'content': "You are a helpful assistant. Answer the question concisely with short words or phrases. Answer the question directly and output nothing else. Never say you don't know the answer. Never enter think mode."}]
Context: Apple is green.
Turn: 1
Question: What color is an apple?
[SHINE         ]: An apple can be green, red, or yellow, depending on the variety.
[In-Context    ]: An apple is green.
[Only Question ]: Red.
Turn: 2
Question: What color is a banana
[SHINE         ]: A banana is typically yellow when rip


