# PartGraph 改进版训练

## 主要改进
1. **多任务学习**: 分类损失 + 重建损失 (分类为主)
2. **增加容量**: num_slots=8, slot_dim=256
3. **多样性损失**: 防止 slot collapse
4. **改进初始化**: 每个 slot 独立初始化

In [1]:
# 1. 检查 GPU
!nvidia-smi

Sat Nov 29 06:45:52 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   50C    P8             10W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
# 2. 克隆/更新代码
!git clone https://github.com/alltobebetter/PartGraph-FewShot.git 2>/dev/null || (cd PartGraph-FewShot && git pull)
%cd PartGraph-FewShot

/content/PartGraph-FewShot


In [3]:
# 3. 安装依赖
!pip install -q -r requirements.txt

In [4]:
# 4. 下载 CUB-200 数据集
!mkdir -p data
%cd data
!wget -q -N https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz
!tar -xzf CUB_200_2011.tgz 2>/dev/null || echo 'Already extracted'
%cd ..

/content/PartGraph-FewShot/data
/content/PartGraph-FewShot


In [5]:
# 5. 运行改进版训练 (batch_size=16 以节省显存)
!python src/train_improved.py \
    --data_root ./data \
    --output_dir ./checkpoints_improved \
    --batch_size 16 \
    --epochs 30 \
    --num_slots 8 \
    --slot_dim 256 \
    --lambda_cls 1.0 \
    --lambda_recon 0.1 \
    --lambda_div 0.01

Using device: cuda
Config: num_slots=8, slot_dim=256
Dataset loaded. 5994 training images, 200 classes.
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100% 44.7M/44.7M [00:00<00:00, 180MB/s] 
Epoch 1/30:   0% 0/187 [00:03<?, ?it/s]
Traceback (most recent call last):
  File "/content/PartGraph-FewShot/src/train_improved.py", line 242, in <module>
    train(args)
  File "/content/PartGraph-FewShot/src/train_improved.py", line 152, in train
    recon, slots, attn, alpha = model(images)
                                ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In [6]:
# 6. Few-Shot 评估
!python src/eval_fewshot.py \
    --data_root ./data \
    --checkpoint ./checkpoints_improved/best_model.pth \
    --num_slots 8 \
    --slot_dim 256 \
    --n_way 5 \
    --k_shot 5 \
    --n_episodes 100

Using device: cuda
Test dataset: 5794 images
No checkpoint loaded, using random/pretrained weights

Evaluating 5-way 5-shot...
Running 100 episodes...

Method: backbone:   0% 0/100 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/content/PartGraph-FewShot/src/eval_fewshot.py", line 221, in <module>
    evaluate(args)
  File "/content/PartGraph-FewShot/src/eval_fewshot.py", line 180, in evaluate
    correct, total = run_episode(
                     ^^^^^^^^^^^^
  File "/content/PartGraph-FewShot/src/eval_fewshot.py", line 121, in run_episode
    support_features = extract_features(model, support_images, method)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/PartGraph-FewShot/src/eval_fewshot.py", line 43, in extract_features
    features = model.backbone(images)  # (B, C, H, W)
               ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return

In [None]:
# 7. 查看可视化结果
from IPython.display import Image, display
import glob

vis_files = sorted(glob.glob('./checkpoints_improved/epoch_*_vis.png'))
for f in vis_files[-5:]:  # 显示最后5个
    print(f)
    display(Image(filename=f))