In [3]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
!pip install datasets
from datasets import load_dataset
food = load_dataset("rajistics/indian_food_images")

In [5]:
food

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 5328
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 941
    })
})

In [6]:
labels = food["train"].features["label"].names
labels

['burger',
 'butter_naan',
 'chai',
 'chapati',
 'chole_bhature',
 'dal_makhani',
 'dhokla',
 'fried_rice',
 'idli',
 'jalebi',
 'kaathi_rolls',
 'kadai_paneer',
 'kulfi',
 'masala_dosa',
 'momos',
 'paani_puri',
 'pakode',
 'pav_bhaji',
 'pizza',
 'samosa']

In [7]:
label2id , id2label = dict() , dict()

In [8]:
for i , label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

In [9]:
print(label2id)

{'burger': 0, 'butter_naan': 1, 'chai': 2, 'chapati': 3, 'chole_bhature': 4, 'dal_makhani': 5, 'dhokla': 6, 'fried_rice': 7, 'idli': 8, 'jalebi': 9, 'kaathi_rolls': 10, 'kadai_paneer': 11, 'kulfi': 12, 'masala_dosa': 13, 'momos': 14, 'paani_puri': 15, 'pakode': 16, 'pav_bhaji': 17, 'pizza': 18, 'samosa': 19}


In [10]:
print(id2label)

{0: 'burger', 1: 'butter_naan', 2: 'chai', 3: 'chapati', 4: 'chole_bhature', 5: 'dal_makhani', 6: 'dhokla', 7: 'fried_rice', 8: 'idli', 9: 'jalebi', 10: 'kaathi_rolls', 11: 'kadai_paneer', 12: 'kulfi', 13: 'masala_dosa', 14: 'momos', 15: 'paani_puri', 16: 'pakode', 17: 'pav_bhaji', 18: 'pizza', 19: 'samosa'}


In [None]:
from transformers import AutoImageProcessor

In [None]:
model_ckpt = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(model_ckpt , use_fast = True)

In [15]:
from torchvision.transforms import RandomResizedCrop , Compose , Normalize , ToTensor
normalize = Normalize(mean = image_processor.image_mean , std = image_processor.image_std)

if "shorted_edge" in image_processor.size:
  size = (image_processor.size["shorted_edge"])
else:
  size = (image_processor.size["height"] , image_processor.size["width"])

In [16]:
_transforms = Compose([RandomResizedCrop(size) , ToTensor() , normalize])

In [17]:
def transforms(examples):
  examples['pixel_values'] = [_transforms(img.convert("RGB")) for img in examples['image']]
  del examples['image']
  return examples

In [18]:
food = food.with_transform(transforms)

In [None]:
!pip install evaluate
import evaluate
import numpy as np
accuracy = evaluate.load("accuracy")

In [20]:
def compute_metrics(eval_pred):
  predictions , labels = eval_pred
  predictions = np.argmax(predictions , axis = 1)
  return accuracy.compute(predictions = predictions , references = labels)

In [None]:
from transformers import AutoModelForImageClassification , TrainingArguments , Trainer
import torch

device = "cpu"

model = AutoModelForImageClassification.from_pretrained(model_ckpt , num_labels = len(labels) , id2label = id2label , label2id = label2id).to(device )

In [None]:
args = TrainingArguments(
    output_dir = "train_dir",
    remove_unused_columns=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=4,
    num_train_epochs=4,
    load_best_model_at_end=True,
    metric_for_best_model='accuracy'
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=food['train'],
    eval_dataset=food['test'],
    tokenizer=image_processor,
    compute_metrics=compute_metrics
)

trainer.train()

In [23]:
trainer.save_model('food_classification')

In [24]:
from transformers import pipeline

pipe = pipeline("image-classification", model='food_classification', device=device)

Device set to use cpu


In [25]:
import requests
from PIL import Image
from io import BytesIO

url = 'https://www.indianhealthyrecipes.com/wp-content/uploads/2015/10/pizza-recipe-1.jpg'
response = requests.get(url)
image = Image.open(BytesIO(response.content))
image.show()

In [26]:
pipe(image)

[{'label': 'pizza', 'score': 0.9556057453155518},
 {'label': 'kadai_paneer', 'score': 0.5695244669914246},
 {'label': 'butter_naan', 'score': 0.5339012742042542},
 {'label': 'chapati', 'score': 0.5163205862045288},
 {'label': 'burger', 'score': 0.5013607144355774}]