In [None]:
# 检查GPU状态
!nvidia-smi


In [None]:
# 安装Kaggle用于数据下载
%pip install kaggle


In [None]:
# 配置Kaggle API
import os
import json

# 上传kaggle.json到根目录
kaggle_config = {
    "username": "your_username",
    "key": "your_api_key"
}

# 创建.kaggle目录并保存配置
os.makedirs(os.path.expanduser('~/.kaggle'), exist_ok=True)
with open(os.path.expanduser('~/.kaggle/kaggle.json'), 'w') as f:
    json.dump(kaggle_config, f)
os.chmod(os.path.expanduser('~/.kaggle/kaggle.json'), 0o600)

print("✓ Kaggle API配置完成")


In [None]:
# 克隆LLaVA仓库
import os

if os.path.exists("LLaVA"):
    print("✓ LLaVA仓库已存在")
else:
    print("正在克隆LLaVA仓库...")
    !git clone https://github.com/haotian-liu/LLaVA.git


In [None]:
# 进入LLaVA目录并安装
%cd LLaVA

try:
    import llava
    print("✓ LLaVA已安装")
except ImportError:
    print("正在安装LLaVA...")
    %pip install -e .


In [None]:
# 下载和准备Flickr30k数据集
import os
import glob

if os.path.exists('./data/flickr30k/Images/flickr30k_images'):
    print("✓ Flickr30k数据集已存在")
    print(f"找到 {len(glob.glob('./data/flickr30k/Images/flickr30k_images/*.jpg'))} 个图片文件")
else:
    print("正在下载Flickr30k数据集...")
    !kaggle datasets download -d hsankesara/flickr-image-dataset
    
    if os.path.exists('flickr-image-dataset.zip'):
        print("正在解压数据集...")
        !unzip -q flickr-image-dataset.zip -d ./data/
        !rm flickr-image-dataset.zip
    else:
        print("警告: 数据集文件未找到")


In [None]:
# 将Flickr30k数据转换为LLaVA格式
import os
import json

converted_file = "./data/flickr30k/train.json"

if os.path.exists(converted_file):
    print("✓ Dataset already converted to LLaVA format.")
    try:
        with open(converted_file, 'r') as f:
            data = json.load(f)
        print(f"Found {len(data)} data records.")
    except Exception as e:
        print(f"Warning: Could not read {converted_file}: {e}")
else:
    print("Converting Flickr30k dataset to LLaVA format...")
    
    # 使用官方转换脚本
    convert_script = "llava/data/datasets/convert_flickr30k.py"
    if os.path.exists(convert_script):
        print("Using official LLaVA conversion script...")
        try:
            import subprocess
            subprocess.run(["python", convert_script, "--root", "./data/flickr30k"], check=True)
            print("✓ Official script executed successfully.")
        except Exception as e:
            print(f"Error executing official script: {e}")
    else:
        print("Official conversion script not found. Please check LLaVA installation.")


In [None]:
# LoRA 微调训练命令
training_command = """
python llava/train/train_mem.py \\
    --model-name llava-hf/llava-1.6-vicuna-7b-hf \\
    --data ./data/flickr30k/train.json \\
    --val-data ./data/flickr30k/val.json \\
    --lora-r 16 --lora-alpha 32 --lora-dropout 0.05 \\
    --output-dir checkpoints/lora_flickr_ep1 \\
    --num-epochs 1 --batch-size 4 --lr 1e-4 --fp16
"""

print("LoRA 微调训练命令：")
print(training_command)

print("\n参数详解：")
print("--model-name: 使用的基础模型")
print("--data: 训练数据路径") 
print("--val-data: 验证数据路径")
print("--lora-r: LoRA rank参数，控制低秩矩阵的维度")
print("--lora-alpha: LoRA alpha参数，缩放因子")
print("--lora-dropout: LoRA dropout率，防止过拟合")


In [None]:
# 训练前环境检查
import os
import torch

print("=== 训练环境检查 ===")

# 1. 检查GPU状态
print("\n1. GPU状态检查：")
if torch.cuda.is_available():
    print(f"✅ GPU可用")
    print(f"GPU数量: {torch.cuda.device_count()}")
    print(f"当前GPU: {torch.cuda.get_device_name(0)}")
    print(f"显存总量: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
else:
    print("❌ GPU不可用，建议在GPU环境中运行")

# 2. 检查数据文件
print("\n2. 数据文件检查：")
data_files = [
    "./data/flickr30k/train.json",
    "./data/flickr30k/val.json"
]

for file_path in data_files:
    if os.path.exists(file_path):
        print(f"✅ {file_path} 存在")
    else:
        print(f"❌ {file_path} 不存在，请确保数据已准备")

# 3. 创建输出目录
print("\n3. 创建输出目录：")
output_dir = "checkpoints/lora_flickr_ep1"
os.makedirs(output_dir, exist_ok=True)
print(f"✅ 输出目录已创建: {output_dir}")

print("\n=== 准备就绪，可以开始训练 ===")


In [None]:
# 生成预测与计算指标
print("📊 评估步骤：")

# 1. 生成预测
eval_command = """
python llava/eval/cli_eval.py \\
    --ckpt checkpoints/lora_flickr_ep1 \\
    --val-data ./data/flickr30k/val.json \\
    --outfile pred.json
"""
print("1. 生成预测:")
print(eval_command)

# 2. 计算指标
metrics_command = """
python scripts/eval_caption.py \\
    --pred-file pred.json \\
    --gt-file ./data/flickr30k/val.json
"""
print("\n2. 计算指标:")
print(metrics_command)

print("\n📋 记录对比: 原模型 vs LoRA微调后提升值")


In [None]:
# 生成演示素材（5张图片 + 模型输出对比）
import os, glob, random

# 创建演示目录和选择图片
os.makedirs("./assets", exist_ok=True)
image_files = glob.glob('./data/flickr30k/Images/flickr30k_images/*.jpg')

if len(image_files) >= 5:
    demo_images = random.sample(image_files, 5)
    
    # 创建演示文档模板
    with open("./assets/week1_demo.md", 'w') as f:
        f.write("# LLaVA LoRA 微调演示\n\n")
        for i, img_path in enumerate(demo_images, 1):
            img_name = os.path.basename(img_path)
            f.write(f"### 示例 {i}: {img_name}\n")
            f.write("**原模型**: [待填入]\n")
            f.write("**LoRA微调**: [待填入]\n\n")
    
    print("✅ 演示文档已创建: ./assets/week1_demo.md")
    print(f"✅ 选中演示图片: {[os.path.basename(p) for p in demo_images]}")
    print("📋 后续: 用训练后模型生成描述，填入对比结果")
else:
    print("❌ 图片不足，请检查数据集")
