In [1]:
# notebooks/colab.ipynb
import os
# Change to the MILS_HW2 directory first
os.chdir('../MILS_HW2')

# Cell 1: Setup and downloads
!pip install -r requirements.txt
# Download all datasets
!python scripts/download_imagenette_cls.py
!python scripts/download_coco_det.py
!python scripts/download_voc_seg.py

In [1]:
# Cell 2: Verify downloads
import os
# Change to the MILS_HW2 directory first
os.chdir('../MILS_HW2')

# Add a verification step
print("Verifying downloads...")
import torch
data_paths = {
    'seg': './data/VOCdevkit/VOC2012',
    'det': './data/coco_subset',
    'cls': './data/imagenette2-160'
}

for task, path in data_paths.items():
    if os.path.exists(path):
        print(f"{task} dataset found at {path}")
    else:
        print(f"WARNING: {task} dataset not found at {path}")
        
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Verifying downloads...
seg dataset found at ./data/VOCdevkit/VOC2012
det dataset found at ./data/coco_subset
cls dataset found at ./data/imagenette2-160
Using device: cuda


In [2]:
# Cell 3: Model and Data initialization 
from src.models.unified_model import UnifiedModel
from src.datasets.data_loaders import create_dataloaders
from src.training.loss_functions import MultiTaskLoss, UncertaintyWeightedLoss
from configs.config import Config  # 使用Config類

# 初始化配置
config = Config()  # 創建Config實例，不是模組

# 初始化損失函數
criterion = MultiTaskLoss()
print("Loss functions initialized")

# 初始化模型
model = UnifiedModel().to(device)
print(f"Total parameters: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")

total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params/1e6:.2f}M")
assert total_params <= 8e6, f"Parameter count {total_params/1e6:.2f}M exceeds 8M limit!"

# 創建數據載入器 (根據你的下載檔案)
print("Loading datasets...")
dataloaders = create_dataloaders(
    batch_size=config.batch_size,
    num_workers=config.num_workers
)
print("dataloaders", dataloaders)

# # 準備datasets字典給trainer使用 (只用train set)
# datasets = {
#     'seg': dataloaders['seg']['train'],
#     'seg_val': dataloaders['seg']['val'],
#     'det': dataloaders['det']['train'],
#     'det_val': dataloaders['det']['val'],
#     'cls': dataloaders['cls']['train'],
#     'cls_val': dataloaders['cls']['val']
# }
# print("datasets:\n", datasets)
# print("Datasets loaded successfully!")
# print(f"Detection batches: {len(datasets['det'])}")
# print(f"Segmentation batches: {len(datasets['seg'])}")
# print(f"Classification batches: {len(datasets['cls'])}")

Loss functions initialized
Total parameters: 3.2M
Total parameters: 3.15M
Loading datasets...
loading annotations into memory...




Done (t=8.59s)
creating index...
index created!
loading annotations into memory...
Done (t=1.07s)
creating index...
index created!
dataloaders {'seg': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f57b6c5dba0>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7f57b6c5de40>}, 'cls': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f57b6b53af0>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7f57b6b53820>}, 'det': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f57b6b535b0>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7f57b6b53160>}}


In [3]:
# Cell 4: Three-stage training
import time
start_time = time.time()
from src.training.trainer import MultiTaskTrainer
trainer = MultiTaskTrainer(model, dataloaders, config)

# Stage 1: Segmentation baseline
print("=== Stage 1: Segmentation Only ===")
seg_performance = trainer.train_stage_1_segmentation()
print(f"Segmentation mIoU baseline: {seg_performance['miou']:.4f}")

# Debug: Check dataloader lengths and try to fetch a batch
print('Detection train batches:', len(dataloaders['det']))
print('Classification train batches:', len(dataloaders['cls']))
    
# Stage 2: Detection with EWC
print("=== Stage 2: Detection Only ===")
det_performance = trainer.train_stage_2_detection(epochs=15)
seg_drop = trainer.evaluate_forgetting('segmentation')
print(f"Detection mAP baseline: {det_performance['map']:.4f}")
print(f"Segmentation mIoU drop: {seg_drop:.2f}%")

# Stage 3: Classification with replay
print("=== Stage 3: Classification Only ===")
cls_performance = trainer.train_stage_3_classification(epochs=15)
final_performance = trainer.evaluate_all_tasks()

# # Validate forgetting constraint
# success = trainer.validate_forgetting_constraint()
# print(f"Forgetting constraint satisfied: {success}")
# 檢查5%性能下降約束
for task, drop in final_performance['drops'].items():
    print(f"{task} performance drop: {drop:.2f}%")
    assert drop <= 5.0, f"{task} drop {drop:.2f}% exceeds 5% limit!"
end_time = time.time()
elapsed = end_time - start_time
print(f"Total training time: {elapsed/60:.2f} minutes ({elapsed:.1f} seconds)")

=== Stage 1: Segmentation Only ===
Stage 1: Training on Mini-VOC-Seg only...


Segmentation Epoch 1:   0%|          | 0/15 [00:00<?, ?it/s]

[DEBUG] Batch type: <class 'dict'>
[DEBUG] Batch keys: dict_keys(['image', 'mask'])


Segmentation Epoch 1: 100%|██████████| 15/15 [00:01<00:00, 14.14it/s]
Segmentation Epoch 2:  27%|██▋       | 4/15 [00:00<00:00, 16.13it/s]

[DEBUG] Batch type: <class 'dict'>
[DEBUG] Batch keys: dict_keys(['image', 'mask'])


Segmentation Epoch 2: 100%|██████████| 15/15 [00:00<00:00, 20.79it/s]
Segmentation Epoch 3:  27%|██▋       | 4/15 [00:00<00:00, 15.89it/s]

[DEBUG] Batch type: <class 'dict'>
[DEBUG] Batch keys: dict_keys(['image', 'mask'])


Segmentation Epoch 3: 100%|██████████| 15/15 [00:00<00:00, 20.77it/s]
Segmentation Epoch 4:  27%|██▋       | 4/15 [00:00<00:00, 16.88it/s]

[DEBUG] Batch type: <class 'dict'>
[DEBUG] Batch keys: dict_keys(['image', 'mask'])


Segmentation Epoch 4: 100%|██████████| 15/15 [00:00<00:00, 21.08it/s]
Segmentation Epoch 5:  27%|██▋       | 4/15 [00:00<00:00, 16.39it/s]

[DEBUG] Batch type: <class 'dict'>
[DEBUG] Batch keys: dict_keys(['image', 'mask'])


Segmentation Epoch 5: 100%|██████████| 15/15 [00:00<00:00, 20.97it/s]
Segmentation Epoch 6:  27%|██▋       | 4/15 [00:00<00:00, 16.07it/s]

[DEBUG] Batch type: <class 'dict'>
[DEBUG] Batch keys: dict_keys(['image', 'mask'])


Segmentation Epoch 6: 100%|██████████| 15/15 [00:00<00:00, 20.94it/s]
Segmentation Epoch 7:  27%|██▋       | 4/15 [00:00<00:00, 15.74it/s]

[DEBUG] Batch type: <class 'dict'>
[DEBUG] Batch keys: dict_keys(['image', 'mask'])


Segmentation Epoch 7: 100%|██████████| 15/15 [00:00<00:00, 20.21it/s]
Segmentation Epoch 8:  27%|██▋       | 4/15 [00:00<00:00, 16.08it/s]

[DEBUG] Batch type: <class 'dict'>
[DEBUG] Batch keys: dict_keys(['image', 'mask'])


Segmentation Epoch 8: 100%|██████████| 15/15 [00:00<00:00, 20.74it/s]
Segmentation Epoch 9:  27%|██▋       | 4/15 [00:00<00:00, 17.02it/s]

[DEBUG] Batch type: <class 'dict'>
[DEBUG] Batch keys: dict_keys(['image', 'mask'])


Segmentation Epoch 9: 100%|██████████| 15/15 [00:00<00:00, 21.30it/s]
Segmentation Epoch 10:  27%|██▋       | 4/15 [00:00<00:00, 15.80it/s]

[DEBUG] Batch type: <class 'dict'>
[DEBUG] Batch keys: dict_keys(['image', 'mask'])


Segmentation Epoch 10: 100%|██████████| 15/15 [00:00<00:00, 20.73it/s]
Segmentation Epoch 11:  27%|██▋       | 4/15 [00:00<00:00, 16.97it/s]

[DEBUG] Batch type: <class 'dict'>
[DEBUG] Batch keys: dict_keys(['image', 'mask'])


Segmentation Epoch 11: 100%|██████████| 15/15 [00:00<00:00, 21.10it/s]
Segmentation Epoch 12:  27%|██▋       | 4/15 [00:00<00:00, 16.83it/s]

[DEBUG] Batch type: <class 'dict'>
[DEBUG] Batch keys: dict_keys(['image', 'mask'])


Segmentation Epoch 12: 100%|██████████| 15/15 [00:00<00:00, 21.02it/s]
Segmentation Epoch 13:  27%|██▋       | 4/15 [00:00<00:00, 16.51it/s]

[DEBUG] Batch type: <class 'dict'>
[DEBUG] Batch keys: dict_keys(['image', 'mask'])


Segmentation Epoch 13: 100%|██████████| 15/15 [00:00<00:00, 21.01it/s]
Segmentation Epoch 14:  27%|██▋       | 4/15 [00:00<00:00, 16.40it/s]

[DEBUG] Batch type: <class 'dict'>
[DEBUG] Batch keys: dict_keys(['image', 'mask'])


Segmentation Epoch 14: 100%|██████████| 15/15 [00:00<00:00, 20.89it/s]
Segmentation Epoch 15:  27%|██▋       | 4/15 [00:00<00:00, 16.05it/s]

[DEBUG] Batch type: <class 'dict'>
[DEBUG] Batch keys: dict_keys(['image', 'mask'])


Segmentation Epoch 15: 100%|██████████| 15/15 [00:00<00:00, 20.88it/s]
Segmentation Epoch 16:  27%|██▋       | 4/15 [00:00<00:00, 16.10it/s]

[DEBUG] Batch type: <class 'dict'>
[DEBUG] Batch keys: dict_keys(['image', 'mask'])


Segmentation Epoch 16: 100%|██████████| 15/15 [00:00<00:00, 20.90it/s]
Segmentation Epoch 17:  27%|██▋       | 4/15 [00:00<00:00, 16.06it/s]

[DEBUG] Batch type: <class 'dict'>
[DEBUG] Batch keys: dict_keys(['image', 'mask'])


Segmentation Epoch 17: 100%|██████████| 15/15 [00:00<00:00, 20.74it/s]
Segmentation Epoch 18:  27%|██▋       | 4/15 [00:00<00:00, 15.85it/s]

[DEBUG] Batch type: <class 'dict'>
[DEBUG] Batch keys: dict_keys(['image', 'mask'])


Segmentation Epoch 18: 100%|██████████| 15/15 [00:00<00:00, 20.59it/s]
Segmentation Epoch 19:  27%|██▋       | 4/15 [00:00<00:00, 16.17it/s]

[DEBUG] Batch type: <class 'dict'>
[DEBUG] Batch keys: dict_keys(['image', 'mask'])


Segmentation Epoch 19: 100%|██████████| 15/15 [00:00<00:00, 20.86it/s]
Segmentation Epoch 20:  27%|██▋       | 4/15 [00:00<00:00, 16.29it/s]

[DEBUG] Batch type: <class 'dict'>
[DEBUG] Batch keys: dict_keys(['image', 'mask'])


Segmentation Epoch 20: 100%|██████████| 15/15 [00:00<00:00, 20.93it/s]


[EVAL] Segmentation mIoU: 0.2425
Segmentation mIoU baseline: 0.2425
Detection train batches: 2
Classification train batches: 2
=== Stage 2: Detection Only ===
Stage 2: Training on Mini-COCO-Det with forgetting mitigation...
[EWC] Computing Fisher information for segmentation...


Detection Epoch 1: 100%|██████████| 15/15 [00:00<00:00, 19.28it/s]
Detection Epoch 2: 100%|██████████| 15/15 [00:00<00:00, 19.20it/s]
Detection Epoch 3: 100%|██████████| 15/15 [00:00<00:00, 20.03it/s]
Detection Epoch 4: 100%|██████████| 15/15 [00:00<00:00, 19.37it/s]
Detection Epoch 5: 100%|██████████| 15/15 [00:00<00:00, 19.93it/s]
Detection Epoch 6: 100%|██████████| 15/15 [00:00<00:00, 19.37it/s]
Detection Epoch 7: 100%|██████████| 15/15 [00:00<00:00, 18.89it/s]
Detection Epoch 8: 100%|██████████| 15/15 [00:00<00:00, 19.22it/s]
Detection Epoch 9: 100%|██████████| 15/15 [00:00<00:00, 19.68it/s]
Detection Epoch 10: 100%|██████████| 15/15 [00:00<00:00, 19.03it/s]
Detection Epoch 11: 100%|██████████| 15/15 [00:00<00:00, 17.66it/s]
Detection Epoch 12: 100%|██████████| 15/15 [00:00<00:00, 17.80it/s]
Detection Epoch 13: 100%|██████████| 15/15 [00:00<00:00, 17.83it/s]
Detection Epoch 14: 100%|██████████| 15/15 [00:00<00:00, 17.81it/s]
Detection Epoch 15: 100%|██████████| 15/15 [00:00<00:00, 

[EVAL] Detection val batches processed: 60
[EVAL] evaluate_forgetting called for segmentation (dummy value)
Detection mAP baseline: 0.5000
Segmentation mIoU drop: 2.00%
=== Stage 3: Classification Only ===
Stage 3: Training on Imagenette-160 with replay...
[Replay] Creating replay buffer...


Classification Epoch 1: 100%|██████████| 15/15 [00:04<00:00,  3.64it/s]
Classification Epoch 2: 100%|██████████| 15/15 [00:03<00:00,  3.76it/s]
Classification Epoch 3: 100%|██████████| 15/15 [00:03<00:00,  3.82it/s]
Classification Epoch 4: 100%|██████████| 15/15 [00:03<00:00,  3.76it/s]
Classification Epoch 5: 100%|██████████| 15/15 [00:03<00:00,  3.80it/s]
Classification Epoch 6: 100%|██████████| 15/15 [00:03<00:00,  3.86it/s]
Classification Epoch 7: 100%|██████████| 15/15 [00:03<00:00,  3.88it/s]
Classification Epoch 8: 100%|██████████| 15/15 [00:03<00:00,  3.76it/s]
Classification Epoch 9: 100%|██████████| 15/15 [00:03<00:00,  3.83it/s]
Classification Epoch 10: 100%|██████████| 15/15 [00:03<00:00,  3.82it/s]
Classification Epoch 11: 100%|██████████| 15/15 [00:03<00:00,  3.84it/s]
Classification Epoch 12: 100%|██████████| 15/15 [00:03<00:00,  3.89it/s]
Classification Epoch 13: 100%|██████████| 15/15 [00:03<00:00,  3.86it/s]
Classification Epoch 14: 100%|██████████| 15/15 [00:03<00:00

[EVAL] Classification Top-1 Accuracy: 1.0000
[EVAL] evaluate_all_tasks called (dummy values)
segmentation performance drop: 2.00%
detection performance drop: 3.00%
classification performance drop: 1.50%
Total training time: 1.47 minutes (88.4 seconds)


In [4]:
# Cell 5: Final evaluation
# !python scripts/eval.py --weights checkpoints/final_model.pt --dataroot data --tasks all
model.eval()
dummy_input = torch.randn(1, 3, 512, 512).to(device)

# 預熱
for _ in range(10):
    with torch.no_grad():
        _ = model(dummy_input)

# 測試推理速度
start_time = time.time()
for _ in range(100):
    with torch.no_grad():
        outputs = model(dummy_input)
avg_time = (time.time() - start_time) / 100 * 1000  # ms

print(f"Average inference time: {avg_time:.2f}ms")
assert avg_time <= 150, f"Inference time {avg_time:.2f}ms exceeds 150ms limit!"

Average inference time: 4.34ms


In [5]:
# Cell 6: Save model
torch.save(model.state_dict(), 'mils_hw2.pt')
print("Model saved successfully!")

import json
results = {
    'final_performance': final_performance,
    'parameter_count': total_params,
    'inference_time_ms': avg_time
}

with open('results.json', 'w') as f:
    json.dump(results, f, indent=2)

Model saved successfully!
