In [1]:
!pip install torch torchvision transformers datasets accelerate



In [2]:
import torch
print(torch.backends.mps.is_available())
print(torch.backends.mps.is_built())

True
True


In [3]:
import os
import torch
from transformers import ViTImageProcessor, ViTForImageClassification, Trainer, TrainingArguments
from datasets import Dataset, load_metric
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image as PILImage
import numpy as np

data_dir = '/Users/s4meone/PycharmProjects/gestures_ai/dataset'

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

def load_data(data_dir):
    data = {"image": [], "label": []}
    label2id = {label: idx for idx, label in enumerate(os.listdir(data_dir))}
    id2label = {idx: label for label, idx in label2id.items()}
    supported_extensions = ['.jpg', '.jpeg', '.png', '.bmp']

    for label in os.listdir(data_dir):
        class_dir = os.path.join(data_dir, label)
        if not os.path.isdir(class_dir):
            continue

        for file_name in os.listdir(class_dir):
            if any(file_name.lower().endswith(ext) for ext in supported_extensions):
                img_path = os.path.join(class_dir, file_name)
                try:
                    image = PILImage.open(img_path).convert("RGB")
                    data["image"].append(image)
                    data["label"].append(label2id[label])
                except Exception as e:
                    print(f"Error loading image {img_path}: {e}")

    return data, label2id, id2label

print("–ó–∞–≥—Ä—É–∑–∫–∞ –¥–∞–Ω–Ω—ã—Ö...")
data, label2id, id2label = load_data(data_dir)
print(f"label2id: {label2id}")
print(f"–ó–∞–≥—Ä—É–∂–µ–Ω–æ {len(data['image'])} –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–π")

def preprocess_images(data, batch_size=64):
    dataset_size = len(data['image'])
    pixel_values = []
    labels = []

    for i in range(0, dataset_size, batch_size):
        batch_images = data['image'][i:i + batch_size]
        images = [image.convert("RGB") if isinstance(image, PILImage.Image) else PILImage.fromarray(image).convert("RGB") for image in batch_images]
        inputs = processor(images=images, return_tensors="pt")
        pixel_values.append(inputs["pixel_values"])
        labels.append(torch.tensor(data["label"][i:i + batch_size]))

    pixel_values = torch.cat(pixel_values)
    labels = torch.cat(labels)
    return {"pixel_values": pixel_values, "label": labels}

print("–ü—Ä–µ–æ–±—Ä–∞–∑–æ–≤–∞–Ω–∏–µ –¥–∞–Ω–Ω—ã—Ö...")
processed_data = preprocess_images(data)
print("–î–∞–Ω–Ω—ã–µ –ø—Ä–µ–æ–±—Ä–∞–∑–æ–≤–∞–Ω—ã")

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __len__(self):
        return len(self.encodings["pixel_values"])

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        return item

print("–°–æ–∑–¥–∞–Ω–∏–µ –¥–∞—Ç–∞—Å–µ—Ç–∞...")
dataset = CustomDataset(processed_data)
print("–î–∞—Ç–∞—Å–µ—Ç —Å–æ–∑–¥–∞–Ω")

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


–ó–∞–≥—Ä—É–∑–∫–∞ –¥–∞–Ω–Ω—ã—Ö...
label2id: {'dislike': 0, 'like': 1, 'no': 2, 'ok': 3, 'peace': 4, 'one': 5, 'rock': 6, 'fist': 7, 'palm': 8}
–ó–∞–≥—Ä—É–∂–µ–Ω–æ 26149 –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–π
–ü—Ä–µ–æ–±—Ä–∞–∑–æ–≤–∞–Ω–∏–µ –¥–∞–Ω–Ω—ã—Ö...
–î–∞–Ω–Ω—ã–µ –ø—Ä–µ–æ–±—Ä–∞–∑–æ–≤–∞–Ω—ã
–°–æ–∑–¥–∞–Ω–∏–µ –¥–∞—Ç–∞—Å–µ—Ç–∞...
–î–∞—Ç–∞—Å–µ—Ç —Å–æ–∑–¥–∞–Ω


In [5]:
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_dataset, test_dataset

(<torch.utils.data.dataset.Subset at 0x10e8b4310>,
 <torch.utils.data.dataset.Subset at 0x10fa22010>)

In [6]:
print("–°–æ–∑–¥–∞–Ω–∏–µ –º–æ–¥–µ–ª–∏...")
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=len(label2id))
model.config.label2id = label2id
model.config.id2label = id2label

–°–æ–∑–¥–∞–Ω–∏–µ –º–æ–¥–µ–ª–∏...


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
print("–ü–∞—Ä–∞–º–µ—Ç—Ä—ã –æ–±—É—á–µ–Ω–∏—è...")
training_args = TrainingArguments(
    output_dir='./results',
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    num_train_epochs=3, # –≠–ø–æ—Ö–∏ —Ç—É—Ç
    learning_rate=5e-5,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    remove_unused_columns=False,
    push_to_hub=False,
    load_best_model_at_end=True,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=500,
    save_total_limit=3,
)

print("–ü–∞—Ä–∞–º–µ—Ç—Ä—ã –æ–±—É—á–µ–Ω–∏—è —É—Å—Ç–∞–Ω–æ–≤–ª–µ–Ω—ã")

–ü–∞—Ä–∞–º–µ—Ç—Ä—ã –æ–±—É—á–µ–Ω–∏—è...
–ü–∞—Ä–∞–º–µ—Ç—Ä—ã –æ–±—É—á–µ–Ω–∏—è —É—Å—Ç–∞–Ω–æ–≤–ª–µ–Ω—ã


In [8]:
metric = load_metric("accuracy")

def compute_metrics(p):
    return metric.compute(predictions=p.predictions.argmax(-1), references=p.label_ids)

  metric = load_metric("accuracy")


In [9]:
print("–°–æ–∑–¥–∞–Ω–∏–µ —Ç—Ä–µ–Ω–µ—Ä–∞...")
# –°–æ–∑–¥–∞–Ω–∏–µ —Ç—Ä–µ–Ω–µ—Ä–∞
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)
print("–¢—Ä–µ–Ω–µ—Ä —Å–æ–∑–¥–∞–Ω")

–°–æ–∑–¥–∞–Ω–∏–µ —Ç—Ä–µ–Ω–µ—Ä–∞...
–¢—Ä–µ–Ω–µ—Ä —Å–æ–∑–¥–∞–Ω


dataloader_config = DataLoaderConfiguration(dispatch_batches=None)


In [None]:
print("–ù–∞—á–∞–ª–æ –æ–±—É—á–µ–Ω–∏—è...")
trainer.train()
print("–û–±—É—á–µ–Ω–∏–µ –∑–∞–≤–µ—Ä—à–µ–Ω–æ")

–ù–∞—á–∞–ª–æ –æ–±—É—á–µ–Ω–∏—è...


Epoch,Training Loss,Validation Loss,Accuracy
1,0.4182,0.084072,0.988719
2,0.0612,0.047889,0.99044
