<a href="https://colab.research.google.com/github/HatemMoushir/Shark-identification-1/blob/main/shark_resnet18_trained.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from tqdm import tqdm

# -------------------------------------------------------
print("🧹 التحضير: التجهيزات العامة...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------------------------------------------
print("📁 تحميل الصور يدويًا...")

class SharkDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []

        class_folders = sorted(os.listdir(root_dir))
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(class_folders)}

        for cls_name in class_folders:
            cls_folder = os.path.join(root_dir, cls_name)
            for filename in os.listdir(cls_folder):
                if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
                    path = os.path.join(cls_folder, filename)
                    self.samples.append((path, self.class_to_idx[cls_name]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        image = Image.open(path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# -------------------------------------------------------
print("🧱 بناء DataLoaders...")

transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor()
])

batch_size = 32

train_dataset = SharkDataset("/content/Shark_project_split/train", transform)
val_dataset   = SharkDataset("/content/Shark_project_split/val", transform)
test_dataset  = SharkDataset("/content/Shark_project_split/test", transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size)
test_loader  = DataLoader(test_dataset, batch_size=batch_size)

num_classes = len(train_dataset.class_to_idx)
class_names = list(train_dataset.class_to_idx.keys())
print(f"📊 عدد الأصناف: {num_classes} - {class_names}")

# -------------------------------------------------------
print("🧠 تحميل موديل ResNet18...")

model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)
num_epochs = 10

# -------------------------------------------------------
print("🏋️ بدء التدريب...")

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct, total = 0, 0

    for images, labels in tqdm(train_loader, desc=f"🚂 Epoch {epoch+1}/{num_epochs}"):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    train_loss = running_loss / total
    train_acc = correct / total * 100
    print(f"📈 Epoch {epoch+1}: Loss = {train_loss:.4f}, Accuracy = {train_acc:.2f}%")

# -------------------------------------------------------
print("💾 حفظ النموذج...")

torch.save(model.state_dict(), "/content/shark_resnet18_trained.pth")
print("✅ النموذج تم حفظه كـ: shark_resnet18_trained.pth")

# -------------------------------------------------------
print("🔍 التقييم النهائي على بيانات الاختبار...")

model.eval()
correct, total = 0, 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

test_acc = correct / total * 100
print(f"🎯 دقة النموذج على بيانات الاختبار: {test_acc:.2f}%")