# Generative Molecular Captioning with Qwen3-0.6B

Two-stage training pipeline:
1. **Stage 1 (Alignment)**: Train projector to align GNN embeddings with LLM text space
2. **Stage 2 (SFT)**: Fine-tune projector + LoRA adapters for caption generation

In [None]:
# Cell 1.5: Google Colab Setup
try:
    from google.colab import drive
    import os
    print("Running on Google Colab. Setting up repository...")

    # Clone the repository if not already present
    REPO_DIR = "/content/altegrad_kaggle"
    REPO_URL = "https://github.com/AxENSRennes/altegrad_kaggle.git"
    BRANCH = "axel"

    if not os.path.exists(REPO_DIR):
        print(f"Cloning {REPO_URL} (branch: {BRANCH})...")
        !git lfs install
        !git clone -b {BRANCH} {REPO_URL} {REPO_DIR}
        print("Repository cloned successfully.")
    else:
        print(f"Repository already exists at {REPO_DIR}")
        %cd {REPO_DIR}
        !git pull origin {BRANCH}
        !git lfs pull

    %cd {REPO_DIR}
    print(f"Working directory: {os.getcwd()}")

except ImportError:
    print("Not running on Google Colab (ImportError).")
except Exception as e:
    print(f"Error during Colab setup: {e}")

In [None]:
# Cell 1: Install Dependencies
!pip install -q transformers>=4.36 peft bitsandbytes accelerate wandb rich nltk torch-geometric

In [None]:
# Cell 2: Imports and Path Setup
import sys
import os

# Auto-detect environment
if os.path.exists("/kaggle/input/mol-caption-code"):
    sys.path.insert(0, "/kaggle/input/mol-caption-code")
    print("Running on Kaggle")
elif os.path.exists("/content/altegrad_kaggle/mol-caption-code"):
    sys.path.insert(0, "/content/altegrad_kaggle/mol-caption-code")
    print("Running on Colab")
else:
    sys.path.insert(0, "./mol-caption-code")
    print("Running locally")

import torch
from config import get_config
from model_wrapper import create_model
from train_stage1 import train_stage1
from train_stage2 import train_stage2
from inference import run_inference
from metrics import compute_metrics
from report import print_training_report
from utils import set_seed, WandBLogger

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Cell 3: Configure Experiment Mode
# Choose mode: "quick" (5min test), "medium" (1h), "full" (9h)
config = get_config(
    mode="quick",  # Change to "medium" or "full" for longer training
    use_wandb=False,  # Set to True to enable W&B logging
)

print(f"Experiment mode: {config.experiment_mode}")
print(f"Stage 1 epochs: {config.stage1_epochs}")
print(f"Stage 2 epochs: {config.stage2_epochs}")
print(f"Train subset: {config.train_subset or 'all'}")

In [None]:
# Cell 4: W&B Initialization (Optional)
if config.use_wandb:
    import wandb
    wandb.login()
    logger = WandBLogger(enabled=True)
    logger.init(config.wandb_project, config, tags=[config.experiment_mode])
    print("W&B initialized")
else:
    logger = None
    print("W&B disabled - set config.use_wandb=True to enable")

In [None]:
# Cell 5: Set Seed & Create Model
set_seed(config.seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

model = create_model(config, device)
model.print_trainable_parameters()

In [None]:
# Cell 6: Stage 1 Training (Alignment)
print("=" * 50)
print("STAGE 1: Alignment Training")
print("=" * 50)

stage1_metrics = train_stage1(model, config, logger=logger)
print_training_report("Stage 1", stage1_metrics, config)

In [None]:
# Cell 7: Stage 2 Training (SFT)
print("=" * 50)
print("STAGE 2: Supervised Fine-Tuning")
print("=" * 50)

stage2_metrics = train_stage2(model, config, load_stage1=True, logger=logger)
print_training_report("Stage 2", stage2_metrics, config)

In [None]:
# Cell 8: Generate Submission (Full Mode)
if config.experiment_mode == "full":
    print("=" * 50)
    print("Generating Submission")
    print("=" * 50)

    results = run_inference(config)
    print(f"Submission saved to: {config.submission_path}")
else:
    print("Skipping submission generation (not in full mode)")
    print("Run with mode='full' for full training and submission")

In [None]:
# Cell 9: Cleanup
if logger:
    logger.finish()
    print("W&B run finished")

print("Done!")