In [1]:
import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Optional

import numpy as np
import torch
from datasets import load_dataset
from torchvision.transforms import Compose, Lambda, Normalize, RandomHorizontalFlip, RandomResizedCrop, ToTensor

In [2]:
import transformers
from transformers import (
    CONFIG_MAPPING,
    IMAGE_PROCESSOR_MAPPING,
    MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
    AutoConfig,
    AutoImageProcessor,
    AutoModelForMaskedImageModeling,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version

In [3]:
""" Pre-training a 🤗 Transformers model for simple masked image modeling (SimMIM).
Any model supported by the AutoModelForMaskedImageModeling API can be used.
"""

logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.36.0.dev0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)

In [4]:
@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    Using `HfArgumentParser` we can turn this class into argparse arguments to be able to
    specify them on the command line.
    """

    dataset_name: Optional[str] = field(
        default="cifar10", metadata={"help": "Name of a dataset from the datasets package"}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    image_column_name: Optional[str] = field(
        default=None,
        metadata={"help": "The column name of the images in the files. If not set, will try to use 'image' or 'img'."},
    )
    train_dir: Optional[str] = field(default=None, metadata={"help": "A folder containing the training data."})
    validation_dir: Optional[str] = field(default=None, metadata={"help": "A folder containing the validation data."})
    train_val_split: Optional[float] = field(
        default=0.15, metadata={"help": "Percent to split off of train for validation."}
    )
    mask_patch_size: int = field(default=32, metadata={"help": "The size of the square patches to use for masking."})
    mask_ratio: float = field(
        default=0.6,
        metadata={"help": "Percentage of patches to mask."},
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
        },
    )

    def __post_init__(self):
        data_files = {}
        if self.train_dir is not None:
            data_files["train"] = self.train_dir
        if self.validation_dir is not None:
            data_files["val"] = self.validation_dir
        self.data_files = data_files if data_files else None

In [5]:
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/image processor we are going to pre-train.
    """

    model_name_or_path: str = field(
        default=None,
        metadata={
            "help": (
                "The model checkpoint for weights initialization. Can be a local path to a pytorch_model.bin or a "
                "checkpoint identifier on the hub. "
                "Don't set if you want to train a model from scratch."
            )
        },
    )
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
    )
    config_name_or_path: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    config_overrides: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override some existing default config settings when a model is trained from scratch. Example: "
                "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
            )
        },
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store (cache) the pretrained models/datasets downloaded from the hub"},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
    token: str = field(
        default=None,
        metadata={
            "help": (
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
            )
        },
    )
    use_auth_token: bool = field(
        default=None,
        metadata={
            "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
        },
    )
    trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
                "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
                "execute code present on the Hub on your local machine."
            )
        },
    )
    image_size: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "The size (resolution) of each image. If not specified, will use `image_size` of the configuration."
            )
        },
    )
    patch_size: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "The size (resolution) of each patch. If not specified, will use `patch_size` of the configuration."
            )
        },
    )
    encoder_stride: Optional[int] = field(
        default=None,
        metadata={"help": "Stride to use for the encoder."},
    )

In [6]:
class MaskGenerator:
    """
    A class to generate boolean masks for the pretraining task.

    A mask is a 1D tensor of shape (model_patch_size**2,) where the value is either 0 or 1,
    where 1 indicates "masked".
    """

    def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6):
        self.input_size = input_size
        self.mask_patch_size = mask_patch_size
        self.model_patch_size = model_patch_size
        self.mask_ratio = mask_ratio

        if self.input_size % self.mask_patch_size != 0:
            raise ValueError("Input size must be divisible by mask patch size")
        if self.mask_patch_size % self.model_patch_size != 0:
            raise ValueError("Mask patch size must be divisible by model patch size")

        self.rand_size = self.input_size // self.mask_patch_size
        self.scale = self.mask_patch_size // self.model_patch_size

        self.token_count = self.rand_size**2
        self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))

    def __call__(self):
        mask_idx = np.random.permutation(self.token_count)[: self.mask_count]
        mask = np.zeros(self.token_count, dtype=int)
        mask[mask_idx] = 1

        mask = mask.reshape((self.rand_size, self.rand_size))
        mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)

        return torch.tensor(mask.flatten())

In [7]:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    mask = torch.stack([example["mask"] for example in examples])
    return {"pixel_values": pixel_values, "bool_masked_pos": mask}

In [8]:
args = {
    "model_type": "vit",
    "output_dir": "./outputs/",
    "overwrite_output_dir": True,
    "remove_unused_columns": False,
    "label_names": "bool_masked_pos",
    "do_train": True,
    "do_eval": True,
    "learning_rate": 2e-5,
    "weight_decay": 0.05,
    "num_train_epochs": 100,
    "per_device_train_batch_size": 8,
    "per_device_eval_batch_size": 8,
    "logging_strategy": "steps",
    "logging_steps": 10,
    "evaluation_strategy": "epoch",
    "save_strategy": "epoch",
    "load_best_model_at_end": True,
    "save_total_limit": 3,
    "seed": 1337
}

In [11]:
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_dict(args)

In [13]:
model_args

ModelArguments(model_name_or_path=None, model_type='vit', config_name_or_path=None, config_overrides=None, cache_dir=None, model_revision='main', image_processor_name=None, token=None, use_auth_token=None, trust_remote_code=False, image_size=None, patch_size=None, encoder_stride=None)

In [14]:
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)

if training_args.should_log:
    # The default of training_args.log_level is passive, so we set log level at info here to have that default.
    transformers.utils.logging.set_verbosity_info()

log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()

In [15]:
# Log on each process the small summary:
logger.warning(
    f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
    + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")

11/30/2023 17:35:40 - INFO - __main__ - Training/evaluation parameters TrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=True,
do_predict=False,
do_train=True,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=None,
evaluation_strategy=IntervalStrategy.EPOCH,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
gradient_checkpoint

In [9]:
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
    last_checkpoint = get_last_checkpoint(training_args.output_dir)
    if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. "
            "Use --overwrite_output_dir to overcome."
        )
    elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
        logger.info(
            f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
            "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
        )

# Initialize our dataset.
ds = load_dataset(
    data_args.dataset_name,
    data_args.dataset_config_name,
    data_files=data_args.data_files,
    cache_dir=model_args.cache_dir,
    token=model_args.token,
)

# If we don't have a validation split, split off a percentage of train as validation.
data_args.train_val_split = None if "validation" in ds.keys() else data_args.train_val_split
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
    split = ds["train"].train_test_split(data_args.train_val_split)
    ds["train"] = split["train"]
    ds["validation"] = split["test"]

# Create config
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
config_kwargs = {
    "cache_dir": model_args.cache_dir,
    "revision": model_args.model_revision,
    "token": model_args.token,
    "trust_remote_code": model_args.trust_remote_code,
}
if model_args.config_name_or_path:
    config = AutoConfig.from_pretrained(model_args.config_name_or_path, **config_kwargs)
elif model_args.model_name_or_path:
    config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
else:
    config = CONFIG_MAPPING[model_args.model_type]()
    logger.warning("You are instantiating a new config instance from scratch.")
    if model_args.config_overrides is not None:
        logger.info(f"Overriding config: {model_args.config_overrides}")
        config.update_from_string(model_args.config_overrides)
        logger.info(f"New config: {config}")

# make sure the decoder_type is "simmim" (only relevant for BEiT)
if hasattr(config, "decoder_type"):
    config.decoder_type = "simmim"

# adapt config
model_args.image_size = model_args.image_size if model_args.image_size is not None else config.image_size
model_args.patch_size = model_args.patch_size if model_args.patch_size is not None else config.patch_size
model_args.encoder_stride = (
    model_args.encoder_stride if model_args.encoder_stride is not None else config.encoder_stride
)

config.update(
    {
        "image_size": model_args.image_size,
        "patch_size": model_args.patch_size,
        "encoder_stride": model_args.encoder_stride,
    }
)

# create image processor
if model_args.image_processor_name:
    image_processor = AutoImageProcessor.from_pretrained(model_args.image_processor_name, **config_kwargs)
elif model_args.model_name_or_path:
    image_processor = AutoImageProcessor.from_pretrained(model_args.model_name_or_path, **config_kwargs)
else:
    IMAGE_PROCESSOR_TYPES = {
        conf.model_type: image_processor_class for conf, image_processor_class in IMAGE_PROCESSOR_MAPPING.items()
    }
    image_processor = IMAGE_PROCESSOR_TYPES[model_args.model_type]()

# create model
if model_args.model_name_or_path:
    model = AutoModelForMaskedImageModeling.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        token=model_args.token,
        trust_remote_code=model_args.trust_remote_code,
    )
else:
    logger.info("Training new model from scratch")
    model = AutoModelForMaskedImageModeling.from_config(config, trust_remote_code=model_args.trust_remote_code)

if training_args.do_train:
    column_names = ds["train"].column_names
else:
    column_names = ds["validation"].column_names

if data_args.image_column_name is not None:
    image_column_name = data_args.image_column_name
elif "image" in column_names:
    image_column_name = "image"
elif "img" in column_names:
    image_column_name = "img"
else:
    image_column_name = column_names[0]

# transformations as done in original SimMIM paper
# source: https://github.com/microsoft/SimMIM/blob/main/data/data_simmim.py
transforms = Compose(
    [
        Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
        RandomResizedCrop(model_args.image_size, scale=(0.67, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)),
        RandomHorizontalFlip(),
        ToTensor(),
        Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
    ]
)

# create mask generator
mask_generator = MaskGenerator(
    input_size=model_args.image_size,
    mask_patch_size=data_args.mask_patch_size,
    model_patch_size=model_args.patch_size,
    mask_ratio=data_args.mask_ratio,
)

def preprocess_images(examples):
    """Preprocess a batch of images by applying transforms + creating a corresponding mask, indicating
    which patches to mask."""

    examples["pixel_values"] = [transforms(image) for image in examples[image_column_name]]
    examples["mask"] = [mask_generator() for i in range(len(examples[image_column_name]))]

    return examples

if training_args.do_train:
    if "train" not in ds:
        raise ValueError("--do_train requires a train dataset")
    if data_args.max_train_samples is not None:
        ds["train"] = ds["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
    # Set the training transforms
    ds["train"].set_transform(preprocess_images)

if training_args.do_eval:
    if "validation" not in ds:
        raise ValueError("--do_eval requires a validation dataset")
    if data_args.max_eval_samples is not None:
        ds["validation"] = (
            ds["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
        )
    # Set the validation transforms
    ds["validation"].set_transform(preprocess_images)

# Initialize our trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds["train"] if training_args.do_train else None,
    eval_dataset=ds["validation"] if training_args.do_eval else None,
    tokenizer=image_processor,
    data_collator=collate_fn,
)

# Training
if training_args.do_train:
    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
    trainer.save_model()
    trainer.log_metrics("train", train_result.metrics)
    trainer.save_metrics("train", train_result.metrics)
    trainer.save_state()

# Evaluation
if training_args.do_eval:
    metrics = trainer.evaluate()
    trainer.log_metrics("eval", metrics)
    trainer.save_metrics("eval", metrics)

# Write model card and (optionally) push to hub
kwargs = {
    "finetuned_from": model_args.model_name_or_path,
    "tasks": "masked-image-modeling",
    "dataset": data_args.dataset_name,
    "tags": ["masked-image-modeling"],
}
if training_args.push_to_hub:
    trainer.push_to_hub(**kwargs)
else:
    trainer.create_model_card(**kwargs)

usage: ipykernel_launcher.py [-h] [--model_name_or_path MODEL_NAME_OR_PATH]
                             [--model_type MODEL_TYPE]
                             [--config_name_or_path CONFIG_NAME_OR_PATH]
                             [--config_overrides CONFIG_OVERRIDES]
                             [--cache_dir CACHE_DIR]
                             [--model_revision MODEL_REVISION]
                             [--image_processor_name IMAGE_PROCESSOR_NAME]
                             [--token TOKEN]
                             [--use_auth_token [USE_AUTH_TOKEN]]
                             [--trust_remote_code [TRUST_REMOTE_CODE]]
                             [--image_size IMAGE_SIZE]
                             [--patch_size PATCH_SIZE]
                             [--encoder_stride ENCODER_STRIDE]
                             [--dataset_name DATASET_NAME]
                             [--dataset_config_name DATASET_CONFIG_NAME]
                             [--image_column_name IM

SystemExit: 2