In [None]:
!pip install transformers[torch]

In [2]:
import transformers
from torchvision import datasets, transforms
import torch

In [3]:
import gc
torch.cuda.empty_cache()
gc.collect()

48

In [4]:
from datasets import load_dataset

train_dataset = load_dataset("mnist", split='train')
test_dataset = load_dataset("mnist", split='test')

In [None]:
print(train_dataset[0]['image'])
print(test_dataset[0]['image'])

In [6]:
classes = list(set(train_dataset['label']))
print(len(classes))

10


In [7]:
from transformers import ViTImageProcessor

model_name = 'google/vit-base-patch16-224-in21k'

image_processor = ViTImageProcessor.from_pretrained(model_name, num_channels=3, image_mean=0.5, image_std=0.5)
image_processor

ViTImageProcessor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": 0.5,
  "image_processor_type": "ViTImageProcessor",
  "image_std": 0.5,
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

In [26]:
def preprocess(batch):
    images = []
    for i, image in enumerate(batch['image']):
        # image = Image.open(image_name)
        image = image.convert('RGB')
        images.append(image)
    inputs = image_processor(images, return_tensors = 'pt')

    inputs['label'] = batch['label']

    return inputs

In [None]:
import PIL
from PIL import Image

print(train_dataset[0]['image'])
image = train_dataset[0]['image']
# image = Image.open(train_dataset[0]['image'])
image = image.convert('RGB')
example = image_processor(image, return_tensors='pt')
print(example)

In [28]:
prepared_train = train_dataset.with_transform(preprocess)
prepared_test = test_dataset.with_transform(preprocess)


In [29]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [30]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

In [None]:
pip install evaluate

In [32]:
import numpy as np
import evaluate

accuracy_metric = evaluate.load("accuracy")

def compute_metric(p):
    return accuracy_metric.compute(
        predictions=np.argmax(p.predictions, axis=1),
        references= p.label_ids
    )


In [33]:
from transformers import TrainingArguments, AdamW, get_linear_schedule_with_warmup

training_args = TrainingArguments(
    output_dir='./mnistModel',
    per_device_train_batch_size=16,
    gradient_accumulation_steps=2,
    num_train_epochs=5,
    evaluation_strategy='steps',
    save_steps = 100,
    eval_steps=100,
    fp16=True,
    logging_steps = 10,
    learning_rate = 2e-3,
    save_total_limit = 2,
    remove_unused_columns = False,
    push_to_hub = False,
    load_best_model_at_end = True,
    resume_from_checkpoint='./mnistModel',
)


In [None]:
from transformers import ViTForImageClassification
from torch import nn

model = ViTForImageClassification.from_pretrained(model_name, num_labels = len(classes), ignore_mismatched_sizes=True)
model.to(device)

In [35]:
from transformers import Trainer

trainer = Trainer(
    model = model,
    args = training_args,
    data_collator = collate_fn,
    compute_metrics = compute_metric,
    train_dataset = prepared_train,
    eval_dataset = prepared_test,
    tokenizer = image_processor
)

In [36]:
# train_results = trainer.train(resume_from_checkpoint=True)
train_results = trainer.train()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

Step,Training Loss,Validation Loss


OutOfMemoryError: CUDA out of memory. Tried to allocate 30.00 MiB (GPU 0; 4.00 GiB total capacity; 3.17 GiB already allocated; 0 bytes free; 3.44 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
metrics = trainer.evaluate(prepared_test)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  epoch                   =        5.0
  eval_accuracy           =     0.0167
  eval_loss               =     4.8265
  eval_runtime            = 0:00:42.74
  eval_samples_per_second =     67.183
  eval_steps_per_second   =      8.398


Infer
