# üß∂ AlphaKnit ‚Äî Train KnittingTransformer on Colab

**AlphaKnit v6.6-F: Scientific Falsification & Discovery**

Notebook n√†y gi√∫p b·∫°n train m√¥ h√¨nh `KnittingTransformer` tr√™n Google Colab v·ªõi GPU T4/A100.

**Pipeline:** Point Cloud (`.npy`) ‚Üí Encoder (PointNet) ‚Üí Transformer Decoder ‚Üí Edge-Action Sequence (stitch tokens)

---

## 1. üîß Setup & Install Dependencies

In [None]:
# Ki·ªÉm tra GPU
!nvidia-smi

import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

In [None]:
# Mount Google Drive (ƒë·ªÉ l∆∞u checkpoints & dataset)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone repo t·ª´ GitHub (thay URL n·∫øu c·∫ßn)
# N·∫øu repo private, d√πng: git clone https://<TOKEN>@github.com/user/AlphaKnit.git

import os

REPO_DIR = "/content/AlphaKnit"

if not os.path.exists(REPO_DIR):
    # === OPTION A: Clone t·ª´ GitHub ===
    # !git clone https://github.com/<your-username>/AlphaKnit.git {REPO_DIR}

    # === OPTION B: Copy t·ª´ Google Drive ===
    !cp -r "/content/drive/MyDrive/AlphaKnit" {REPO_DIR}

    print(f"‚úÖ Project ready at {REPO_DIR}")
else:
    print(f"‚úÖ Project already exists at {REPO_DIR}")

In [None]:
# C√†i dependencies
!pip install -q scipy trimesh networkx webdataset tqdm matplotlib

In [None]:
# Th√™m project v√†o Python path
import sys
sys.path.insert(0, os.path.join(REPO_DIR, "src"))
sys.path.insert(0, os.path.join(REPO_DIR, "scripts"))

# Verify import
from alphaknit import config
from alphaknit.model import KnittingTransformer
from alphaknit.knitting_dataset import KnittingDataset, make_dataloaders

print(f"‚úÖ AlphaKnit imported successfully")
print(f"   Vocab size: {config.VOCAB_SIZE}")
print(f"   D_MODEL: {config.D_MODEL}, N_HEADS: {config.N_HEADS}, N_LAYERS: {config.N_LAYERS}")
print(f"   MAX_SEQ_LEN: {config.MAX_SEQ_LEN}, N_POINTS: {config.N_POINTS}")

## 2. üèóÔ∏è Generate Dataset (Optional)

N·∫øu b·∫°n ch∆∞a c√≥ dataset, ch·∫°y cell d∆∞·ªõi ƒë√¢y ƒë·ªÉ t·∫°o dataset tr·ª±c ti·∫øp th√†nh **WebDataset shards** tr√™n Google Drive.

Pipeline: `SpatialGeneratorV2` ‚Üí `GraphValidator` ‚Üí `ForwardSimulator` ‚Üí PCA Align ‚Üí Tensorize ‚Üí Pack `.tar` shards

**∆Øu ƒëi·ªÉm WebDataset shards:**
- Tr√°nh gi·ªõi h·∫°n 100k files c·ªßa Google Drive FUSE
- I/O streaming nhanh h∆°n nhi·ªÅu so v·ªõi ƒë·ªçc t·ª´ng file
- Shuffle entropy t·ªët h∆°n

> ‚ö†Ô∏è **Th·ªùi gian ∆∞·ªõc t√≠nh:** ~5k samples/10 ph√∫t tr√™n Colab CPU. 50k samples ‚âà 1.5-2 gi·ªù.

In [None]:
# ============================================================
# C·∫§U H√åNH DATASET GENERATION
# ============================================================
N_SAMPLES  = 50000   # T·ªïng s·ªë samples c·∫ßn t·∫°o
SHARD_SIZE = 1000    # S·ªë samples m·ªói shard (.tar file)

# Output tr√™n Google Drive (persistent, kh√¥ng m·∫•t khi session timeout)
SHARDS_OUTPUT_DIR = "/content/drive/MyDrive/AlphaKnit/data/processed/shards_phase9b_full"

print(f"Will generate {N_SAMPLES} samples into {N_SAMPLES // SHARD_SIZE} shards")
print(f"Output: {SHARDS_OUTPUT_DIR}")

In [None]:
import json
import tarfile
import tempfile
import numpy as np
import torch
from tqdm.notebook import tqdm

from alphaknit.dataset_builder import DatasetBuilder
from alphaknit import config
from pack_tensor_dataset import build_tensor_sample

os.makedirs(SHARDS_OUTPUT_DIR, exist_ok=True)

# Check for existing shards ‚Äî cho ph√©p resume generation
existing_shards = sorted([f for f in os.listdir(SHARDS_OUTPUT_DIR) if f.endswith('.tar')])
if existing_shards:
    print(f"‚ö†Ô∏è Found {len(existing_shards)} existing shards in output dir.")
    print(f"   Last shard: {existing_shards[-1]}")
    resume_from = len(existing_shards)
    samples_already_done = resume_from * SHARD_SIZE
    print(f"   Resuming from shard {resume_from} (‚âà{samples_already_done} samples done)")
else:
    resume_from = 0
    samples_already_done = 0

remaining = N_SAMPLES - samples_already_done
if remaining <= 0:
    print(f"‚úÖ Already have enough shards ({samples_already_done} ‚â• {N_SAMPLES})")
else:
    print(f"\nüöÄ Generating {remaining} samples...")
    builder = DatasetBuilder(output_dir="/tmp/_alphaknit_gen_dummy")
    temp_dir = tempfile.mkdtemp()

    skipped = 0
    samples_generated = 0
    shard_id = resume_from

    pbar = tqdm(total=remaining, desc="Generating samples")

    while samples_generated < remaining:
        shard_path = os.path.join(SHARDS_OUTPUT_DIR, f"shard-{shard_id:04d}.tar")
        samples_in_this_shard = min(SHARD_SIZE, remaining - samples_generated)

        with tarfile.open(shard_path, "w") as tar:
            count_in_shard = 0
            while count_in_shard < samples_in_this_shard:
                global_idx = samples_already_done + samples_generated + count_in_shard
                raw_sample = builder._generate_one(global_idx)

                if raw_sample is None:
                    skipped += 1
                    if skipped > N_SAMPLES * 5:
                        print(f"‚ö†Ô∏è Too many skipped samples ({skipped}). Stopping.")
                        break
                    continue

                name = raw_sample['id']
                pc = raw_sample.pop("point_cloud")

                # T·∫°m l∆∞u JSON + NPY ƒë·ªÉ ƒëi qua tensorizer
                tmp_json = os.path.join(temp_dir, f"{name}.json")
                tmp_npy  = os.path.join(temp_dir, f"{name}.npy")

                with open(tmp_json, "w") as f:
                    json.dump(raw_sample, f)
                np.save(tmp_npy, pc)

                # Tensorize: pad point cloud + build src/tgt token sequences
                tensor_sample = build_tensor_sample(
                    tmp_json, tmp_npy,
                    config.MAX_SEQ_LEN, config.N_POINTS
                )

                # Pack v√†o tar shard
                tmp_pt = os.path.join(temp_dir, f"{name}.pt")
                torch.save(tensor_sample, tmp_pt)
                tar.add(tmp_pt, arcname=f"{name}.pt")

                # D·ªçn tmp
                os.remove(tmp_json)
                os.remove(tmp_npy)
                os.remove(tmp_pt)

                count_in_shard += 1
                samples_generated += 1
                pbar.update(1)

        shard_id += 1

    pbar.close()
    os.rmdir(temp_dir)

    total_shards = len([f for f in os.listdir(SHARDS_OUTPUT_DIR) if f.endswith('.tar')])
    print(f"\nüéâ Done! Total: {total_shards} shards, {samples_already_done + samples_generated} samples")
    print(f"   Skipped (invalid): {skipped}")
    print(f"   Saved to: {SHARDS_OUTPUT_DIR}")

In [None]:
# Verify shards
shard_files = sorted([f for f in os.listdir(SHARDS_OUTPUT_DIR) if f.endswith('.tar')])
total_size_mb = sum(os.path.getsize(os.path.join(SHARDS_OUTPUT_DIR, f)) for f in shard_files) / 1e6

print(f"üì¶ {len(shard_files)} shards | {total_size_mb:.0f} MB total")
for f in shard_files[:5]:
    size = os.path.getsize(os.path.join(SHARDS_OUTPUT_DIR, f)) / 1e6
    print(f"   {f} ({size:.1f} MB)")
if len(shard_files) > 5:
    print(f"   ... ({len(shard_files) - 5} more shards)")

## 3. üì¶ Dataset Configuration

Dataset g·ªìm WebDataset `.tar` shards, m·ªói shard ch·ª©a `sample_XXXXX.pt` ƒë√£ tensorize s·∫µn:
- `pc`: Point cloud `(N_POINTS, 3)` float32
- `src`: Teacher forcing input `(MAX_SEQ_LEN, 3)` long ‚Äî `<SOS> + edge_tuples`
- `tgt`: Prediction target `(MAX_SEQ_LEN, 3)` long ‚Äî `edge_tuples + <EOS>`

In [None]:
# ============================================================
# C·∫§U H√åNH DATASET PATH
# ============================================================

# WebDataset shards tr√™n Google Drive
DATASET_DIR = "/content/drive/MyDrive/AlphaKnit/data/processed/shards_phase9b_full/shard-{0000..0049}.tar"

# N·∫øu mu·ªën d√πng Map-Style dataset (folder ch·ª©a .json + .npy):
# DATASET_DIR = "/content/drive/MyDrive/AlphaKnit/data/processed/dataset"

print(f"Dataset: {DATASET_DIR}")

In [None]:
# Preview 1 sample t·ª´ shard
import tarfile
import io

if '.tar' in DATASET_DIR:
    # WebDataset: ƒë·ªçc tr·ª±c ti·∫øp t·ª´ tar
    base_dir = DATASET_DIR.split('{')[0]
    first_shard = base_dir + "0000.tar"
    if os.path.exists(first_shard):
        with tarfile.open(first_shard, 'r') as tar:
            member = tar.getmembers()[0]
            f = tar.extractfile(member)
            sample = torch.load(io.BytesIO(f.read()), map_location='cpu', weights_only=False)
        
        print(f"Sample from {os.path.basename(first_shard)} ‚Üí {member.name}")
        print(f"  pc shape:  {sample['pc'].shape}  dtype: {sample['pc'].dtype}")
        print(f"  src shape: {sample['src'].shape}  dtype: {sample['src'].dtype}")
        print(f"  tgt shape: {sample['tgt'].shape}  dtype: {sample['tgt'].dtype}")
        
        # Decode first few tokens
        print(f"\n  First 5 src tuples (type, p1, p2):")
        for i in range(min(5, sample['src'].shape[0])):
            t, p1, p2 = sample['src'][i].tolist()
            name = config.ID_TO_TOKEN.get(t, f'<ID:{t}>')
            print(f"    [{i}] {name:8s} (p1={p1}, p2={p2})")
    else:
        print(f"‚ö†Ô∏è First shard not found: {first_shard}")
        print(f"   Available files: {os.listdir(os.path.dirname(first_shard))[:5]}")
else:
    # Map-Style: ƒë·ªçc tr·ª±c ti·∫øp json + npy
    import json
    import numpy as np
    if os.path.isdir(DATASET_DIR):
        sample_files = sorted([f for f in os.listdir(DATASET_DIR) if f.endswith('.json')])
        if sample_files:
            sid = sample_files[0].replace('.json', '')
            with open(os.path.join(DATASET_DIR, f"{sid}.json")) as f:
                meta = json.load(f)
            pc = np.load(os.path.join(DATASET_DIR, f"{sid}.npy"))
            print(f"Sample: {sid}")
            print(f"  Point cloud: {pc.shape}, range [{pc.min():.3f}, {pc.max():.3f}]")
            edge_seq = meta.get('edge_sequence', [])
            print(f"  Edge sequence: {len(edge_seq)} tuples")
            for t, p1, p2 in edge_seq[:5]:
                print(f"    {config.ID_TO_TOKEN.get(t, f'<ID:{t}>')}(p1={p1}, p2={p2})")

## 4. üöÄ Training Configuration

In [None]:
# ============================================================
# HYPERPARAMETERS ‚Äî T·ªëi ∆∞u cho Colab T4 (16GB VRAM)
# ============================================================

TRAIN_CONFIG = {
    # Dataset
    "dataset_dir": DATASET_DIR,

    # Checkpoint ‚Äî L∆∞u l√™n Google Drive ƒë·ªÉ kh√¥ng m·∫•t khi session timeout
    "checkpoint_dir": "/content/drive/MyDrive/AlphaKnit/checkpoints",
    "run_name": "colab_v6.6F",

    # Model architecture
    "d_model": 128,
    "n_heads": 4,
    "n_layers": 3,
    "ffn_dim": 256,

    # Training
    "epochs": 50,
    "batch_size": 64,          # T4: 64 ok, A100: c√≥ th·ªÉ tƒÉng l√™n 128-256
    "lr": 1e-3,
    "grad_accum_steps": 2,     # Effective batch = 64 * 2 = 128
    "label_smoothing": 0.1,
    "scheduler_type": "cosine",

    # Data loading
    "num_workers": 2,
    "val_split": 0.1,

    # Phase transition
    "early_stop_patience": 10,
    "log_compile_every": 5,

    # Device
    "device_str": "auto",
}

# T·∫°o th∆∞ m·ª•c checkpoint
os.makedirs(TRAIN_CONFIG["checkpoint_dir"], exist_ok=True)
print("‚úÖ Training config ready")
for k, v in TRAIN_CONFIG.items():
    print(f"   {k}: {v}")

## 5. üèãÔ∏è Train Model

In [None]:
from alphaknit.train import train

# ‚ñ∂ B·∫ÆT ƒê·∫¶U TRAINING
history = train(**TRAIN_CONFIG)

print(f"\nüéâ Training complete! {len(history)} epochs recorded.")
print(f"üìÅ Checkpoints saved to: {TRAIN_CONFIG['checkpoint_dir']}")

## 6. üîÑ Resume Training (n·∫øu b·ªã ng·∫Øt gi·ªØa ch·ª´ng)

N·∫øu Colab session b·ªã timeout, ch·∫°y l·∫°i t·ª´ **Cell 1 (Setup)** r·ªìi nh·∫£y th·∫≥ng xu·ªëng ƒë√¢y.

In [None]:
from alphaknit.train import train

RESUME_CONFIG = {
    **TRAIN_CONFIG,
    "resume_auto": True,       # T·ª± t√¨m checkpoint m·ªõi nh·∫•t
    "epochs": 100,             # T·ªïng epochs mu·ªën ƒë·∫°t ƒë∆∞·ª£c
    # "force_phase2": True,    # Uncomment n·∫øu mu·ªën force Physics phase
    # "reset_optimizer": True, # Uncomment n·∫øu mu·ªën reset optimizer cho phase transition
}

history = train(**RESUME_CONFIG)

print(f"\nüéâ Resumed training complete! {len(history)} epochs recorded.")

## 7. üìä Training Visualization

In [None]:
import json
import matplotlib.pyplot as plt

# Load history t·ª´ file
history_path = os.path.join(
    TRAIN_CONFIG["checkpoint_dir"],
    f"training_history_{TRAIN_CONFIG['run_name']}.json"
)

if os.path.exists(history_path):
    with open(history_path) as f:
        history = json.load(f)
    print(f"‚úÖ Loaded {len(history)} epochs from history")
else:
    print("‚ö†Ô∏è No history file found. Run training first.")
    history = []

In [None]:
if history:
    epochs_list = [h["epoch"] for h in history]

    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle("AlphaKnit Training Dashboard", fontsize=16, fontweight="bold")

    # 1. Loss curves
    ax = axes[0, 0]
    ax.plot(epochs_list, [h["train_loss"] for h in history], label="Train", color="#2196F3")
    ax.plot(epochs_list, [h["val_loss"] for h in history], label="Val", color="#F44336", ls="--")
    ax.set_title("Loss"); ax.set_xlabel("Epoch"); ax.legend(); ax.grid(True, alpha=0.3)

    # 2. Entropy
    ax = axes[0, 1]
    ax.plot(epochs_list, [h.get("train_entropy", 0) for h in history], color="#9C27B0")
    ax.set_title("Token Entropy"); ax.set_xlabel("Epoch"); ax.grid(True, alpha=0.3)

    # 3. Structural Accuracy
    ax = axes[0, 2]
    ax.plot(epochs_list, [h.get("struct_acc", 0) for h in history], color="#4CAF50")
    ax.set_title("Structural Top-1 Acc"); ax.set_xlabel("Epoch"); ax.set_ylim(0, 1); ax.grid(True, alpha=0.3)

    # 4. Compile Success Rate
    ax = axes[1, 0]
    ce = [h["epoch"] for h in history if "compile_success_rate" in h]
    cr = [h["compile_success_rate"] for h in history if "compile_success_rate" in h]
    if ce: ax.plot(ce, cr, 'o-', color="#FF9800", ms=6)
    ax.set_title("Compile Success Rate"); ax.set_xlabel("Epoch"); ax.set_ylim(0, 1); ax.grid(True, alpha=0.3)

    # 5. Phase Lag
    ax = axes[1, 1]
    ax.plot(epochs_list, [h.get("phase_lag", 0) for h in history], color="#00BCD4")
    ax.set_title("Phase Lag"); ax.set_xlabel("Epoch"); ax.grid(True, alpha=0.3)

    # 6. PDI & Tension
    ax = axes[1, 2]
    ax.plot(epochs_list, [h.get("train_pdi", 0) for h in history], label="PDI", color="#E91E63")
    ax.plot(epochs_list, [h.get("train_tension", 0) for h in history], label="Tension", color="#795548")
    ax.set_title("PDI & Tension"); ax.set_xlabel("Epoch"); ax.legend(); ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(os.path.join(TRAIN_CONFIG["checkpoint_dir"], "training_dashboard.png"), dpi=150)
    plt.show()
    print("üìä Dashboard saved!")

## 8. üíæ Export Best Model

In [None]:
# Ki·ªÉm tra c√°c checkpoints
ckpt_dir = TRAIN_CONFIG["checkpoint_dir"]
ckpt_files = sorted([f for f in os.listdir(ckpt_dir) if f.endswith('.pt')])

print(f"üìÅ Checkpoints in {ckpt_dir}:")
for f in ckpt_files:
    size_mb = os.path.getsize(os.path.join(ckpt_dir, f)) / 1e6
    print(f"   {f} ({size_mb:.1f} MB)")

In [None]:
# Download best model v·ªÅ m√°y local
from google.colab import files

best_model_path = os.path.join(ckpt_dir, f"best_model_{TRAIN_CONFIG['run_name']}.pt")
if os.path.exists(best_model_path):
    ckpt = torch.load(best_model_path, map_location="cpu", weights_only=True)
    print(f"Best model ‚Äî Epoch: {ckpt['epoch']}, Val Loss: {ckpt['val_loss']:.4f}")
    files.download(best_model_path)
else:
    print("‚ö†Ô∏è Best model not found. Train first!")

## 9. üî¨ Quick Inference Test

In [None]:
# Load best model v√† test inference
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = KnittingTransformer(
    d_model=TRAIN_CONFIG["d_model"],
    n_heads=TRAIN_CONFIG["n_heads"],
    n_layers=TRAIN_CONFIG["n_layers"],
    ffn_dim=TRAIN_CONFIG["ffn_dim"],
).to(device)

best_model_path = os.path.join(ckpt_dir, f"best_model_{TRAIN_CONFIG['run_name']}.pt")
if os.path.exists(best_model_path):
    ckpt = torch.load(best_model_path, map_location=device, weights_only=True)
    model.load_state_dict(ckpt["model_state"])
    model.eval()
    print(f"‚úÖ Model loaded (Epoch {ckpt['epoch']}, Val Loss {ckpt['val_loss']:.4f})")

    # Load 1 sample t·ª´ shard ƒë·ªÉ test
    import tarfile, io
    base_dir = DATASET_DIR.split('{')[0] if '{' in DATASET_DIR else DATASET_DIR + '/'
    first_shard = base_dir + "0000.tar" if '.tar' in DATASET_DIR else None

    if first_shard and os.path.exists(first_shard):
        with tarfile.open(first_shard, 'r') as tar:
            member = tar.getmembers()[0]
            f = tar.extractfile(member)
            sample = torch.load(io.BytesIO(f.read()), map_location='cpu', weights_only=False)
        pc = sample['pc'].unsqueeze(0).to(device)
    elif os.path.isdir(DATASET_DIR):
        dataset = KnittingDataset(DATASET_DIR)
        sample = dataset[0]
        pc = sample['point_cloud'].unsqueeze(0).to(device)
    else:
        pc = None
        print("‚ö†Ô∏è No dataset found for inference test.")

    if pc is not None:
        with torch.no_grad():
            pred_tuples = model.greedy_decode(pc, max_len=config.MAX_SEQ_LEN)

        pred = pred_tuples[0]
        print(f"\nüìã Generated sequence ({len(pred)} tuples):")
        for i, (t, p1, p2) in enumerate(pred[:20]):
            token_name = config.ID_TO_TOKEN.get(t, f"<ID:{t}>")
            print(f"   [{i:3d}] {token_name:8s} (p1={p1}, p2={p2})")
        if len(pred) > 20:
            print(f"   ... ({len(pred) - 20} more tuples)")

        # Compile test
        try:
            from alphaknit.compiler import KnittingCompiler
            compiler = KnittingCompiler()
            tokens = [f"{config.ID_TO_TOKEN.get(t, '<UNK>')}({p1},{p2})" for t, p1, p2 in pred]
            graph = compiler.compile(tokens)
            print(f"\n‚úÖ Compile SUCCESS! Graph: {len(graph.nodes)} nodes.")
        except Exception as e:
            print(f"\n‚ùå Compile failed: {e}")
else:
    print("‚ö†Ô∏è No trained model found. Run training first!")