<a href="https://colab.research.google.com/github/HatemMoushir/Shark-identification-1/blob/main/shark-vit-base-patch16-224.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

!pip install -q transformers datasets torchvision evaluate
!pip install wandb


from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader
import os

# المسار إلى البيانات
data_dir = "/content/Shark_project_split"

# التحويلات المبدئية (تصغير الصور وتحويلها إلى Tensor)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# تحميل البيانات
dataset = ImageFolder(data_dir, transform=transform)

# تقسيم إلى تدريب واختبار
from torch.utils.data import random_split

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])


#---

from transformers import ViTForImageClassification, ViTFeatureExtractor, TrainingArguments, Trainer, ViTImageProcessor
import torch
import numpy as np
from datasets import Dataset as HFDataset

# استخلاص الخصائص
# feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") # Deprecated
feature_extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")

# تعريف نموذج التصنيف
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=len(dataset.classes),
    id2label={str(i): c for i, c in enumerate(dataset.classes)},
    label2id={c: str(i) for i, c in enumerate(dataset.classes)},
    ignore_mismatched_sizes=True # Add this argument to ignore the size mismatch in the classifier layer
)

#----

from PIL import Image

def transform_example(example):
    image = example['image']
    encoding = feature_extractor(images=image, return_tensors="pt")
    encoding['label'] = example['label']
    return encoding

# تحويل بيانات PyTorch إلى Dataset من نوع Hugging Face
def convert_to_hf_dataset(torch_dataset):
    images = []
    labels = []
    for img, label in torch_dataset:
        images.append(img.permute(1, 2, 0).numpy())  # Convert to HWC
        labels.append(label)
    return HFDataset.from_dict({"image": images, "label": labels})

hf_train = convert_to_hf_dataset(train_dataset).with_transform(transform_example)
hf_val = convert_to_hf_dataset(val_dataset).with_transform(transform_example)


#----

training_args = TrainingArguments(
    output_dir="./vit-shark-classifier2",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    eval_strategy="epoch", # Changed from evaluation_strategy
    save_strategy="epoch",
    num_train_epochs=5,
    logging_dir="./logs",
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="none"
)

from evaluate import load
accuracy = load("accuracy")

os.environ["WANDB_DISABLED"] = "true"

def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    return accuracy.compute(predictions=preds, references=p.label_ids)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=hf_train,
    eval_dataset=hf_val,
    image_processor=feature_extractor,
    compute_metrics=compute_metrics
)

trainer.train()

In [2]:
!gdown 165LwqivtdzeXwMaj2VeGzgspqdnOiyrq

Downloading...
From (original): https://drive.google.com/uc?id=165LwqivtdzeXwMaj2VeGzgspqdnOiyrq
From (redirected): https://drive.google.com/uc?id=165LwqivtdzeXwMaj2VeGzgspqdnOiyrq&confirm=t&uuid=30d20897-6cef-4d8a-a625-a84d6a25e28f
To: /content/Shark_project_split.zip
100% 139M/139M [00:03<00:00, 40.6MB/s]


In [None]:
!unzip "/content/Shark_project_split.zip" -d "/content/Shark_project_split"