<a href="https://colab.research.google.com/github/AlexRaudvee/MultiArchPDD-CV/blob/main/main_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Setup of environment

In [2]:
import os
import shutil
import zipfile
from pathlib import Path
from google.colab import drive

def mount_google_drive(mount_point: Path = Path('/content/drive')) -> Path:
    """Mounts Google Drive and returns the mount point."""
    drive.mount(str(mount_point))
    return mount_point

def extract_zip(zip_path: Path, extract_to: Path) -> None:
    """Extracts a zip file to the given directory."""
    if not zip_path.is_file():
        raise FileNotFoundError(f"Could not find zip file at {zip_path}")
    with zipfile.ZipFile(zip_path, 'r') as z:
        z.extractall(str(extract_to))

def move_contents(src_dir: Path, dst_dir: Path) -> None:
    """
    Moves everything from src_dir into dst_dir.
    Overwrites any existing files or folders of the same name.
    Cleans up the now-empty src_dir at the end.
    """
    if not src_dir.is_dir():
        raise FileNotFoundError(f"{src_dir} does not exist")
    for item in src_dir.iterdir():
        target = dst_dir / item.name
        if target.exists():
            print(f"Warning: {target} already exists, overwriting")
            if target.is_dir():
                shutil.rmtree(target)
            else:
                target.unlink()
        shutil.move(str(item), str(target))
    src_dir.rmdir()

def setup_directories(*dirs: Path) -> None:
    """Ensures that each directory in `dirs` exists."""
    for d in dirs:
        d.mkdir(parents=True, exist_ok=True)

def zip_folder(folder_path: Path, output_path: Path) -> None:
    """
    Recursively zip the contents of folder_path into a .zip file at output_path.
    """
    with zipfile.ZipFile(output_path, 'w', compression=zipfile.ZIP_DEFLATED) as zipf:
        for root, _, files in os.walk(folder_path):
            for fname in files:
                fpath = Path(root) / fname
                arcname = fpath.relative_to(folder_path)
                zipf.write(str(fpath), arcname)

Mounted at /content/drive


In [None]:
# ——— Constants ———
DRIVE_MOUNT_POINT = Path('/content/drive')
ZIP_PATH            = DRIVE_MOUNT_POINT / 'MyDrive/.colab.zip'
EXTRACT_TO          = Path('/content')
SRC_DIR             = EXTRACT_TO / '.colab'
DST_DIR             = EXTRACT_TO
DISTILLED_DIR       = EXTRACT_TO / 'data' / 'Distilled'
MODEL_DIR           = EXTRACT_TO / 'data' / 'checkpoints'
ASSETS_DIR          = EXTRACT_TO / 'assets' / 'viz_synthetic'

# ——— SetUp ———
mount_google_drive(DRIVE_MOUNT_POINT)
extract_zip(ZIP_PATH, EXTRACT_TO)
move_contents(SRC_DIR, DST_DIR)
setup_directories(DISTILLED_DIR)
setup_directories(ASSETS_DIR)
setup_directories(MODEL_DIR)

In [None]:
!pip install matplotlib

### Launch of Dataset Distillation

In [5]:
!python main.py multi-branch \
    --dataset mnist \
    --model convnet lenet \
    --batch-size 64 \
    --ipc 2 \
    --P 1 \
    --K 1 \
    --T 1 \
    --lr-model 1e-3 \
    --lr-syn-data 1e-2 \
    --syn-optimizer adam \
    --inner-optimizer momentum \
    --debug False \
    --out-dir data/Distilled \
    --ckpt-dir data/checkpoints

[Dataloader]:
     - Loading...
     - Done.
[Distillator]:
     - Saving...                                                                
     - models saved to ['data/checkpoints/mult-branch_mnist_convnet.pth', 'data/checkpoints/mult-branch_mnist_lenet.pth']
     - distilled dataset & history saved to data/Distilled/mult-branch_mnist_convnet_lenet.pt
     - Plotted & saved stage 1 → assets/viz_synthetic/synthetic_stage_01.png
     - Done.


In [2]:
!python -m scripts.run_distill \
  --pdd-core mm-match \
  --dataset cifar10 \
  --model convnet \
  --batch-size 64 \
  --synthetic-size 10 \
  --P 1 \
  --K 1 \
  --T 1 \
  --lr-model 1e-3 \
  --lr-syn-data 1e-2 \
  --syn-optimizer adam \
  --inner-optimizer momentum \
  --out-dir data/Distilled \
  --ckpt-dir data/checkpoints

[Dataloader]:
     - Loading...
     - Done.
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/workspaces/MultiArchPDD-CV/scripts/run_distill.py", line 141, in <module>
    main()
  File "/workspaces/MultiArchPDD-CV/scripts/run_distill.py", line 112, in main
    X_syn, Y_syn = pdd.distill()
                   ^^^^^^^^^^^^^
  File "/workspaces/MultiArchPDD-CV/distillation/PDD.py", line 101, in distill
    syn_opt = Adam(opt_params,
              ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/optim/adam.py", line 100, in __init__
    super().__init__(params, defaults)
  File "/usr/local/lib/python3.12/site-packages/torch/optim/optimizer.py", line 369, in __init__
    self.add_param_group(cast(dict, param_group))
  File "/usr/local/lib/python3.12/site-packages/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^

### Benchmarking of Distilled Dataset (dev - accuracy performance)

In [1]:
!python main.py benchmark \
    --distilled-path data/Distilled/cmps-loss_mnist_lenet_convnet.pt \
    --benchmark-mode real \
    --model convnet \
    --syn-batch-size 64 \
    --test-batch-size 256 \
    --lr 1e-3  \
    --epochs-per-stage 5 \
    --till-stage 1 \
    --real-size 1000 

[Benchmarker]:
     - Using device: cpu
     - Loading distilled data from data/Distilled/cmps-loss_mnist_lenet_convnet.pt
     - Total synthetic examples = 40; real subset size = 1000

     - [Real] Sampling 1000 examples from real mnist train split
     - [Real] Training for 5 total epochs on real data

     - Epoch 1/5 → loss 2.2248
     - Epoch 2/5 → loss 1.3308
     - Epoch 3/5 → loss 0.6539
     - Epoch 4/5 → loss 0.4139
     - Epoch 5/5 → loss 0.2905

     - Evaluating on real mnist test set…
Final test accuracy on real mnist: 88.70%
