# Masked Auto Encoder for pretraining

> Masked autoenocder pre-training



All codes are taken from [here](https://github.com/huggingface/transformers/blob/main/examples/pytorch/image-pretraining/run_mae.py)

In [None]:
#| default_exp pretraining.mae_pretraining

In [29]:
%load_ext autoreload
%autoreload 2

In [30]:
#| export
import torch
print(torch.__version__)

'2.4.0+cu121'

In [55]:
#| export
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
import transformers
from fastcore.imports import *

import torch
from datasets import load_dataset
from torchvision.transforms import Compose, Lambda, Normalize, RandomHorizontalFlip, RandomResizedCrop, ToTensor
from torchvision.transforms.functional import InterpolationMode

In [None]:
#| export
from transformers import (
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    ViTImageProcessor,
    ViTMAEConfig,
    ViTMAEForPreTraining,
)

2024-09-01 14:51:28.146053: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-01 14:51:28.146096: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-01 14:51:28.146104: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-01 14:51:28.151943: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


/home/user/miniconda3/lib/python3.11/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32


  warn("The installed version of bitsandbytes was compiled without GPU support. "


In [9]:
#| export
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version

In [10]:
#| export
logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.45.0.dev0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

In [11]:
#| export
@dataclass
class DataTrainingArguments:

	dataset_name: Optional[str] = field(
        default=None, metadata={"help": "Name of a dataset from the datasets package"}
    )
	dataset_config_name: Optional[str] = field(
        default=None, 
		metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    											)

	trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether to trust the execution of code from datasets/models defined on the Hub."
                " This option should only be set to `True` for repositories you trust and in which you have read the"
                " code, as it will execute code present on the Hub on your local machine."
            )
        },
    )

	image_column_name: Optional[str] = field(
        default=None, metadata={"help": "The column name of the images in the files."}
    )

	train_dir: Optional[str] = field(
            default=None, 
            metadata={"help": "A folder containing the training data."}
                  )
	validation_dir: Optional[str] = field(
            default=None, metadata={"help": "A folder containing the validation data."}
                  )
	train_val_split: Optional[float] = field(
        default=0.15, metadata={"help": "Percent to split off of train for validation."}
    )
	max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
	max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
        },
    )

	def __post_init__(self):
		data_files = {}
		if self.train_dir is not None:
			data_files["train"] = self.train_dir
		if self.validation_dir is not None:
			data_files["val"] = self.validation_dir
		self.data_files = data_files if data_files else None

In [31]:
#| export

@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/image processor we are going to pre-train.
    """

    model_name_or_path: str = field(
        default=None,
        metadata={
            "help": (
                "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
            )
        },
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name_or_path"}
    )
    config_overrides: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override some existing default config settings when a model is trained from scratch. Example: "
                "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
            )
        },
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
    token: str = field(
        default=None,
        metadata={
            "help": (
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
            )
        },
    )
    mask_ratio: float = field(
        default=0.75, metadata={"help": "The ratio of the number of masked tokens in the input sequence."}
    )
    norm_pix_loss: bool = field(
        default=True, metadata={"help": "Whether or not to train with normalized pixel values as target."}
    )



In [93]:
HOME = os.getenv("HOME")
data_dir=Path(f'{HOME}/Schreibtisch/projects/data/easy_endline/CurrentTrainingData20240209/images')


In [94]:
ds = load_dataset(
    'imagefolder',
	data_dir=data_dir,
	split='train'
)

Resolving data files:   0%|          | 0/7901 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/7901 [00:00<?, ?files/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [95]:
ds

Dataset({
    features: ['image'],
    num_rows: 7901
})

In [98]:
ds[0]

{'image': <PIL.PngImagePlugin.PngImageFile image mode=L size=1632x1152>}

In [97]:
import numpy as np
np.unique(ds['label'])

KeyError: "Column label not in the dataset. Current columns in the dataset: ['image']"

In [32]:
#| export
@dataclass
class CustomTrainingArguments(TrainingArguments):
    base_learning_rate: float = field(
        default=1e-3, metadata={"help": "Base learning rate: absolute_lr = base_lr * total_batch_size / 256."}
    )
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    return {"pixel_values": pixel_values}


In [44]:
#| export
def main_():

    # getting arguments and parser
	parser = HfArgumentParser((
		ModelArguments, 
		DataTrainingArguments, 
		CustomTrainingArguments ))
	if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
		model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
	else:
		model_args, data_args, training_args = parser.parse_args_into_dataclasses()


	# Setup logging
	logging.basicConfig(
		format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
		datefmt="%m/%d/%Y %H:%M:%S",
		handlers=[logging.StreamHandler(sys.stdout)],
	)

	if training_args.should_log:
		transformers.utils.logging.set_verbosity_info()

	log_level = training_args.get_process_log_level()
	logger.setLevel(log_level)
	transformers.utils.logging.set_verbosity(log_level)
	transformers.utils.logging.enable_default_handler()
	transformers.utils.logging.enable_explicit_format()

	  # Log on each process the small summary:
	logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
        + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
    )
	logger.info(f"Training/evaluation parameters {training_args}")

	# Detecting last checkpoint.
	last_checkpoint = None
	if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
		last_checkpoint = get_last_checkpoint(training_args.output_dir)
		if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
			raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
		elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
			logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

	ds = load_dataset(
    	'imagefolder',
		data_dir=data_args.train_dir,
		split='train',
		cache_dir=model_args.cache_dir,
		token=model_args.token,

	)
	# Initialize our dataset.
	#ds = load_dataset(
        ##data_args.dataset_name,
        #data_args.dataset_config_name,
		#train_dir=data_args.train_dir,
        #data_files=data_args.data_files,
        #cache_dir=model_args.cache_dir,
        #token=model_args.token,
        #trust_remote_code=data_args.trust_remote_code,
    #)

	#If we don't have a validation split, split off a percentage of train as validation.
	data_args.train_val_split = None if "validation" in ds.keys() else data_args.train_val_split
	if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
		split = ds["train"].train_test_split(data_args.train_val_split)
		ds["train"] = split["train"]
		ds["validation"] = split["test"]


	# Config updateing
	config_kwargs = {

        "cache_dir": model_args.cache_dir,
        "revision": model_args.model_revision,
        "token": model_args.token,
    }

	if model_args.config_name:
		config = ViTMAEConfig.from_pretrained(model_args.config_name, **config_kwargs)
	elif model_args.model_name_or_path:
		config = ViTMAEConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
	else:
		config = ViTMAEConfig()
		logger.warning("You are instantiating a new config instance from scratch.")
		if model_args.config_overrides is not None:
			logger.info(f"Overriding config: {model_args.config_overrides}")
			config.update_from_string(model_args.config_overrides)
			logger.info(f"New config: {config}")

	# adapt config
	config.update(
    	{
        	"mask_ratio": model_args.mask_ratio,
        	"norm_pix_loss": model_args.norm_pix_loss,
    	}
	)

	# create image processor
	if model_args.image_processor_name:
		image_processor = ViTImageProcessor.from_pretrained(model_args.image_processor_name, **config_kwargs)
	elif model_args.model_name_or_path:
		image_processor = ViTImageProcessor.from_pretrained(model_args.model_name_or_path, **config_kwargs)
	else:
		image_processor = ViTImageProcessor()

	# create model
	if model_args.model_name_or_path:
		model = ViTMAEForPreTraining.from_pretrained(
			model_args.model_name_or_path,
			from_tf=bool(".ckpt" in model_args.model_name_or_path),
			config=config,
			cache_dir=model_args.cache_dir,
			revision=model_args.model_revision,
			token=model_args.token,
		)
	else:
		logger.info("Training new model from scratch")
		model = ViTMAEForPreTraining(config)

	if training_args.do_train:
		column_names = ds["train"].column_names
	else:
		column_names = ds["validation"].column_names

	if data_args.image_column_name is not None:
		image_column_name = data_args.image_column_name
	elif "image" in column_names:
		image_column_name = "image"
	elif "img" in column_names:
		image_column_name = "img"
	else:
		image_column_name = column_names[0]

	# transformations as done in original MAE paper
	# source: https://github.com/facebookresearch/mae/blob/main/main_pretrain.py
	if "shortest_edge" in image_processor.size:
		size = image_processor.size["shortest_edge"]
	else:
		size = (image_processor.size["height"], image_processor.size["width"])
	transforms = Compose(
		[
			Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
			RandomResizedCrop(size, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC),
			RandomHorizontalFlip(),
			ToTensor(),
			Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
		]
	)

	def preprocess_images(examples):
		"""Preprocess a batch of images by"""
		examples["pixel_values"] = [transforms(image) for image in examples[image_column_name]]
		return examples

	if training_args.do_train:
		if "train" not in ds:
			raise ValueError("--do_train requires a train dataset")
		if data_args.max_train_samples is not None:
			ds["train"] = ds["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
		# Set the training transforms
		ds["train"].set_transform(preprocess_images)

	if training_args.do_eval:
		if "validation" not in ds:
			raise ValueError("--do_eval requires a validation dataset")
		if data_args.max_eval_samples is not None:
			ds["validation"] = (
				ds["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
			)
		# Set the validation transforms
		ds["validation"].set_transform(preprocess_images)

		# Compute absolute learning rate
		total_train_batch_size = (
			training_args.train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
		)
		if training_args.base_learning_rate is not None:
			training_args.learning_rate = training_args.base_learning_rate * total_train_batch_size / 256

		# Initialize our trainer
		trainer = Trainer(
			model=model,
			args=training_args,
			train_dataset=ds["train"] if training_args.do_train else None,
			eval_dataset=ds["validation"] if training_args.do_eval else None,
			tokenizer=image_processor,
			data_collator=collate_fn,
		)

		# Training
		if training_args.do_train:
			checkpoint = None
			if training_args.resume_from_checkpoint is not None:
				checkpoint = training_args.resume_from_checkpoint
			elif last_checkpoint is not None:
				checkpoint = last_checkpoint
			train_result = trainer.train(resume_from_checkpoint=checkpoint)
			trainer.save_model()
			trainer.log_metrics("train", train_result.metrics)
			trainer.save_metrics("train", train_result.metrics)
			trainer.save_state()

		# Evaluation
		if training_args.do_eval:
			metrics = trainer.evaluate()
			trainer.log_metrics("eval", metrics)
			trainer.save_metrics("eval", metrics)

		# Write model card and (optionally) push to hub
		kwargs = {
			"tasks": "masked-auto-encoding",
			"dataset": data_args.dataset_name,
			"tags": ["masked-auto-encoding"],
		}



In [62]:
 # Initialize our dataset.
ds = load_dataset(
	    'cifar10',
        #data_args.dataset_config_name,
        data_files=data_args['data_files'],
        cache_dir=None,
        token=model_args['token'],
        trust_remote_code=data_args['trust_remote_code'],
    )


Downloading readme:   0%|          | 0.00/5.16k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/120M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/23.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [88]:
ds

Dataset({
    features: ['image', 'label'],
    num_rows: 7901
})

In [63]:
ds

DatasetDict({
    train: Dataset({
        features: ['img', 'label'],
        num_rows: 50000
    })
    test: Dataset({
        features: ['img', 'label'],
        num_rows: 10000
    })
})

In [69]:
data_args

{'data_files': None, 'trust_remote_code': False, 'train_val_split': 0.15}

In [73]:
def check_splits(
        ds, 
        data_args
        ):
	"check splits"
	data_args['train_val_split'] = None if 'validation' in ds.keys() else data_args['train_val_split']
	if isinstance(data_args['train_val_split'], float) and data_args['train_val_split'] > 0:
		split = ds['train'].train_test_split(data_args['train_val_split'])
		ds['train'] = split['train']
		ds['validation'] = split['test']
	return ds




In [80]:
config_kwargs = {
    "cache_dir": model_args['cache_dir'],
    "revision": model_args['model_revision'],
    "token": model_args['token']
}

In [104]:
model_args['config_name'] = None
model_args['model_name_or_path'] = None
model_args['config_overrides'] = None
model_args['mask_ratio'] = 0.75
model_args['norm_pix_loss'] = True
model_args['image_processor_name'] = None

In [86]:
def get_config(model_args):
	if model_args.config_name:
		config = ViTMAEConfig.from_pretrained(model_args.config_name, **config_kwargs)
	elif model_args.model_name_or_path:
		config = ViTMAEConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
	else:
		config = ViTMAEConfig()
		logger.warning("You are instantiating a new config instance from scratch.")
		if model_args.config_overrides is not None:
			logger.info(f"Overriding config: {model_args.config_overrides}")
			config.update_from_string(model_args.config_overrides)
			logger.info(f"New config: {config}")

    # adapt config
	config.update(
        {
            "mask_ratio": model_args.mask_ratio,
            "norm_pix_loss": model_args.norm_pix_loss,
        }
    )


In [98]:
def get_config_local(model_args):
	if model_args['config_name']:
		config = ViTMAEConfig.from_pretrained(model_args['config_name'], **config_kwargs)
	elif model_args['model_name_or_path']:
		config = ViTMAEConfig.from_pretrained(model_args['model_name_or_path'], **config_kwargs)
	else:
		config = ViTMAEConfig()
		logger.warning("You are instantiating a new config instance from scratch.")
		if model_args['config_overrides'] is not None:
			logger.info(f"Overriding config: {model_args['config_overrides']}")
			config.update_from_string(model_args['config_overrides'])
			logger.info(f"New config: {config}")

    # adapt config
	config.update(
        {
            "mask_ratio": model_args['mask_ratio'],
            "norm_pix_loss": model_args['norm_pix_loss'],
        }
    )
	return config


In [99]:
config = get_config_local(model_args)

You are instantiating a new config instance from scratch.


In [100]:
config

ViTMAEConfig {
  "attention_probs_dropout_prob": 0.0,
  "decoder_hidden_size": 512,
  "decoder_intermediate_size": 2048,
  "decoder_num_attention_heads": 16,
  "decoder_num_hidden_layers": 8,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "mask_ratio": 0.75,
  "model_type": "vit_mae",
  "norm_pix_loss": true,
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": true,
  "transformers_version": "4.45.0.dev0"
}

In [105]:
def get_image_processor(model_args):
	if model_args.image_processor_name:
		image_processor = ViTImageProcessor.from_pretrained(model_args.image_processor_name, **config_kwargs)
	else:
		image_processor = ViTImageProcessor.from_config(config)
	return image_processor

In [110]:
def get_image_processor_local(model_args,confg):
	if model_args['image_processor_name']:
		image_processor = ViTImageProcessor.from_pretrained(model_args['image_processor_name'], **config_kwargs)
	elif model_args['model_name_or_path']:
		image_processor = ViTImageProcessor.from_pretrained(model_args['model_name_or_path'], **config_kwargs)
	else:
		image_processor = ViTImageProcessor()
	return image_processor

In [145]:
image_processor =get_image_processor_local(model_args,config)

In [146]:

image_processor

ViTImageProcessor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

In [115]:
def get_model(model_args, config):
	if model_args.model_name_or_path:
		model = ViTMAEForPreTraining.from_pretrained(
			model_args.model_name_or_path,
			from_tf=bool(".ckpt" in model_args.model_name_or_path),
			config=config,
			cache_dir=model_args.cache_dir,
			revision=model_args.model_revision,
			token=model_args.token,
	)
	else:
		logger.info("Training new model from scratch")
		model = ViTMAEForPreTraining(config)
	return model

In [116]:
def get_model_local(model_args, config):
	if model_args['model_name_or_path']:
		model = ViTMAEForPreTraining.from_pretrained(
			model_args['model_name_or_path'],
			from_tf=bool(".ckpt" in model_args['model_name_or_path']),
			config=config,
			cache_dir=model_args['cache_dir'],
			revision=model_args['model_revision'],
			token=model_args['token'],
	)
	else:
		logger.info("Training new model from scratch")
		model = ViTMAEForPreTraining(config)
	return model

In [117]:
model = get_model_local(model_args, config)

In [120]:
model

ViTMAEForPreTraining(
  (vit): ViTMAEModel(
    (embeddings): ViTMAEEmbeddings(
      (patch_embeddings): ViTMAEPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
    )
    (encoder): ViTMAEEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTMAELayer(
          (attention): ViTMAESdpaAttention(
            (attention): ViTMAESdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTMAESelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTMAEIntermediate(
            (dense): Linear(in_features=768, out_features=

In [125]:
ds['train'].column_names

['img', 'label']

In [45]:

def get_column_name(training_args, data_args, ds):
	if training_args.do_train:
		column_names = ds['train'].column_names
	else:	
		column_names = ds['validation'].column_names
	return column_names


In [129]:
def get_column_name_local(training_args, data_args, ds):
	if training_args['do_train']:
		column_names = ds['train'].column_names
	else:	
		column_names = ds['validation'].column_names
	return column_names

In [130]:
column_names = get_column_name_local(training_args, data_args, ds)

In [131]:
column_names

['img', 'label']

In [46]:

def get_image_column_name(data_args, column_names):

	if data_args.image_column_name is not None:
		image_column_name = data_args.image_column_name
	elif "image" in column_names:
		image_column_name = "image"
	elif "img" in column_names:
		image_column_name = "img"
	else:
		image_column_name = column_names[0]
	return image_column_name

In [47]:
def get_image_column_name_local(data_args, column_names):

	if data_args['image_column_name'] is not None:
		image_column_name = data_args['image_column_name']
	elif "image" in column_names:
		image_column_name = "image"
	elif "img" in column_names:
		image_column_name = "img"
	else:
		image_column_name = column_names[0]
	return image_column_name

In [135]:
data_args['image_column_name'] = None

In [136]:
image_column_name = get_image_column_name_local(data_args, column_names)
image_column_name

'img'

In [49]:
def get_shortest_edge(image_processor):
	if 'shortest_edge' in image_processor.size:
		return image_processor.size['shortest_edge']
	else:
		return (image_processor.size['height'], image_processor.size['width'])


NameError: name 'image_processor' is not defined

In [151]:
size

(224, 224)

In [161]:
def transform_ds(
		training_args, 
		data_args, 
		ds, 
		image_processor, 
		size
		):
	column_names = get_column_name_local(training_args, data_args, ds)
	image_column_name = get_image_column_name_local(data_args, column_names)
	return column_names, image_column_name


In [162]:
transforms = Compose(
        			[
            		Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
            		RandomResizedCrop(size, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC),
            		RandomHorizontalFlip(),
            		ToTensor(),
            		Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
        			]
    				)

In [1]:
def preprocess_image(example):
	example['pixel_values'] = [transforms(i) for i in example[image_column_name]]
	return example

In [None]:
#| export
if __name__ == "__main__":
	main_()

In [92]:
#| hide
import nbdev; nbdev.nbdev_export('11_mae_pretraining.ipynb')