<a href="https://colab.research.google.com/github/Bassendiaye/lowdiscovery/blob/main/notebooks/aug_during_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision transformers albumentations opencv-python scikit-learn tqdm

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
import os
import random
import numpy as np
from PIL import Image
from collections import Counter

import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torch.nn as nn
import torch.optim as optim

from torchvision import transforms
from transformers import DeiTFeatureExtractor, DeiTForImageClassification

import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
# ---------------------- Reproducibilité ----------------------
SEED = 42
def set_seed(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(SEED)

In [None]:
# ------------------- Transformations dynamiques -------------------
AUGMENTATIONS = [
    A.HorizontalFlip(p=1),
    A.VerticalFlip(p=1),
    A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1),
    A.GridDistortion(num_steps=5, distort_limit=0.3, p=1),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1),
    A.RGBShift(r_shift_limit=20, g_shift_limit=20, b_shift_limit=20, p=1),
    A.HueSaturationValue(hue_shift_limit=15, sat_shift_limit=20, val_shift_limit=15, p=1),
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=1),
    A.Defocus(radius=(3, 5), p=1),
    A.MotionBlur(blur_limit=(3, 5), p=1),
    A.GaussianBlur(blur_limit=(3, 5), p=1)
]

train_transform = A.Compose([
    A.OneOf(AUGMENTATIONS, p=1),
    A.Resize(224, 224),
    A.Normalize(mean=(0.5,), std=(0.5,)),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=(0.5,), std=(0.5,)),
    ToTensorV2()
])

  A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1),


In [None]:
# ------------------------- Dataset personnalisé -------------------------
class ImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = np.array(Image.open(self.image_paths[idx]).convert("RGB"))
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image=image)["image"]
        return image, label

In [None]:
# ------------------------- Chargement des données -------------------------
def load_dataset(data_dir):
    image_paths = []
    labels = []
    class_to_idx = {cls_name: i for i, cls_name in enumerate(sorted(os.listdir(data_dir)))}

    for class_name, class_idx in class_to_idx.items():
        class_dir = os.path.join(data_dir, class_name)
        if not os.path.isdir(class_dir):
            continue
        for img_file in os.listdir(class_dir):
            if img_file.lower().endswith(('jpg', 'png', 'jpeg')):
                image_paths.append(os.path.join(class_dir, img_file))
                labels.append(class_idx)

    return image_paths, labels, class_to_idx

In [None]:
# ------------------------- Oversampling Sampler -------------------------
def create_weighted_sampler(labels):
    class_counts = Counter(labels)
    class_weights = {cls: 1.0 / count for cls, count in class_counts.items()}
    sample_weights = [class_weights[label] for label in labels]
    sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
    return sampler

In [None]:
# ------------------------- Training loop -------------------------
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss, correct = 0.0, 0
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
    return running_loss / len(dataloader.dataset), correct / len(dataloader.dataset)

In [None]:
# ------------------------- Validation -------------------------
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix # Add imports for evaluation metrics

def evaluate(model, dataloader, device):
    model.eval()
    preds_all, labels_all = [], []
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            outputs = model(images).logits
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            preds_all.extend(preds)
            labels_all.extend(labels.numpy())

    acc = accuracy_score(labels_all, preds_all)
    f1 = f1_score(labels_all, preds_all, average="macro")
    cm = confusion_matrix(labels_all, preds_all)
    return acc, f1, cm

In [None]:
# ------------------------- Main Script -------------------------
import os # Add os import here
from sklearn.model_selection import train_test_split # Already imported, keep it
from transformers import DeiTConfig, DeiTForImageClassification # Add transformers imports here
import torch.nn as nn # Add nn import here
import torch.optim as optim # Add optim import here
from torch.utils.data import DataLoader # Add DataLoader import here
from collections import Counter # Add Counter import here
from torch.utils.data import WeightedRandomSampler # Add WeightedRandomSampler import here


def main():
    data_dir = "/content/drive/MyDrive/Dossier_de_Basse/Data_paper"
    image_paths, labels, class_to_idx = load_dataset(data_dir)

    from sklearn.model_selection import train_test_split
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        image_paths, labels, test_size=0.2, stratify=labels, random_state=SEED)

    train_dataset = ImageDataset(train_paths, train_labels, transform=train_transform)
    val_dataset = ImageDataset(val_paths, val_labels, transform=val_transform)

    sampler = create_weighted_sampler(train_labels)
    train_loader = DataLoader(train_dataset, batch_size=32, sampler=sampler)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    from transformers import AutoModelForImageClassification
    model = AutoModelForImageClassification.from_pretrained(
        "facebook/deit-small-patch16-224",
        num_labels=len(class_to_idx),
        ignore_mismatched_sizes=True  # Ajout pour éviter les erreurs sur les têtes de classification
    )
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=0.0001)

    for epoch in range(1, 101):
        print(f"\n--- Époque {epoch} ---")
        train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
        val_acc, val_f1, val_cm = evaluate(model, val_loader, device)

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val Acc: {val_acc:.4f}, F1 Score: {val_f1:.4f}")
        #print(f"Confusion matrix:\n{val_cm}")

if __name__ == "__main__":
    main()

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/88.3M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at facebook/deit-small-patch16-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 384]) in the checkpoint and torch.Size([8, 384]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([8]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



--- Époque 1 ---


model.safetensors:   0%|          | 0.00/88.2M [00:00<?, ?B/s]

Train Loss: 0.5582, Train Acc: 0.8069
Val Acc: 0.8553, F1 Score: 0.4663

--- Époque 2 ---
Train Loss: 0.2457, Train Acc: 0.9213
Val Acc: 0.8238, F1 Score: 0.5139

--- Époque 3 ---
Train Loss: 0.1927, Train Acc: 0.9364
Val Acc: 0.8497, F1 Score: 0.4860

--- Époque 4 ---
Train Loss: 0.1443, Train Acc: 0.9549
Val Acc: 0.8433, F1 Score: 0.4722

--- Époque 5 ---
Train Loss: 0.1320, Train Acc: 0.9567
Val Acc: 0.8343, F1 Score: 0.5344

--- Époque 6 ---
Train Loss: 0.1289, Train Acc: 0.9582
Val Acc: 0.9057, F1 Score: 0.5558

--- Époque 7 ---
Train Loss: 0.1106, Train Acc: 0.9637
Val Acc: 0.9440, F1 Score: 0.5854

--- Époque 8 ---
Train Loss: 0.1104, Train Acc: 0.9635
Val Acc: 0.9443, F1 Score: 0.6338

--- Époque 9 ---
Train Loss: 0.0943, Train Acc: 0.9695
Val Acc: 0.9087, F1 Score: 0.6089

--- Époque 10 ---
Train Loss: 0.0900, Train Acc: 0.9699
Val Acc: 0.9075, F1 Score: 0.6480

--- Époque 11 ---
Train Loss: 0.0829, Train Acc: 0.9726
Val Acc: 0.9279, F1 Score: 0.6220

--- Époque 12 ---
Train L

In [None]:
# ------------------------- Main Script -------------------------
"""import os
from sklearn.model_selection import train_test_split
from transformers import AutoModelForImageClassification
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from collections import Counter
from torch.utils.data import WeightedRandomSampler
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
import copy


def main():
    data_dir = "/content/drive/MyDrive/Dossier_de_Basse/Data_paper"
    image_paths, labels, class_to_idx = load_dataset(data_dir)

    train_paths, val_paths, train_labels, val_labels = train_test_split(
        image_paths, labels, test_size=0.2, stratify=labels, random_state=SEED)

    train_dataset = ImageDataset(train_paths, train_labels, transform=train_transform)
    val_dataset = ImageDataset(val_paths, val_labels, transform=val_transform)

    sampler = create_weighted_sampler(train_labels)
    train_loader = DataLoader(train_dataset, batch_size=32, sampler=sampler)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    model = AutoModelForImageClassification.from_pretrained(
        "facebook/deit-small-patch16-224",
        num_labels=len(class_to_idx),
        ignore_mismatched_sizes=True
    )
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=0.0001)

    best_val_f1 = 0.0
    epochs_no_improve = 0
    patience = 10  # Number of epochs to wait before stopping

    for epoch in range(1, 101):
        print(f"\n--- Époque {epoch} ---")
        train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
        val_acc, val_f1, val_cm = evaluate(model, val_loader, device)

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val Acc: {val_acc:.4f}, F1 Score: {val_f1:.4f}")

        # Early stopping logic
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            epochs_no_improve = 0
            # Save the best model
            best_model_wts = copy.deepcopy(model.state_dict())
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early stopping after {epoch} epochs.")
                model.load_state_dict(best_model_wts) # Load best model weights
                break

if __name__ == "__main__":
    main()"""