In [None]:
import os
import random
import math
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt
import torch

import tensorflow as tf
import tensorflow_datasets as tfds

from transformers import ViTFeatureExtractor, ViTForImageClassification
from transformers import TrainingArguments, Trainer

import datasets
from datasets import load_metric

In [None]:
# hyperparameter
DATA_DIR='../input/workoutexercises-images'
OUTPUT_DIR='./trainingOutput'
height=256
width=256
channels=3
batch_size=64
seed=42
img_shape=(height, width, channels)
img_size=(height, width)
# model_id='google/vit-base-patch16-224'
model_id='./trainingOutput/checkpoint-200'
metric=load_metric("accuracy")
# class_names=['bench press', 'biceps curl', 'chest fly machine', 'deadlift', 'incline bench press', 'lat pulldown', 'push-up', 'tricep pushdown']

In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained(model_id)
# model = ViTForImageClassification.from_pretrained(model_id)

# print(feature_extractor)
# print(model)

In [4]:
def show_img(data, class_names):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(np.asarray(data['image'][i]))
        plt.title(class_names[str(data['label'][i])])
        plt.axis("off")

In [5]:
# Defing a function print shape of dataset (loaded from keras dataset)
def dataset_shape(dataset):
    for image_batch, labels_batch in dataset:
        print('images: ', image_batch.shape)
        print('labels: ', labels_batch.shape)
        break

In [6]:
def data_files ():
    data_size = 0
    for dirname, _, filenames in os.walk('/kaggle/input'):
        for filename in filenames:
#             print(os.path.join(dirname, filename))
            data_size += 1
    return data_size

In [7]:
def gen_encoding(batch):
    inputs = feature_extractor([x for x in batch['image']], return_tensors='pt')
    inputs['label'] = batch['label']
    
    return inputs

In [8]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

In [9]:
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

In [10]:
DATASET_SIZE=data_files()

train_size = int(0.8 * DATASET_SIZE)
val_size = int(0.1 * DATASET_SIZE)
test_size = int(0.1 * DATASET_SIZE)

print(DATASET_SIZE)
print(train_size)
print(val_size)
print(test_size)

In [11]:
# Load image (loaded from datasets)
dataset = datasets.load_dataset('imagefolder', data_dir='../input/workoutexercises-images/')
ds = dataset['train']
ds = ds.shuffle(seed=seed)
ds['label'][:10]

In [12]:
# Split data (loaded from datasets)
train_dsX = ds['image'][0:train_size]
val_dsX = ds['image'][train_size:train_size+val_size]
test_dsX = ds['image'][train_size+val_size:DATASET_SIZE]

print('train_dsX: ', len(train_dsX), type(train_dsX))
print('val_dsX: ', len(val_dsX), type(val_dsX))
print('test_dsX: ', len(test_dsX), type(test_dsX))

train_dsY = ds['label'][0:train_size]
val_dsY = ds['label'][train_size:train_size+val_size]
test_dsY = ds['label'][train_size+val_size:DATASET_SIZE]

print('train_dsY: ', len(train_dsY), type(train_dsY))
print('val_dsY: ', len(val_dsY), type(val_dsY))
print('test_dsY: ', len(test_dsY), type(test_dsY))

train_ds = datasets.Dataset.from_dict({'image': train_dsX, 'label': train_dsY})
test_ds = datasets.Dataset.from_dict({'image': test_dsX, 'label': test_dsY})
val_ds = datasets.Dataset.from_dict({'image': val_dsX, 'label': val_dsY})

print('train_ds: ', train_ds)
print('test_ds: ', test_ds)
print('val_ds: ', val_ds)

In [13]:
class_names = ds.features["label"].names
id2label = {str(i): label for i, label in enumerate(class_names)}
label2id = {v: k for k, v in id2label.items()}

In [14]:
show_img(train_ds, id2label)

In [15]:
show_img(test_ds, id2label)

In [16]:
show_img(val_ds, id2label)

In [17]:
pt_train_ds = train_ds.with_transform(gen_encoding)
pt_val_ds = val_ds.with_transform(gen_encoding)
pt_test_ds = test_ds.with_transform(gen_encoding)

In [18]:
model = ViTForImageClassification.from_pretrained(
    model_id,
    num_labels=len(class_names),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)

In [19]:
training_args = TrainingArguments(
  output_dir=OUTPUT_DIR,
  per_device_train_batch_size=batch_size,
  evaluation_strategy="steps",
  num_train_epochs=8,
  fp16=False,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

In [20]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=pt_train_ds,
    eval_dataset=pt_val_ds,
    tokenizer=feature_extractor
)

In [21]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

In [22]:
metrics = trainer.evaluate(pt_test_ds)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

In [29]:
feature_extractor = ViTFeatureExtractor.from_pretrained(model_id)
model = ViTForImageClassification.from_pretrained(model_id)

In [57]:
image = test_ds[random.randrange(len(test_ds))]['image']
display(image)

inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)

logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])