In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os

from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader

from torchvision.models import vit_b_16, ViT_B_16_Weights
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights

from sklearn.metrics import f1_score, precision_score, recall_score


In [31]:
# ================== Configuration ==================
class Config:
    # Images
    img_size_global = 224
    img_size_local = 96
    patch_size = 96
    num_patches = 4

    # Mod√®le
    embed_dim = 384
    dropout = 0.5

    # Entra√Ænement
    lr = 5e-5
    weight_decay = 0.05
    batch_size = 8
    epochs = 20

    # R√©gularisation
    use_mixup = True
    mixup_alpha = 0.2
    label_smoothing = 0.1
    early_stopping_patience = 7
    warmup_epochs = 3

    # Features avanc√©es
    use_smart_patches = True
    use_attention = True


# Chemins dataset
TRAIN_DIR = r"C:\Users\ENNHILI YASSINE\Desktop\ASMAE-ABDELOUAFI\train"
VAL_DIR   = r"C:\Users\ENNHILI YASSINE\Desktop\ASMAE-ABDELOUAFI\val"
TEST_DIR  = r"C:\Users\ENNHILI YASSINE\Desktop\ASMAE-ABDELOUAFI\test"

print("üìÅ V√©rification des chemins...")
print("TRAIN:", os.path.exists(TRAIN_DIR))
print("VAL:", os.path.exists(VAL_DIR))
print("TEST:", os.path.exists(TEST_DIR))


üìÅ V√©rification des chemins...
TRAIN: True
VAL: True
TEST: True


In [32]:
# ================== Transformations ==================
def get_transforms():
    train_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomResizedCrop(Config.img_size_global, scale=(0.7, 1.0)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.3),
        transforms.RandomRotation(30),
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
        transforms.RandomGrayscale(p=0.1),
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.2, scale=(0.02, 0.15)),
    ])

    val_transform = transforms.Compose([
        transforms.Resize((Config.img_size_global, Config.img_size_global)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    local_transform = transforms.Compose([
        transforms.Resize((Config.img_size_local, Config.img_size_local)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    return train_transform, val_transform, local_transform


transform_global_train, transform_global_val, transform_local = get_transforms()


# ================== SmartPatchExtractor ==================
class SmartPatchExtractor:
    def __init__(self, patch_size, num_patches):
        self.patch_size = patch_size
        self.num_patches = num_patches

    def extract_patches(self, pil_img):
        img = np.array(pil_img)
        gray = np.mean(img, axis=2)
        H, W = gray.shape
        ps = self.patch_size

        gy, gx = np.gradient(gray)
        grad_mag = np.sqrt(gx ** 2 + gy ** 2)

        patch_scores = []
        for y in range(0, H - ps, ps // 2):
            for x in range(0, W - ps, ps // 2):
                patch = grad_mag[y:y + ps, x:x + ps]
                score = patch.mean()
                patch_scores.append((score, x, y))

        patch_scores.sort(reverse=True)
        patches = []

        for _, x, y in patch_scores[:self.num_patches]:
            patch = pil_img.crop((x, y, x + ps, y + ps))
            patches.append(patch)

        if len(patches) < self.num_patches:
            patches += [patches[-1]] * (self.num_patches - len(patches))

        return patches


# ================== Dataset ==================
class SmartLeafDataset(Dataset):
    def __init__(self, root_dir, transform_global, transform_local,
                 num_patches=4, patch_size=96, use_smart_patches=True):

        self.root_dir = root_dir
        self.transform_global = transform_global
        self.transform_local = transform_local
        self.num_patches = num_patches
        self.patch_size = patch_size
        self.use_smart_patches = use_smart_patches

        if use_smart_patches:
            self.patch_extractor = SmartPatchExtractor(patch_size, num_patches)

        self.samples = []
        self.classes = sorted([d for d in os.listdir(root_dir)
                              if os.path.isdir(os.path.join(root_dir, d))])
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}

        for cls in self.classes:
            cls_dir = os.path.join(root_dir, cls)
            for fname in os.listdir(cls_dir):
                if fname.lower().endswith(('.jpg', '.jpeg', '.png')):
                    self.samples.append((os.path.join(cls_dir, fname), self.class_to_idx[cls]))

    def __len__(self):
        return len(self.samples)

    def _extract_random_patches(self, pil_img):
        w, h = pil_img.size
        ps = self.patch_size
        patches = []

        if w < ps or h < ps:
            pil_img = pil_img.resize((max(w, ps), max(h, ps)))
            w, h = pil_img.size

        for _ in range(self.num_patches):
            x = np.random.randint(0, max(1, w - ps + 1))
            y = np.random.randint(0, max(1, h - ps + 1))
            patch = pil_img.crop((x, y, min(x + ps, w), min(y + ps, h)))
            if patch.size != (ps, ps):
                patch = patch.resize((ps, ps))
            patches.append(patch)

        return patches

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert('RGB')

        img_global = self.transform_global(img)

        if self.use_smart_patches:
            try:
                raw_patches = self.patch_extractor.extract_patches(img)
            except:
                raw_patches = self._extract_random_patches(img)
        else:
            raw_patches = self._extract_random_patches(img)

        patches_tensor = torch.stack(
            [self.transform_local(p) for p in raw_patches], dim=0
        )

        return img_global, patches_tensor, label


In [33]:
# ================== M√©canismes d'Attention ==================
class PatchAttention(nn.Module):
    def __init__(self, patch_dim):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(patch_dim, patch_dim // 2),
            nn.Tanh(),
            nn.Linear(patch_dim // 2, 1)
        )

    def forward(self, patch_features):
        attn_scores = self.attention(patch_features)
        attn_weights = F.softmax(attn_scores, dim=1)
        weighted_features = patch_features * attn_weights
        aggregated = weighted_features.sum(dim=1)
        return aggregated, attn_weights.squeeze(-1)


class AttentionFusion(nn.Module):
    def __init__(self, global_dim, local_dim, hidden_dim=256):
        super().__init__()

        self.global_proj = nn.Linear(global_dim, hidden_dim)
        self.local_proj = nn.Linear(local_dim, hidden_dim)

        self.attention = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 2),
            nn.Softmax(dim=1)
        )

        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

    def forward(self, global_feat, local_feat):
        g = self.global_proj(global_feat)
        l = self.local_proj(local_feat)

        combined = torch.cat([g, l], dim=1)
        att = self.attention(combined)

        fused = torch.cat([
            g * att[:, 0:1],
            l * att[:, 1:2]
        ], dim=1)

        return self.fusion(fused), att


In [34]:
class AdvancedLocalGlobalNet(nn.Module):
    def __init__(self, num_classes, use_attention=True):
        super().__init__()
        self.use_attention = use_attention

        vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        vit_dim = vit.heads.head.in_features
        vit.heads.head = nn.Identity()
        self.vit = vit

        eff = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
        eff_dim = eff.classifier[1].in_features
        eff.classifier = nn.Identity()
        self.eff = eff

        if use_attention:
            self.patch_attention = PatchAttention(eff_dim)
            self.attention_fusion = AttentionFusion(vit_dim, eff_dim, hidden_dim=384)
            classifier_input = 384
        else:
            classifier_input = vit_dim + eff_dim

        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(classifier_input, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.4),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, img_global, patches_local):
        B, N, C, H, W = patches_local.shape

        feat_global = self.vit(img_global)

        patches_flat = patches_local.view(B * N, C, H, W)
        feat_local_all = self.eff(patches_flat)
        feat_local_all = feat_local_all.view(B, N, -1)

        if self.use_attention:
            feat_local, _ = self.patch_attention(feat_local_all)
            fused, _ = self.attention_fusion(feat_global, feat_local)
        else:
            feat_local = feat_local_all.mean(dim=1)
            fused = torch.cat([feat_global, feat_local], dim=1)

        logits = self.classifier(fused)
        return logits


In [35]:
def make_dataloaders():
    print("üîÑ Chargement des datasets...")

    train_ds = SmartLeafDataset(
        TRAIN_DIR, transform_global_train, transform_local,
        num_patches=Config.num_patches, patch_size=96,
        use_smart_patches=True
    )

    val_ds = SmartLeafDataset(
        VAL_DIR, transform_global_val, transform_local,
        num_patches=Config.num_patches, patch_size=96,
        use_smart_patches=True
    )

    test_ds = SmartLeafDataset(
        TEST_DIR, transform_global_val, transform_local,
        num_patches=Config.num_patches, patch_size=96,
        use_smart_patches=True
    )

    train_loader = DataLoader(train_ds, batch_size=Config.batch_size, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=Config.batch_size, shuffle=False)
    test_loader  = DataLoader(test_ds,  batch_size=Config.batch_size, shuffle=False)

    print(f"‚û°Ô∏è {len(train_ds)} train, {len(val_ds)} val, {len(test_ds)} test")
    print(f"üìå Classes d√©tect√©es : {len(train_ds.classes)}")

    return train_loader, val_loader, test_loader, train_ds.classes, len(train_ds.classes)


train_loader, val_loader, test_loader, class_names, num_classes = make_dataloaders()


üîÑ Chargement des datasets...
‚û°Ô∏è 25617 train, 8129 val, 5417 test
üìå Classes d√©tect√©es : 38


In [36]:
device = "cuda" if torch.cuda.is_available() else "cpu"

raw = torch.load(
    r"C:\Users\ENNHILI YASSINE\Desktop\ASMAE-ABDELOUAFI\best_model_final.pth",
    map_location=device
)

state_dict = raw["model_state_dict"]
print("‚ú® Poids charg√©s depuis le fichier .pth !")


  raw = torch.load(


‚ú® Poids charg√©s depuis le fichier .pth !


In [37]:
model = AdvancedLocalGlobalNet(
    num_classes=38,
    use_attention=True
).to(device)

model.load_state_dict(state_dict)
model.eval()

print("‚úÖ Mod√®le hybride reconstruit et pr√™t pour l'√©valuation !")


‚úÖ Mod√®le hybride reconstruit et pr√™t pour l'√©valuation !


In [38]:
def evaluate_model(model, test_loader):
    y_true, y_pred = [], []

    model.eval()
    with torch.no_grad():
        for img_global, patches_local, labels in test_loader:
            img_global = img_global.to(device)
            patches_local = patches_local.to(device)
            labels = labels.to(device)

            outputs = model(img_global, patches_local)
            preds = outputs.argmax(1)

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    return {
        "accuracy": (y_true == y_pred).mean(),
        "f1_macro": f1_score(y_true, y_pred, average="macro"),
        "precision_macro": precision_score(y_true, y_pred, average="macro"),
        "recall_macro": recall_score(y_true, y_pred, average="macro"),
    }


In [39]:
metrics = evaluate_model(model, test_loader)

print("üìä R√©sultats du mod√®le final")
print("-------------------------------------")
print(f"Accuracy        : {metrics['accuracy']*100:.2f}%")
print(f"F1-macro        : {metrics['f1_macro']:.4f}")
print(f"Pr√©cision       : {metrics['precision_macro']:.4f}")
print(f"Rappel          : {metrics['recall_macro']:.4f}")

  return torch._native_multi_head_attention(


üìä R√©sultats du mod√®le final
-------------------------------------
Accuracy        : 92.41%
F1-macro        : 0.8705
Pr√©cision       : 0.8414
Rappel          : 0.9122


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
