## SSL MAE Dataset Formatting 

This notebook prepares the MVTec anormaly detection (mvtec-ad) dataset in the format compatible with TAO FTMS to Pretrain and Finetune an SSL MAE model. 

### Download the Dataset

To get started, go to https://www.mvtec.com/company/research/datasets/mvtec-ad, agree to the license terms and click "Download The Whole Dataset". 

Place the downloaded file ```mvtec_anomaly_detection.tar.xz``` in the same folder as this notebook then run all cells. 

Once complete, upload the output dataset folder  ```mvtec_ad_classification``` to your cloud storage. You can then follow the [ssl_mae_pretrain_finetune.ipynb](https://github.com/NVIDIA/tao_tutorials/tree/main/notebooks/tao_api_starter_kit/api/ssl_mae_pretrain_finetune.ipynb) notebook to learn how to pretrain and finetune an SSL MAE model with this dataset. 

In [None]:
!pip install tqdm

In [None]:
!test -f mvtec_anomaly_detection.tar.xz && echo "✅ Dataset Found." || (echo "❌ File 'mvtec_anomaly_detection.tar.xz' not found."; exit 1)

In [None]:
!mkdir mvtec_ad
!tar -xf mvtec_anomaly_detection.tar.xz -C mvtec_ad

### Format the Dataset

In this tutorial, we prepare `train/val/test` dataset for image classification. Image classification expects a directory of images with the following structure, where each class has its own directory with the class name - `good` and `defect` in this notebook. More TAO Dataset formats can be found [here](https://docs.nvidia.com/tao/tao-toolkit/text/data_annotation_format.html)
```
DATA_DIR
├── images_train
│   ├── good
│   │   ├── image_name_1.jpg
│   │   ├── ...
|   |   ... 
│   └── defect
│       ├── image_name_2.jpg
│       ├── ...
```

In [None]:
import os
import shutil
import random
from pathlib import Path
from collections import Counter
from tqdm import tqdm

In [None]:
source_dir = "mvtec_ad"
output_dir = "mvtec_ad_classification"

split_ratios = {
    "train": 0.7,
    "val": 0.15,
    "test": 0.15
}
seed = 42

In [None]:
def collect_images(root_dir):
    """Collect (image_path, label) pairs from mvtec dataset."""
    all_images = []
    categories = [d for d in Path(root_dir).iterdir() if d.is_dir()]

    for category in categories:
        train_good_dir = category / 'train' / 'good'
        test_dir = category / 'test'

        if train_good_dir.exists():
            all_images.extend((img_path, 'good') for img_path in train_good_dir.glob('*'))

        if test_dir.exists():
            for defect_type in test_dir.iterdir():
                label = 'good' if defect_type.name == 'good' else 'defect'
                all_images.extend((img_path, label) for img_path in defect_type.glob('*'))

    return all_images

def split_images(all_images):
    """Shuffle and split images into train/val/test sets."""
    random.shuffle(all_images)
    n_total = len(all_images)
    n_train = int(n_total * split_ratios["train"])
    n_val = int(n_total * split_ratios["val"])

    train_set = all_images[:n_train]
    val_set = all_images[n_train:n_train + n_val]
    test_set = all_images[n_train + n_val:]

    return train_set, val_set, test_set

def copy_and_rename(images, subset):
    """Copy and rename images to the output folder."""
    for idx, (src_path, label) in tqdm(enumerate(images), total=len(images), desc=f"Copying {subset}"):
        subset_dir = Path(output_dir) / subset / label
        subset_dir.mkdir(parents=True, exist_ok=True)
        dst_path = subset_dir / f"{idx:05d}{src_path.suffix.lower()}"
        shutil.copy(src_path, dst_path)

def count_images(output_dir):
    """Simple count of images without pandas."""
    counts = Counter()
    for subset in ['train/images_train', 'val/images_val', 'test/images_test']:
        subset_dir = Path(output_dir) / subset
        if not subset_dir.exists():
            continue
        for label in ['good', 'defect']:
            label_dir = subset_dir / label
            if not label_dir.exists():
                continue
            num_files = len(list(label_dir.glob('*')))
            counts[(subset, label)] = num_files

    print(f"{'Subset':<20} {'Label':<10} {'Count':<6}")
    print("-" * 40)
    for (subset, label), count in sorted(counts.items()):
        print(f"{subset:<20} {label:<10} {count:<6}")

In [None]:
random.seed(seed)

all_images = collect_images(source_dir)
train_set, val_set, test_set = split_images(all_images)

copy_and_rename(train_set, 'train/images_train')
copy_and_rename(val_set, 'val/images_val')
copy_and_rename(test_set, 'test/images_test')

print(f"Formatted dataset saved to '{Path(output_dir).resolve()}'!")

In [None]:
count_images(output_dir)

In [None]:
# Validate dataset format
def print_folder_structure(root_dir, indent=0):
    root_name = os.path.basename(root_dir)
    print(root_name)

    for item in sorted(os.listdir(root_dir)):
        item_path = os.path.join(root_dir, item)
        if os.path.isdir(item_path):
            print(' ' * (indent + 2) + '├── ' + item)
            for subitem in sorted(os.listdir(item_path)):
                subitem_path = os.path.join(item_path, subitem)
                if os.path.isdir(subitem_path):
                    print(' ' * (indent + 4) + '├── ' + subitem)
                    # Print just one file as example
                    files = sorted(os.listdir(subitem_path))
                    if files:
                        print(' ' * (indent + 6) + '├── ' + files[0])
                        print(' ' * (indent + 6) + '├── ...')

print_folder_structure(f"{output_dir}/train")
print_folder_structure(f"{output_dir}/val")
print_folder_structure(f"{output_dir}/test")

### Package dataset and Upload to Cloud

FTMS requires each split of the dataset to be archived and compressed with names `images_train.tar.gz`, `images_val.tar.gz` or `images_test.tar.gz`.

In [None]:
!tar -C mvtec_ad_classification/test -zcf mvtec_ad_classification/test/images_test.tar.gz images_test
!rm -rf mvtec_ad_classification/test/images_test

In [None]:
!tar -C mvtec_ad_classification/train -zcf mvtec_ad_classification/train/images_train.tar.gz images_train
!rm -rf mvtec_ad_classification/train/images_train

In [None]:
!tar -C mvtec_ad_classification/val -zcf mvtec_ad_classification/val/images_val.tar.gz images_val
!rm -rf mvtec_ad_classification/val/images_val

### Upload to Cloud Storage

If using an AWS S3 bucket, you can use the following command to upload the formatted dataset through the [AWS CLI](https://aws.amazon.com/cli/): 

```aws s3 sync mvtec_ad_classification s3://bucket_name/datasets/mvtec_ad_classification```

You should now have dataset paths in your cloud storage at 

- /bucket_name/datasets/mvtec_ad_classification/train
- /bucket_name/datasets/mvtec_ad_classification/val
- /bucket_name/datasets/mvtec_ad_classification/test