## Multi-Golden ChangeNet Dataset Formatting 

This notebook prepares the MVTec anomaly detection (mvtec-ad) dataset in the format compatible with TAO FTMS to train Multi-Golden ChangeNet classification. 

### Download the Dataset

To get started, go to https://www.mvtec.com/company/research/datasets/mvtec-ad, agree to the license terms and click "Download The Whole Dataset". 

Place the downloaded file ```mvtec_anomaly_detection.tar.xz``` in the same folder as this notebook then run all cells. 

Once complete, upload the output dataset folder  ```mvtec_ad_mgcn``` to your cloud storage. You can then follow the [cradio_mg_changenet.ipynb](https://github.com/NVIDIA/tao_tutorials/blob/main/notebooks/tao_api_starter_kit/api/cradio_mg_changenet.ipynb) notebook to learn how to train a ChangeNet model with this dataset. 

In [None]:
!pip install tqdm 

In [None]:
!test -f mvtec_anomaly_detection.tar.xz && echo "✅ Dataset Found." || (echo "❌ File 'mvtec_anomaly_detection.tar.xz' not found."; exit 1)

### Format the Dataset

In this tutorial, we prepare dataset for Multi-Golden ChangeNet classification. The dataset will be structured with good/defective images and 4 golden reference images for each class, along with CSV files describing the dataset structure. The expected directory structure is:

```
|-- dataset_root:
    |-- images
        |-- good:
            |-- G1.jpg
            |-- G2.jpg
        |-- defective:
            |-- D1.jpg
            |-- D2.jpg
        |-- golden:
            |-- G1.jpg
            |-- G2.jpg
            ...
    |-- dataset.csv
```

CSV includes:
- ``input_path``: The path to the directory containing input compare image.
- ``golden_path``: The path to the directory containing corresponding golden reference image.
- ``label``: The labels for the pair-wise images (Use `PASS` for non-defective components, and any other specific defect type label for defective components).
- ``object_name``: The name of the component. It does not need to match any filenames in the golden folder.


In [None]:
#skip this cell if it has already been run 
!mkdir mvtec_ad
!tar -xf mvtec_anomaly_detection.tar.xz -C mvtec_ad

In [None]:
import os
import shutil
import random
import csv
from pathlib import Path
from collections import Counter
from tqdm import tqdm

In [None]:
source_dir = "mvtec_ad"  # your input dataset root
output_dir = "mvtec_ad_mgcn"   # where to create new structured dataset

# Split percentages
split_ratios = {
    "train": 0.7,
    "val": 0.15,
    "test": 0.15
}

random.seed(42)  # for reproducibility

In [None]:
def prepare_dirs():
    """Create split and images/ subdirs"""
    for split in ["train", "val", "test"]:
        images_path = os.path.join(output_dir, split, "images")
        os.makedirs(images_path, exist_ok=True)

def collect_images(class_path):
    good_images = []
    defective_images = []
    for split_folder in ["train", "test"]:
        split_path = os.path.join(class_path, split_folder)
        if not os.path.exists(split_path):
            continue
        for defect_type in os.listdir(split_path):
            defect_path = os.path.join(split_path, defect_type)
            if not os.path.isdir(defect_path):
                continue
            for img in os.listdir(defect_path):
                full_path = os.path.join(defect_path, img)
                if defect_type == "good":
                    good_images.append(full_path)
                else:
                    defective_images.append(full_path)
    return good_images, defective_images

def split_data(images):
    random.shuffle(images)
    total = len(images)
    train_end = int(total * split_ratios["train"])
    val_end = train_end + int(total * split_ratios["val"])
    return {
        "train": images[:train_end],
        "val": images[train_end:val_end],
        "test": images[val_end:]
    }

def copy_images(image_list, split, class_name, label):
    output_paths = []
    for src_path in image_list:
        basename = os.path.basename(src_path)
        new_filename = f"{class_name}_{split}_{label}_{basename}"

        dst_dir = os.path.join(output_dir, split, "images", class_name, label)
        os.makedirs(dst_dir, exist_ok=True)
        dst_path = os.path.join(dst_dir, new_filename)

        if os.path.exists(dst_path):
            os.remove(dst_path)

        shutil.copy2(src_path, dst_path)
        os.chmod(dst_path, 0o644)
        output_paths.append(dst_path)
    return output_paths

def sample_golden_images(good_images):
    if len(good_images) < 4:
        print(f"Warning: Only {len(good_images)} good images available, using all.")
        return good_images
    return random.sample(good_images, 4)

def copy_golden_images(golden_images, class_name):
    for split in ["train", "val", "test"]:
        for src_path in golden_images:
            basename = os.path.basename(src_path)
            new_filename = f"{class_name}_{split}_4golden_{basename}"

            dst_dir = os.path.join(output_dir, split, "images", class_name, "4golden")
            os.makedirs(dst_dir, exist_ok=True)
            dst_path = os.path.join(dst_dir, new_filename)

            if os.path.exists(dst_path):
                os.remove(dst_path)

            shutil.copy2(src_path, dst_path)
            os.chmod(dst_path, 0o644)

def generate_csv(split):
    csv_path = os.path.join(output_dir, split, "dataset.csv")
    rows = []

    images_dir = os.path.join(output_dir, split, "images")
    classes = os.listdir(images_dir)

    for class_name in classes:
        class_path = os.path.join(images_dir, class_name)
        if not os.path.isdir(class_path):
            continue
        for label in ["good", "defective"]:
            label_dir = os.path.join(class_path, label)
            if not os.path.exists(label_dir):
                continue
            for img_file in os.listdir(label_dir):
                input_path = f"{class_name}/{label}"
                golden_path = f"{class_name}/4golden"
                object_name = os.path.splitext(img_file)[0].zfill(3)
                label_value = "PASS" if label == "good" else "NG"
                rows.append([input_path, golden_path, label_value, object_name])

    with open(csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["input_path", "golden_path", "label", "object_name"])
        writer.writerows(rows)

def count_images(output_dir):
    counts = Counter()
    for split in ["train", "val", "test"]:
        images_dir = os.path.join(output_dir, split, "images")
        if not os.path.exists(images_dir):
            continue
        for class_name in os.listdir(images_dir):
            for label in ["good", "defective", "4golden"]:
                label_dir = os.path.join(images_dir, class_name, label)
                if not os.path.exists(label_dir):
                    continue
                num_files = len([
                    f for f in os.listdir(label_dir)
                    if os.path.isfile(os.path.join(label_dir, f))
                ])
                counts[(split, class_name, label)] += num_files

    print(f"{'Split':<10} {'Class':<20} {'Label':<12} {'Count':<6}")
    print("-" * 50)
    for (split, class_name, label), count in sorted(counts.items()):
        print(f"{split:<10} {class_name:<20} {label:<12} {count:<6}")

In [None]:
prepare_dirs()

classes = [d for d in os.listdir(source_dir) if os.path.isdir(os.path.join(source_dir, d))]

for class_name in tqdm(classes):
    class_path = os.path.join(source_dir, class_name)
    good_images, defective_images = collect_images(class_path)

    golden_images = sample_golden_images(good_images)
    split_good = split_data(good_images)
    split_defective = split_data(defective_images)

    for split in ["train", "val", "test"]:
        copy_images(split_good[split], split, class_name, "good")
        copy_images(split_defective[split], split, class_name, "defective")

    copy_golden_images(golden_images, class_name)

for split in ["train", "val", "test"]:
    generate_csv(split)

In [None]:
count_images("mvtec_ad_mgcn")

In [None]:
# Validate dataset format
def print_folder_structure(directory, prefix=""):
    """Print the folder structure of a directory."""
    if not os.path.exists(directory):
        print(f"Directory {directory} does not exist")
        return
        
    for item in sorted(os.listdir(directory)):
        path = os.path.join(directory, item)
        if os.path.isdir(path):
            print(f"{prefix}├── {item}/")
            print_folder_structure(path, prefix + "│   ")
        else:
            # For PNG files, only print one
            if item.endswith('.png'):
                if not hasattr(print_folder_structure, 'png_printed'):
                    print(f"{prefix}├── {item}")
                    print_folder_structure.png_printed = True
            else:
                print(f"{prefix}├── {item}")

print_folder_structure(output_dir)



### Package dataset and Upload to Cloud

FTMS requires each split to have `dataset.csv` and images folder to be archived and compressed with names `images.tar.gz`.


In [None]:
!tar -C mvtec_ad_mgcn/test -zcf mvtec_ad_mgcn/test/images.tar.gz images
!rm -rf mvtec_ad_mgcn/test/images

In [None]:
!tar -C mvtec_ad_mgcn/train -zcf mvtec_ad_mgcn/train/images.tar.gz images
!rm -rf mvtec_ad_mgcn/train/images

In [None]:
!tar -C mvtec_ad_mgcn/val -zcf mvtec_ad_mgcn/val/images.tar.gz images
!rm -rf mvtec_ad_mgcn/val/images

### Upload to Cloud Storage

If using an AWS S3 bucket, you can use the following command to upload the formatted dataset through the [AWS CLI](https://aws.amazon.com/cli/): 

```aws s3 sync mvtec_ad_mgcn s3://bucket_name/datasets/mvtec_ad_mgcn```

You should now have dataset paths in your cloud storage at 

- /bucket_name/datasets/mvtec_ad_mgcn/train
- /bucket_name/datasets/mvtec_ad_mgcn/val
- /bucket_name/datasets/mvtec_ad_mgcn/test