<a href="https://colab.research.google.com/github/allyoushawn/jupyter_notebook_projects/blob/main/ml_misc/tiger_semantic_id_amazon_beauty_llm_finetune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TIGER SemanticID: Qwen3-8B Fine-tuning for SID Recommendation

This notebook fine-tunes Qwen3-8B to generate Semantic IDs for next-item recommendation.

## Pipeline Overview

1. **Build Dialogs**: Convert user histories to conversational format + build trie
2. **Stage A (Vocab)**: Fine-tune only embeddings to learn 1,027 new SID tokens
3. **Stage B (Full)**: Fine-tune entire model on recommendation task
4. **Inference**: Generate SIDs with level + trie constraints
5. **Evaluation**: Measure SID@K, Invalid-ID@K, qualitative examples

In [None]:
# Install dependencies (Colab)
!pip install -q transformers accelerate peft datasets bitsandbytes tiktoken sentencepiece jsonlines orjson

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.1/60.1 MB[0m [31m45.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import os
assert os.path.exists('/content/drive')
WORK_DIR = '/content/drive/MyDrive/colab/tiger_semantic_id_amazon_beauty'
%mkdir -p $WORK_DIR
%cd $WORK_DIR

/content/drive/MyDrive/colab/tiger_semantic_id_amazon_beauty


In [None]:
# Clone repo, install dependencies, and make src importable (Colab-friendly)
try:
    import google.colab  # type: ignore
    IN_COLAB = True
except Exception:
    IN_COLAB = False

repo_url = 'https://github.com/allyoushawn/recsys_playground.git'
repo_dir = 'recsys_playground'
branch_name = '20250908_tiger_dev'

import os
if IN_COLAB:
    if os.path.exists(repo_dir):
      !rm -rf {repo_dir}
    !git clone $repo_url
    %cd $repo_dir
    !git fetch --all
    !git checkout $branch_name || echo 'Branch not found; staying on default.'


Cloning into 'recsys_playground'...
remote: Enumerating objects: 450, done.[K
remote: Counting objects: 100% (197/197), done.[K
remote: Compressing objects: 100% (126/126), done.[K
remote: Total 450 (delta 139), reused 125 (delta 71), pack-reused 253 (from 1)[K
Receiving objects: 100% (450/450), 202.92 KiB | 6.34 MiB/s, done.
Resolving deltas: 100% (260/260), done.
/content/drive/MyDrive/colab/tiger_semantic_id_amazon_beauty/recsys_playground
Fetching origin
Branch '20250908_tiger_dev' set up to track remote branch '20250908_tiger_dev' from 'origin'.
Switched to a new branch '20250908_tiger_dev'


In [None]:
# Imports
import sys
from pathlib import Path

# Add src to path
sys.path.insert(0, f'{WORK_DIR}/recysys_playground/tiger_semantic_id_amazon_beauty/src')

# Config
ARTIFACTS_DIR = f'{WORK_DIR}/artifacts'
LLM_DIR = f'{ARTIFACTS_DIR}/llm'
#LLM_DIR = '/content/llm'
!mkdir -p $LLM_DIR

print(f"Artifacts: {ARTIFACTS_DIR}")
print(f"LLM outputs: {LLM_DIR}")

Artifacts: /content/drive/MyDrive/colab/tiger_semantic_id_amazon_beauty/artifacts
LLM outputs: /content/drive/MyDrive/colab/tiger_semantic_id_amazon_beauty/artifacts/llm


## 1. Build Dialogs

Convert user histories to chat-style JSONL format and build trie of valid SID continuations.

In [None]:
# Build dialogs and trie
!python -m tiger_semantic_id_amazon_beauty.src.llm.build_sid_dialogs \
    --artifacts_dir $ARTIFACTS_DIR \
    --out $LLM_DIR \
    --history_len 8 \
    --train_ratio 0.95

Loading artifacts from /content/drive/MyDrive/colab/tiger_semantic_id_amazon_beauty/artifacts...
Loaded 12101 semantic IDs
Building SID trie...
Saved trie to /content/drive/MyDrive/colab/tiger_semantic_id_amazon_beauty/artifacts/llm/sid_trie.pkl
  - valid_c2: 1 L1 codes
  - valid_c3: 1 (L1,L2) pairs
  - valid_c4: 1 (L1,L2,L3) triples
Loaded sequences for 22363 users
Creating dialogs...
Building dialogs: 100% 22363/22363 [00:00<00:00, 22366.82it/s]
Created 37052 dialogs
Average history length: 8.0 items
Split: 35199 train, 1853 valid
Saved training dialogs to /content/drive/MyDrive/colab/tiger_semantic_id_amazon_beauty/artifacts/llm/dialogs_train.jsonl
Saved validation dialogs to /content/drive/MyDrive/colab/tiger_semantic_id_amazon_beauty/artifacts/llm/dialogs_valid.jsonl

=== Example Dialog ===
SYSTEM:
You are a recommender that must reply ONLY with the next product's Semantic ID as 4 tokens in order: L1, L2, L3, L4.
Valid token ranges by level:
- L1: <sid_0>.. <sid_255>
- L2: <sid_25

In [None]:
# Verify outputs
import jsonlines

train_path = f'{LLM_DIR}/dialogs_train.jsonl'
valid_path = f'{LLM_DIR}/dialogs_valid.jsonl'
trie_path = f'{LLM_DIR}/sid_trie.pkl'

with jsonlines.open(train_path) as reader:
    train_dialogs = list(reader)

with jsonlines.open(valid_path) as reader:
    valid_dialogs = list(reader)

print(f"Train dialogs: {len(train_dialogs)}")
print(f"Valid dialogs: {len(valid_dialogs)}")
print(f"Trie exists: {os.path.exists(trie_path)}")

# Show example
print("\n=== Example Dialog ===")
example = train_dialogs[0]
for msg in example['messages']:
    print(f"{msg['role'].upper()}:")
    print(msg['content'][:200] + "..." if len(msg['content']) > 200 else msg['content'])
    print()

Train dialogs: 35199
Valid dialogs: 1853
Trie exists: True

=== Example Dialog ===
SYSTEM:
You are a recommender that must reply ONLY with the next product's Semantic ID as 4 tokens in order: L1, L2, L3, L4.
Valid token ranges by level:
- L1: <sid_0>.. <sid_255>
- L2: <sid_256>.. <sid_511>
...

USER:
History:
<sid_163> <sid_316> <sid_593> <sid_2384>
<sid_163> <sid_316> <sid_593> <sid_5325>
<sid_163> <sid_316> <sid_593> <sid_6447>
<sid_163> <sid_316> <sid_593> <sid_7323>
<sid_163> <sid_316> <sid_59...

ASSISTANT:
<sid_163> <sid_316> <sid_593> <sid_5593>



## 2. Tokenizer Resize

Add 1,027 new tokens to Qwen tokenizer and initialize embeddings.

In [None]:
# Resize tokenizer and model
!python -m tiger_semantic_id_amazon_beauty.src.llm.tokenizer_resize_qwen \
    --base Qwen/Qwen3-8B \
    --out $LLM_DIR/qwen3_vocab_stage \
    --torch_dtype bfloat16

## 3. Stage A: Vocabulary Extension

Fine-tune **only embeddings** to teach the model the new SID tokens.

In [None]:
# Stage A: Embeddings only
!python -m tiger_semantic_id_amazon_beauty.src.llm.finetune_qwen_vocab \
    --data $LLM_DIR/dialogs_train.jsonl \
    --valid $LLM_DIR/dialogs_valid.jsonl \
    --in_model $LLM_DIR/qwen3_vocab_stage \
    --out_model $LLM_DIR/qwen3_vocab_stage \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 8 \
    --learning_rate 5e-4 \
    --num_train_epochs 1 \
    --warmup_ratio 0.03 \
    --logging_steps 50 \
    --save_steps 500 \
    --bf16 \
    --gradient_checkpointing

2025-10-15 05:20:18.864402: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-15 05:20:18.884307: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760505618.905792    1721 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760505618.912446    1721 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1760505618.928870    1721 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [None]:
# Clear GPU memory before Stage B to avoid OOM
import torch
import gc

# Force garbage collection
gc.collect()

# Clear CUDA cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.synchronize()

print("GPU memory cleared")

GPU memory cleared


## 4. Stage B: LoRA Fine-tuning

Fine-tune with **LoRA adapters** for memory-efficient training (~20-25GB VRAM instead of ~60GB).

In [None]:
# Stage B: LoRA fine-tuning with memory optimizations
!python -m tiger_semantic_id_amazon_beauty.src.llm.finetune_qwen_lora \
    --data $LLM_DIR/dialogs_train.jsonl \
    --valid $LLM_DIR/dialogs_valid.jsonl \
    --in_model $LLM_DIR/qwen3_vocab_stage \
    --out_model $LLM_DIR/qwen3_lora_adapter \
    --sid_trie $LLM_DIR/sid_trie.pkl \
    --lora_r 16 \
    --lora_alpha 32 \
    --lora_dropout 0.05 \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 8 \
    --learning_rate 1e-4 \
    --num_train_epochs 1 \
    --warmup_ratio 0.03 \
    --logging_steps 50 \
    --save_steps 500 \
    --bf16 \
    --gradient_checkpointing

2025-10-15 07:15:01.890240: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-15 07:15:01.908126: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760512501.929232   30121 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760512501.935659   30121 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1760512501.952127   30121 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

## 5. Inference Demo

Generate SIDs with level and trie constraints.

In [None]:
# Interactive inference with LoRA adapter
from tiger_semantic_id_amazon_beauty.src.llm.inference_qwen import SIDRecommender
import json

# Load recommender with LoRA adapter
recommender = SIDRecommender(
    model_path=f'{LLM_DIR}/qwen3_lora_adapter',
    base_model_path=f'{LLM_DIR}/qwen3_vocab_stage',
    trie_path=f'{LLM_DIR}/sid_trie.pkl',
    is_lora_adapter=True,
)

# Load mappings
with open(f'{ARTIFACTS_DIR}/sid_to_items.json') as f:
    sid_to_items = json.load(f)

print("Recommender loaded!")

Loading model from /content/drive/MyDrive/colab/tiger_semantic_id_amazon_beauty/artifacts/llm/qwen3_lora_adapter...
Loading tokenizer from LoRA adapter: /content/drive/MyDrive/colab/tiger_semantic_id_amazon_beauty/artifacts/llm/qwen3_lora_adapter...
Loading base model from /content/drive/MyDrive/colab/tiger_semantic_id_amazon_beauty/artifacts/llm/qwen3_vocab_stage...


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading LoRA adapter from /content/drive/MyDrive/colab/tiger_semantic_id_amazon_beauty/artifacts/llm/qwen3_lora_adapter...
✓ Verified SID tokens in vocabulary (vocab_size=152696)
Loading trie from /content/drive/MyDrive/colab/tiger_semantic_id_amazon_beauty/artifacts/llm/sid_trie.pkl...
Model loaded!
Recommender loaded!


In [None]:
# Example: Generate from history
history_sids = [
    (64, 54, 125, 0),
    (64, 156, 194, 0),
    (112, 191, 11, 4),
]

results = recommender.recommend(
    history_sids=history_sids,
    sid_to_items=sid_to_items,
    top_k=5,
)

print("\n=== Generated Recommendation ===")
if results:
    result = results[0]
    print(f"Generated SID: {result['sid']}")
    print(f"Mapped Items:")
    for item_id in result['items']:
        print(f"  - {item_id}")
else:
    print("No valid SID generated")


=== Generated Recommendation ===
Generated SID: (163, 60, 81, 166)
Mapped Items:


In [None]:
results

[{'sid': (163, 60, 81, 166), 'items': []}]

In [None]:
sid_to_items

{'163-60-81-0': [0],
 '163-60-81-1': [1],
 '163-60-81-2': [2],
 '163-60-81-3': [3],
 '163-60-81-4': [4],
 '163-60-81-5': [5],
 '163-60-81-6': [6],
 '163-60-81-7': [7],
 '163-60-81-8': [8],
 '163-60-81-9': [9],
 '163-60-81-10': [10],
 '163-60-81-11': [11],
 '163-60-81-12': [12],
 '163-60-81-13': [13],
 '163-60-81-14': [14],
 '163-60-81-15': [15],
 '163-60-81-16': [16],
 '163-60-81-17': [17],
 '163-60-81-18': [18],
 '163-60-81-19': [19],
 '163-60-81-20': [20],
 '163-60-81-21': [21],
 '163-60-81-22': [22],
 '163-60-81-23': [23],
 '163-60-81-24': [24],
 '163-60-81-25': [25],
 '163-60-81-26': [26],
 '163-60-81-27': [27],
 '163-60-81-28': [28],
 '163-60-81-29': [29],
 '163-60-81-30': [30],
 '163-60-81-31': [31],
 '163-60-81-32': [32],
 '163-60-81-33': [33],
 '163-60-81-34': [34],
 '163-60-81-35': [35],
 '163-60-81-36': [36],
 '163-60-81-37': [37],
 '163-60-81-38': [38],
 '163-60-81-39': [39],
 '163-60-81-40': [40],
 '163-60-81-41': [41],
 '163-60-81-42': [42],
 '163-60-81-43': [43],
 '163-60

## 6. Evaluation

Measure SID@K, Invalid-ID rate, and qualitative examples.

In [None]:
# Evaluate on validation set
import numpy as np
from tqdm import tqdm

# Load validation dialogs
eval_size = min(1000, len(valid_dialogs))
eval_dialogs = valid_dialogs[:eval_size]

print(f"Evaluating on {eval_size} examples...")

# Metrics
invalid_count = 0
sid_hits = []
generated_sids = []

for dialog in tqdm(eval_dialogs):
    # Extract ground truth
    assistant_msg = dialog['messages'][2]['content']
    # Parse SID tokens from assistant message
    tokens = assistant_msg.split()
    if len(tokens) != 4:
        continue

    # Extract codes from tokens
    try:
        gt_codes = []
        for i, token in enumerate(tokens):
            code_num = int(token.split('_')[1].rstrip('>'))
            level_offset = i * 256
            code = code_num - level_offset
            gt_codes.append(code)
        gt_sid = tuple(gt_codes)
    except:
        continue

    # Extract history from user message
    user_msg = dialog['messages'][1]['content']
    history_lines = user_msg.split('\n')[1:-1]  # Skip "History:" and "Recommend next:"

    history_sids = []
    for line in history_lines:
        tokens = line.split()
        if len(tokens) != 4:
            continue
        try:
            codes = []
            for i, token in enumerate(tokens):
                code_num = int(token.split('_')[1].rstrip('>'))
                level_offset = i * 256
                code = code_num - level_offset
                codes.append(code)
            history_sids.append(tuple(codes))
        except:
            continue

    if not history_sids:
        continue

    # Generate SID
    try:
        generated_sid = recommender.generate_sid(history_sids=history_sids)

        if generated_sid is None:
            invalid_count += 1
            continue

        generated_sids.append(generated_sid)

        # Check if valid (exists in catalog)
        sid_key = ','.join(map(str, generated_sid))
        if sid_key not in sid_to_items:
            invalid_count += 1
            continue

        # Check if matches ground truth
        sid_hits.append(1 if generated_sid == gt_sid else 0)

    except Exception as e:
        print(f"Error: {e}")
        invalid_count += 1
        continue

print("\n=== Evaluation Results ===")
print(f"Examples evaluated: {len(sid_hits) + invalid_count}")
print(f"Invalid-ID@1: {invalid_count / (len(sid_hits) + invalid_count) * 100:.2f}%")
print(f"SID@1 (exact match): {np.mean(sid_hits) * 100:.2f}%" if sid_hits else "N/A")
print(f"Unique SIDs generated: {len(set(generated_sids))}")

Evaluating on 1000 examples...


100%|██████████| 1000/1000 [04:25<00:00,  3.77it/s]


=== Evaluation Results ===
Examples evaluated: 1000
Invalid-ID@1: 100.00%
N/A
Unique SIDs generated: 8





In [None]:
# Qualitative examples
print("\n=== Qualitative Examples ===")

test_histories = [
    [(91, 54, 165, 0), (146, 204, 254, 0), (225, 239, 96, 0)],
    [(229, 236, 102, 0), (225, 212, 226, 1)],
    [(94, 233, 248, 0), (180, 191, 245, 0), (89, 141, 245, 0)],
]

for i, history in enumerate(test_histories, 1):
    print(f"\n[Example {i}]")
    print(f"History: {history}")

    generated_sid = recommender.generate_sid(history_sids=history)
    print(f"Generated SID: {generated_sid}")

    if generated_sid:
        sid_key = ','.join(map(str, generated_sid))
        items = sid_to_items.get(sid_key, [])
        print(f"Mapped to {len(items)} items")
        if items:
            print(f"  Top item: {items[0]}")


=== Qualitative Examples ===

[Example 1]
History: [(91, 54, 165, 0), (146, 204, 254, 0), (225, 239, 96, 0)]
Generated SID: (163, 60, 81, 166)
Mapped to 0 items

[Example 2]
History: [(229, 236, 102, 0), (225, 212, 226, 1)]
Generated SID: (163, 60, 81, 166)
Mapped to 0 items

[Example 3]
History: [(94, 233, 248, 0), (180, 191, 245, 0), (89, 141, 245, 0)]
Generated SID: (163, 60, 81, 166)
Mapped to 0 items


## 7. Acceptance Criteria

Check if all acceptance criteria pass:

1. ✅ Stage A completes and new tokens are learned
2. ✅ Stage B completes with validation loss decreasing
3. ✅ Invalid-ID@1 = 0% (or very close)
4. ✅ SID@10 ≥ baseline
5. ✅ NL prompts produce valid SIDs

In [None]:
# Summary
print("\n" + "="*60)
print("ACCEPTANCE CRITERIA")
print("="*60)

print("\n[1] Stage A: Vocabulary Extension")
vocab_stage_path = f'{LLM_DIR}/qwen3_vocab_stage/pytorch_model.bin'
print(f"  Status: {'✅ PASS' if os.path.exists(vocab_stage_path) else '❌ FAIL'}")

print("\n[2] Stage B: Full Fine-tuning")
full_stage_path = f'{LLM_DIR}/qwen3_full_stage/pytorch_model.bin'
print(f"  Status: {'✅ PASS' if os.path.exists(full_stage_path) else '❌ FAIL'}")

print("\n[3] Invalid-ID Rate")
invalid_rate = invalid_count / (len(sid_hits) + invalid_count) * 100 if (len(sid_hits) + invalid_count) > 0 else 100
print(f"  Invalid-ID@1: {invalid_rate:.2f}%")
print(f"  Status: {'✅ PASS' if invalid_rate < 5.0 else '❌ FAIL'} (target: <5%)")

print("\n[4] SID@1 Exact Match")
sid_acc = np.mean(sid_hits) * 100 if sid_hits else 0
print(f"  SID@1: {sid_acc:.2f}%")
print(f"  Status: {'✅ PASS' if sid_acc > 0 else '❌ FAIL'} (target: >0%)")

print("\n[5] Qualitative Examples")
print(f"  Status: ✅ PASS (see examples above)")

print("\n" + "="*60)
all_pass = (
    os.path.exists(vocab_stage_path) and
    os.path.exists(full_stage_path) and
    invalid_rate < 5.0 and
    sid_acc > 0
)
print(f"OVERALL: {'🎉 ALL CHECKS PASSED!' if all_pass else '❌ SOME CHECKS FAILED'}")
print("="*60)


ACCEPTANCE CRITERIA

[1] Stage A: Vocabulary Extension
  Status: ❌ FAIL

[2] Stage B: Full Fine-tuning
  Status: ❌ FAIL

[3] Invalid-ID Rate
  Invalid-ID@1: 100.00%
  Status: ❌ FAIL (target: <5%)

[4] SID@1 Exact Match
  SID@1: 0.00%
  Status: ❌ FAIL (target: >0%)

[5] Qualitative Examples
  Status: ✅ PASS (see examples above)

OVERALL: ❌ SOME CHECKS FAILED
