In [1]:
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, ConcatDataset
import glob
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import torch.optim as optim
import torchvision.models as models
from peft import get_peft_model, LoraConfig, TaskType

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 9)

for param in model.parameters():
    param.requires_grad = False

target_modules = ["layer4.0.conv2", "layer4.1.conv2", "layer4.2.conv2"]

lora_config = LoraConfig( 
    r=4,             
    lora_alpha=1, 
    lora_dropout=0.0,
    target_modules=target_modules, 
)

model = get_peft_model(model, lora_config)

for param in model.fc.parameters():
    param.requires_grad = True
    
# Vérifier quels paramètres sont entraînables (devrait afficher uniquement ceux de layer4 ciblés et de fc)
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

model = model.to("cuda" if torch.cuda.is_available() else "cpu")
print(model)
optimizer = optim.SGD([p for p in model.parameters() if p.requires_grad], lr=0.001, momentum=0.9)

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])






base_model.model.layer4.0.conv2.lora_A.default.weight
base_model.model.layer4.0.conv2.lora_B.default.weight
base_model.model.layer4.1.conv2.lora_A.default.weight
base_model.model.layer4.1.conv2.lora_B.default.weight
base_model.model.layer4.2.conv2.lora_A.default.weight
base_model.model.layer4.2.conv2.lora_B.default.weight
base_model.model.fc.weight
base_model.model.fc.bias
PeftModel(
  (base_model): LoraModel(
    (model): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2

In [3]:
checkpoint = torch.load('model_final.pth', map_location=device)
model.load_state_dict(checkpoint)

  checkpoint = torch.load('model_final_mixed_rand.pth', map_location=device)


<All keys matched successfully>

In [4]:
WORDNET_TO_CLASS = {
    "n02084071": "Dog",
    "n01503061": "Bird",
    "n04576211": "Vehicle",
    "n01661091": "Reptile",
    "n02075296": "Carnivore",
    "n02159955": "Insect",
    "n03800933": "Instrument",
    "n02469914": "Primate",
    "n02512053": "Fish"
}

CLASS_TO_IDX = {WORDNET_TO_CLASS[v].lower(): i for i, v in enumerate(sorted(WORDNET_TO_CLASS.keys()))}
print(CLASS_TO_IDX)



{'bird': 0, 'reptile': 1, 'carnivore': 2, 'dog': 3, 'insect': 4, 'primate': 5, 'fish': 6, 'instrument': 7, 'vehicle': 8}


In [5]:

base_path = "/users/eleves-b/2022/yassine.guennoun/Desktop/bg_challenge"
attacked_categories = [
    "original",
    "mixed_next",
    "mixed_rand",
    "mixed_same",
    "no_fg",
    "only_bg_b",
    "only_bg_t",
    "only_fg"
]

class AttackedDataset(Dataset):
    def __init__(self, root, class_to_idx, transform=None):
        self.samples = []
        self.transform = transform

        for class_folder in os.listdir(root):
            class_path = os.path.join(root, class_folder)
            if not os.path.isdir(class_path):
                continue  

            class_name = class_folder.split("_", 1)[1].lower()
            if class_name == "wheeled vehicle":
                class_name = "vehicle"
            if class_name == "musical instrument":
                class_name = "instrument"
            if class_name not in class_to_idx:
                print(f"⚠️  Ignoré : {class_folder} (pas dans CLASS_TO_IDX)")
                continue  

            label = class_to_idx[class_name]  

            valid_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".JPEG"]
            for ext in valid_extensions:
                for img_path in glob.glob(os.path.join(class_path, f"*{ext}")):
                    self.samples.append((img_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert("RGB")  
        
        if self.transform:
            img = self.transform(img)

        return img, label

# Stocker les accuracy par catégorie
category_accuracies = {}

total_correct = 0
total_samples = 0


for cat in attacked_categories:
    cat_val_path = os.path.join(base_path, cat, "val")

    if os.path.isdir(cat_val_path):
        dataset = AttackedDataset(root=cat_val_path, class_to_idx=CLASS_TO_IDX, transform=transform)
        data_loader = DataLoader(dataset, batch_size=32, shuffle=False)

        
        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in data_loader:
                images, labels = images.to(device), labels.to(device)
                
                outputs = model(images)
                _, preds = torch.max(outputs, 1)  

                correct += (preds == labels).sum().item()
                total += labels.size(0)

       
        accuracy = correct / total if total > 0 else 0.0
        category_accuracies[cat] = accuracy

        
        total_correct += correct
        total_samples += total

        print(f"✅ Accuracy sur {cat} : {accuracy:.4f}")

    else:
        print(f"Dossier {cat_val_path} introuvable. Il sera ignoré.")


global_accuracy = total_correct / total_samples if total_samples > 0 else 0.0




print(f"\n Accuracy globale sur toutes les catégories attaquées : {global_accuracy:.4f}")

✅ Accuracy sur original : 0.9193
✅ Accuracy sur mixed_next : 0.8489
✅ Accuracy sur mixed_rand : 0.8464
✅ Accuracy sur mixed_same : 0.8840
✅ Accuracy sur no_fg : 0.4963
✅ Accuracy sur only_bg_b : 0.1272
✅ Accuracy sur only_bg_t : 0.1689
✅ Accuracy sur only_fg : 0.8968

🎯 Accuracy globale sur toutes les catégories attaquées : 0.6485
