In [28]:
import os
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from sklearn.model_selection import StratifiedShuffleSplit
from PIL import Image
import torchvision.transforms.functional as F

# === 1) Préparation des données ===
class ResizeWithPadding:
    def __init__(self, target_size, fill=0):
        self.target_size = target_size
        self.fill = fill

    def __call__(self, img: Image.Image):
        w, h = img.size
        scale = self.target_size / max(w, h)
        new_w, new_h = int(w*scale), int(h*scale)
        img = img.resize((new_w, new_h), Image.BILINEAR)
        pad_w = self.target_size - new_w
        pad_h = self.target_size - new_h
        padding = (pad_w//2, pad_h//2, pad_w-pad_w//2, pad_h-pad_h//2)
        return F.pad(img, padding, fill=self.fill)

def get_dataloaders(data_dir, img_size=512, batch_size=16,
                    val_split=0.2, seed=42, num_workers=4):
    transform = transforms.Compose([
        ResizeWithPadding(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], #données de la normalization imagenet
                             std=[0.229,0.224,0.225])
    ])
    dataset = datasets.ImageFolder(root=data_dir, transform=transform)
    targets = [y for _, y in dataset]
    splitter = StratifiedShuffleSplit(n_splits=1,
                                      test_size=val_split,
                                      random_state=seed)
    train_idx, val_idx = next(splitter.split(range(len(dataset)), targets))
    train_ds = Subset(dataset, train_idx)
    val_ds   = Subset(dataset, val_idx)
    train_loader = DataLoader(train_ds, batch_size=batch_size,
                              shuffle=True, num_workers=num_workers)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size,
                              shuffle=False, num_workers=num_workers)
    return train_loader, val_loader, dataset.classes

# === 2) Modèle : CoCa + tête de classification ===
from vit_pytorch.simple_vit_with_patch_dropout import SimpleViT
from vit_pytorch.extractor import Extractor
from coca_pytorch import CoCa

class CoCaFishClassifier(nn.Module):
    def __init__(self, coca_model: CoCa, num_classes: int, dropout=0.3):
        super().__init__()
        self.coca = coca_model
        latent_dim = coca_model.dim  # par défaut dim_latents = dim
        self.classifier = nn.Sequential(
            nn.Linear(latent_dim, latent_dim//2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(latent_dim//2, num_classes)
        )

    def forward(self, images):
        img_emb, _ = self.coca.embed_image(images=images)
        latents   = self.coca.img_to_latents(img_emb)
        return self.classifier(latents)

# === 3) Boucle d’entraînement / validation ===
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss, total_correct, total = 0, 0, 0
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        logits = model(imgs)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)
        preds = logits.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total += imgs.size(0)
    return total_loss / total, total_correct / total

def validate(model, loader, criterion, device):
    model.eval()
    total_loss, total_correct, total = 0, 0, 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            logits = model(imgs)
            loss = criterion(logits, labels)
            total_loss += loss.item() * imgs.size(0)
            preds = logits.argmax(dim=1)
            total_correct += (preds == labels).sum().item()
            total += imgs.size(0)
    return total_loss / total, total_correct / total




In [30]:
print(train_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7d230008d790>


In [None]:
# === 4) Main ===
def main():
    # Config
    data_dir    = "data"
    out_dir     = "checkpoints"
    os.makedirs(out_dir, exist_ok=True)
    device      = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    epochs      = 200
    batch_size  = 16
    lr          = 1e-4
    weight_decay= 1e-2
    val_split   = 0.2

    # 1) Data
    train_loader, val_loader, class_names = get_dataloaders(
        data_dir, img_size=512, batch_size=batch_size,
        val_split=val_split, seed=42
    )
    num_classes = len(class_names)
    print(f"Classes ({num_classes}) : {class_names}")

    # 2) CoCa + Vit
    vit = SimpleViT(
        image_size=256,
        patch_size=32,
        num_classes=0,
        dim=1024, depth=6,
        heads=16, mlp_dim=2048,
        patch_dropout=0.5
    )
    vit = Extractor(vit, return_embeddings_only=True)
    coca = CoCa(
        dim=512,
        img_encoder=vit,
        image_dim=1024,
        num_tokens=20000,
        unimodal_depth=6,
        multimodal_depth=6,
        dim_head=64,
        heads=8,
    ).to(device)

    model = CoCaFishClassifier(coca_model=coca,
                               num_classes=num_classes,
                               dropout=0.3).to(device)

    # 3) Optim + criterion
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    best_acc = 0.
    for epoch in range(1, epochs+1):
        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device)
        val_loss, val_acc     = validate(
            model, val_loader, criterion, device)

        print(f"[Epoch {epoch:02d}] "
              f"Train loss: {train_loss:.4f}, acc: {train_acc:.4f} | "
              f"Val loss: {val_loss:.4f}, acc: {val_acc:.4f}")

        # Sauvegarde si meilleure val accuracy
        if val_acc > best_acc:
            best_acc = val_acc
            ckpt_path = os.path.join(out_dir, f"best_epoch{epoch:02d}.pth")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'class_names': class_names
            }, ckpt_path)
            print(f"→ New best model saved at {ckpt_path}")

if __name__ == "__main__":
    main()

Classes (10) : ['acanthurus.sp', 'caphalopholis.cruentata', 'chaetodon.capistratus', 'haemulon.chrysargyreum', 'haemulon.flavolineatum', 'haemulon.plumierii', 'holocentrus.adscensionis', 'scarus.iseri', 'sparisoma.aurofrenatum', 'sparisoma.viride']




[Epoch 01] Train loss: 2.2944, acc: 0.1477 | Val loss: 2.2875, acc: 0.1304
→ New best model saved at checkpoints/best_epoch01.pth
[Epoch 02] Train loss: 2.2810, acc: 0.1591 | Val loss: 2.2771, acc: 0.1304
[Epoch 03] Train loss: 2.2754, acc: 0.1591 | Val loss: 2.2696, acc: 0.1304
[Epoch 04] Train loss: 2.2686, acc: 0.1818 | Val loss: 2.2631, acc: 0.1304
[Epoch 05] Train loss: 2.2624, acc: 0.1932 | Val loss: 2.2573, acc: 0.1304
[Epoch 06] Train loss: 2.2557, acc: 0.2386 | Val loss: 2.2520, acc: 0.1739
→ New best model saved at checkpoints/best_epoch06.pth
[Epoch 07] Train loss: 2.2532, acc: 0.2159 | Val loss: 2.2434, acc: 0.2174
→ New best model saved at checkpoints/best_epoch07.pth
[Epoch 08] Train loss: 2.2410, acc: 0.2386 | Val loss: 2.2329, acc: 0.2174
[Epoch 09] Train loss: 2.2324, acc: 0.2614 | Val loss: 2.2206, acc: 0.3043
