In [None]:
# notebooks/colab.ipynb
import os
# Change to the MILS_HW2 directory first
os.chdir('../MILS_HW2')

# Cell 1: Setup and downloads
%pip install -r requirements.txt
# Download all datasets
%python scripts/download_imagenette_cls.py
%python scripts/download_coco_det.py
%python scripts/download_voc_seg.py

In [1]:
# Cell 2: Verify downloads
# Add a verification step
print("Verifying downloads...")
import os
data_paths = {
    'seg': './data/VOCdevkit/VOC2012',
    'det': './data/coco_subset',
    'cls': './data/imagenette2-160'
}

for task, path in data_paths.items():
    if os.path.exists(path):
        print(f"{task} dataset found at {path}")
    else:
        print(f"WARNING: {task} dataset not found at {path}")

Verifying downloads...
seg dataset found at ./data/VOCdevkit/VOC2012
det dataset found at ./data/coco_subset
cls dataset found at ./data/imagenette2-160


In [2]:
# Cell 3: Model and Data initialization 
from src.models.unified_model import UnifiedModel
from src.datasets.data_loaders import create_dataloaders
from configs.config import Config  # 使用Config類

# 初始化配置
config = Config()  # 創建Config實例，不是模組

# 初始化模型
model = UnifiedModel()
print(f"Total parameters: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")

# 創建數據載入器 (根據你的下載檔案)
print("Loading datasets...")
dataloaders = create_dataloaders(
    batch_size=config.batch_size,
    num_workers=config.num_workers
)
print("dataloaders", dataloaders)

# 準備datasets字典給trainer使用 (只用train set)
datasets = {
    'seg': dataloaders['seg']['train'],
    'seg_val': dataloaders['seg']['val'],
    'det': dataloaders['det']['train'],
    'det_val': dataloaders['det']['val'],
    'cls': dataloaders['cls']['train'],
    'cls_val': dataloaders['cls']['val']
}
print("datasets:\n", datasets)
print("Datasets loaded successfully!")
print(f"Detection batches: {len(datasets['det'])}")
print(f"Segmentation batches: {len(datasets['seg'])}")
print(f"Classification batches: {len(datasets['cls'])}")



Total parameters: 3.2M
Loading datasets...
loading annotations into memory...
Done (t=10.60s)
creating index...
index created!
loading annotations into memory...
Done (t=1.23s)
creating index...
index created!
dataloaders {'seg': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7fed2da63f70>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7fed2da630a0>}, 'cls': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7fed2da63640>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7fed2da63370>}, 'det': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7fed2da631c0>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7fed2da62ef0>}}
datasets:
 {'seg': <torch.utils.data.dataloader.DataLoader object at 0x7fed2da63f70>, 'seg_val': <torch.utils.data.dataloader.DataLoader object at 0x7fed2da630a0>, 'det': <torch.utils.data.dataloader.DataLoader object at 0x7fed2da631c0>, 'det_val': <torch.utils.data.dataloader.DataLoader object at 0

In [12]:
# Cell 4: Three-stage training
from src.training.trainer import MultiTaskTrainer
trainer = MultiTaskTrainer(model, datasets, config)

# Stage 1: Segmentation baseline
# trainer.train_stage_1_segmentation()

# Debug: Check dataloader lengths and try to fetch a batch
print('Detection train batches:', len(datasets['det']))
print('Classification train batches:', len(datasets['cls']))

# Try to fetch a batch from detection
try:
    det_batch = next(iter(datasets['det']))
    print('First detection batch:', det_batch)
except Exception as e:
    print('Could not fetch detection batch:', e)
    det_batch = None

# Try to fetch a batch from classification
try:
    cls_batch = next(iter(datasets['cls']))
    print('First classification batch:', cls_batch)
except Exception as e:
    print('Could not fetch classification batch:', e)
    cls_batch = None
    
print('Type of datasets[\'det\']:', type(datasets['det']))
print('Type of datasets[\'cls\']:', type(datasets['cls']))
    
# Stage 2: Detection with EWC
if det_batch is not None and len(datasets['det']) > 0:
    trainer.train_stage_2_detection()
else:
    print('Skipping Stage 2: Detection dataloader is empty or invalid.')

# Stage 3: Classification with replay
if cls_batch is not None and len(datasets['cls']) > 0:
    trainer.train_stage_3_classification()
else:
    print('Skipping Stage 3: Classification dataloader is empty or invalid.')

# Validate forgetting constraint
success = trainer.validate_forgetting_constraint()
print(f"Forgetting constraint satisfied: {success}")

Detection train batches: 15
Classification train batches: 15
First detection batch: {'image': tensor([[[[ 1.0673,  1.0331,  0.9988,  ...,  0.9988,  0.9646,  0.9646],
          [ 1.0673,  1.0331,  1.0159,  ...,  0.9988,  0.9817,  0.9988],
          [ 1.0673,  1.0331,  1.0331,  ...,  0.9988,  0.9988,  1.0159],
          ...,
          [-0.0801, -0.0116,  0.0227,  ...,  0.1254,  0.1939,  0.0056],
          [-0.0116, -0.0458,  0.0912,  ...,  0.3652,  0.0741, -0.1828],
          [ 0.4851,  0.1939,  0.2111,  ..., -0.0458,  0.0741,  0.0398]],

         [[ 2.0434,  2.0609,  2.0609,  ...,  2.1134,  2.0959,  2.0959],
          [ 2.0609,  2.0434,  2.0609,  ...,  2.1310,  2.1134,  2.1134],
          [ 2.0609,  2.0259,  2.0259,  ...,  2.0959,  2.0784,  2.0959],
          ...,
          [ 0.1527,  0.2752,  0.2752,  ...,  0.3452,  0.3803,  0.2577],
          [ 0.2227,  0.1527,  0.2227,  ...,  0.5378,  0.2402,  0.0301],
          [ 0.6078,  0.3627,  0.4328,  ...,  0.1527,  0.2752,  0.3102]],

        

RuntimeError: size mismatch (got input: [16, 21, 112, 112] , target: [16, 224, 224]

In [None]:
# Cell 5: Final evaluation
!python scripts/eval.py --weights checkpoints/final_model.pt --dataroot data --tasks all