In [None]:
from PIL import Image
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.preprocessing.text import Tokenizer
from sklearn.model_selection import train_test_split

import torchvision.transforms as transforms
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from pathlib import Path
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.preprocessing.text import Tokenizer
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from enum import Enum

### UTILS BASE MODEL

In [None]:
class Label(Enum):
    DOG = "dog"
    BIKE = "bike"
    BALL = "ball"
    WATER = "water"    


### UTILS METHODS

In [None]:
def get_label(filename: str):
    return filename.split("_")[0]


def get_uuid(filename: str):
    name = Path(filename).stem          
    parts = name.split("_")
    return "_".join(parts[:2])          


def build_augmented_path(img_path: Path, base_dir: Path):
    img_path = Path(img_path)
    filename = img_path.name
    label = get_label(filename)
    uuid = get_uuid(filename)
    pres = "_".join(filename.split(".")[0].split("_")[1:])
    return base_dir / uuid / pres / filename


def make_class_names(dataset):
    # dataset.class_to_idx: {label_str: idx}
    idx_to_class = {v: k for k, v in dataset.class_to_idx.items()}
    return [idx_to_class[i] for i in range(len(idx_to_class))]


def count_oov_pct(sequences, oov_id):
    total = sum(1 for seq in sequences for tid in seq if tid != 0)
    oov = sum(1 for seq in sequences for tid in seq if tid == oov_id)
    return oov/total*100 if total > 0 else 0

class_to_idx = {
    "ball":  0,
    "bike":  1,
    "dog":   2,
    "water": 3,
}

class_names = ["ball", "bike", "dog", "water"]



In [None]:
def load_model_weights(model: nn.Module, ckpt_path: str, strict: bool = True):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    state = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt
    model.load_state_dict(state, strict=strict)
    return model

def load_without_classifier(model: nn.Module, ckpt_path: str, classifier_prefix="classifier.", map_location="cpu"):
    ckpt = torch.load(ckpt_path, map_location=map_location)
    state = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt

    filtered = {k: v for k, v in state.items() if not k.startswith(classifier_prefix)}
    missing, unexpected = model.load_state_dict(filtered, strict=False)
    print("Missing:", missing)
    print("Unexpected:", unexpected)
    return model


### DATASET CLIP MODEL

In [None]:
from typing import Tuple
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

class CLIPDataset(Dataset):
    def __init__(
        self,
        df,
        base_dir: Path,
        transform,
        keras_tokenizer=None,
        hf_tokenizer=None,
        max_seq_len: int = 64,
        vocab_size: int = 1000,
        text_mode: str = "keras"  # 'keras', 'hf'
    ):
        self.base_dir = Path(base_dir)
        self.transform = transform
        self.max_seq_len = max_seq_len
        self.vocab_size = 1000
        self.text_mode = text_mode
        
        self.img_paths = df["image_path"].tolist()
        self.labels = df["label"].tolist()
        self.captions = df["caption"].tolist()
        
        self.classes = sorted(set(self.labels))
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
        self.num_classes = len(self.classes)
        
        self.keras_tokenizer = keras_tokenizer
        self.hf_tokenizer = hf_tokenizer
        
        if self.keras_tokenizer is not None:
            oov_tok = getattr(self.keras_tokenizer, "oov_token", "<OOV>")
            self.oov_id = self.keras_tokenizer.word_index.get(oov_tok, 1)
            if self.oov_id is None or self.oov_id >= self.vocab_size:
                self.oov_id = 1
        
        self._validate_tokenizers()
    
    def _validate_tokenizers(self):
        if self.text_mode in ["keras", "both"] and self.keras_tokenizer is None:
            raise ValueError("keras_tokenizer requis pour text_mode='keras' ou 'both'")
        if self.text_mode in ["hf", "both"] and self.hf_tokenizer is None:
            raise ValueError("hf_tokenizer requis pour text_mode='hf' ou 'both'")
    
    def _build_image_path(self, img_path: str) -> Path:
        return Path(build_augmented_path(img_path, self.base_dir))
    
    def _tokenize_keras(self, text: str) -> torch.Tensor:
        seq = self.keras_tokenizer.texts_to_sequences([text])[0]
        seq = [
            t if (t is not None and 0 <= t < self.vocab_size) else self.oov_id
            for t in seq
        ]
        padded = pad_sequences(
            [seq],
            maxlen=self.max_seq_len,
            padding="post",
            truncating="post",
            value=0
        )[0].astype(np.int64)
        
        return torch.tensor(padded, dtype=torch.long)
    
    def _tokenize_hf(self, text: str) -> dict:
        enc = self.hf_tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_seq_len,
            return_tensors="pt"
        )
        return {
            "input_ids": enc["input_ids"].squeeze(0).to(torch.long),
            "attention_mask": enc["attention_mask"].squeeze(0).to(torch.long)
        }
    
    def __getitem__(self, idx: int) -> dict:
        img_path = self._build_image_path(self.img_paths[idx])
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        
        caption = str(self.captions[idx])
        
        label_str = self.labels[idx]
        label = torch.tensor(self.class_to_idx[label_str], dtype=torch.long)
        
        output = {
            "idx": torch.tensor(idx, dtype=torch.long),
            "image": img,
            "label": label,
            "caption": caption,
        }
        
        if self.text_mode == "keras":
            input_ids_keras = self._tokenize_keras(caption)
            output["input_ids_keras"] = input_ids_keras
            output["attention_mask_keras"] = (input_ids_keras != 0).long()
            
            if self.text_mode == "keras":
                output["input_ids"] = input_ids_keras
                output["attention_mask"] = output["attention_mask_keras"]
        
        if self.text_mode == "hf":
            hf_tokens = self._tokenize_hf(caption)
            output["input_ids_hf"] = hf_tokens["input_ids"]
            output["attention_mask_hf"] = hf_tokens["attention_mask"]
            
            if self.text_mode == "hf":
                output["input_ids"] = hf_tokens["input_ids"]
                output["attention_mask"] = hf_tokens["attention_mask"]
        
        return output
    
    def __len__(self) -> int:
        return len(self.img_paths)
    
    
    def _get_img_path_from_idx(self, idx: int) -> Path:
        return self._build_image_path(self.img_paths[idx])
    
    def _get_caption_from_idx(self, idx: int) -> str:
        return self.captions[idx]
    
    def _get_label_from_idx(self, idx: int) -> str:
        return self.labels[idx]
    
    def _get_img_size(self, idx: int) -> Tuple[int, int]:
        img_path = self._build_image_path(self.img_paths[idx])
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            for t in self.transform.transforms:
                if isinstance(t, transforms.Resize):
                    img = t(img)
        return img.height, img.width

### TRANSFORMS

In [None]:
transform = transforms.Compose(
    [transforms.Resize((300, 500)),
        transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [None]:
transform_resnet = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

### VISION CNN ENCODER

In [None]:
class CNNBasic(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2,2),
            nn.Conv2d(16, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2,2),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2,2)
        )

        self.flattened_size = self._get_flattened_size()

        self.classifier = nn.Sequential(
            nn.Linear(self.flattened_size, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, num_classes)
        )

    def _get_flattened_size(self):
        with torch.no_grad():
            x = torch.randn(1, 3, 300, 500)
            x = self.features(x)
            return x.view(1, -1).shape[1]

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)


In [None]:
class ImageEncoderFromCNNBasic(nn.Module):
    def __init__(self, cnn_basic: nn.Module, embed_dim: int = 256, proj_dim: int = 256):
        super().__init__()
        self.features = cnn_basic.features

        self.flattened_size = cnn_basic.flattened_size
        self.backbone_fc = nn.Sequential(
            nn.Linear(self.flattened_size, embed_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
        )
        self.proj = nn.Sequential(
            nn.Linear(embed_dim, proj_dim),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.backbone_fc(x)
        x = self.proj(x)
        x = F.normalize(x, dim=-1)
        return x


### VISUAL RESNET ENCODER

In [None]:
from torchvision.models import resnet18, ResNet18_Weights

class ImageEncoderFromResNet18(nn.Module):
    def __init__(self, resnet: nn.Module, proj_dim: int = 256, train_backbone: bool = True):
        super().__init__()
        self.backbone = resnet

        self.backbone.fc = nn.Identity()

        self.proj = nn.Linear(512, proj_dim)

        if not train_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        feats = self.backbone(x)
        z = self.proj(feats)
        return F.normalize(z, dim=-1)
def load_best_resnet18(num_classes=4, ckpt_path="../notebooks/best-model-resnet.pth", device="cpu"):
    model = resnet18(weights=None)  # weights=None car on charge tes poids
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    ckpt = torch.load(ckpt_path, map_location=device)
    state = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt
    model.load_state_dict(state, strict=True)
    return model

### SMALL BERT TEXT ENCODER

In [None]:
class PositionalEmbedding(nn.Module):
    def __init__(self, sequence_length: int, vocab_size:int, embed_dim:int):
        super().__init__()
        self.token_embeddings = nn.Embedding(vocab_size, embed_dim)
        self.position_embeddings = nn.Embedding(sequence_length, embed_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len = x.size()
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, seq_len)
        return self.token_embeddings(x) + self.position_embeddings(positions)
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, dropout_rate: float = 0.1) -> None:
        super().__init__()
        self.att = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout_rate, batch_first=True)

        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )

        self.layernorm1 = nn.LayerNorm(embed_dim)
        self.layernorm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)

    def forward(self, x: torch.Tensor, padding_mask: torch.Tensor = None) -> torch.Tensor:
        attn_output, _ = self.att(
            x, x, x,
            key_padding_mask=padding_mask  # <-- correct masking
        )

        x = self.layernorm1(x + self.dropout1(attn_output))
        ffn_output = self.ffn(x)
        out = self.layernorm2(x + self.dropout2(ffn_output))
        return out


In [None]:
class SmallBERT(nn.Module):
    def __init__(self, sequence_length: int, vocab_size: int, embed_dim: int,
                 num_heads: int, ff_dim: int, num_layers: int) -> None:
        super().__init__()
        self.pos_embedding = PositionalEmbedding(sequence_length, vocab_size, embed_dim)
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, ff_dim)
            for _ in range(num_layers)
        ])
        self.layernorm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (batch, seq_len) token ids
        """
        padding_mask = (x == 0)  # 0 = PAD  -> True masked in attn
        x = self.pos_embedding(x)

        for block in self.blocks:
            x = block(x, padding_mask=padding_mask)

        x = self.layernorm(x)
        return self.dropout(x)

class SmallBERTPourClassification(nn.Module):
    def __init__(self, sequence_length: int, vocab_size: int, embed_dim: int,
                 num_heads: int, ff_dim: int, num_layers: int,
                 num_classes: int = 4) -> None:
        super().__init__()

        self.encoder = SmallBERT(sequence_length, vocab_size, embed_dim, num_heads, ff_dim, num_layers)
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        enc = self.encoder(x)             # (batch, seq_len, embed_dim)
        pooled = enc.mean(dim=1)          # mean pooling
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)
        return logits  # logits only, pas softmax


In [None]:
class TextEncoderFromSmallBERT(nn.Module):
    def __init__(self, text_model: nn.Module, proj_dim: int = 256):
        super().__init__()
        self.encoder = text_model.encoder 
        self.dropout = nn.Dropout(0.1)

        embed_dim = text_model.classifier.in_features

        self.proj = nn.Linear(embed_dim, proj_dim)

    def forward(self, x, padding_mask=None):
        enc = self.encoder(x)  
        pooled = enc.mean(dim=1)
        pooled = self.dropout(pooled)
        z = self.proj(pooled)
        z = F.normalize(z, dim=-1)
        return z


### TEXT DISTILLBERT ENCODER MODEL


In [None]:
class TextEncoderFromDistilBERT(nn.Module):
    def __init__(self, distilbert_cls_model, proj_dim=256):
        super().__init__()
        self.backbone = distilbert_cls_model.distilbert
        hidden = self.backbone.config.hidden_size
        self.proj = nn.Linear(hidden, proj_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, input_ids, attention_mask):
        out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        h = out.last_hidden_state  

        mask = attention_mask.unsqueeze(-1).float()         
        summed = (h * mask).sum(dim=1)                      
        denom = mask.sum(dim=1).clamp(min=1e-6)              
        pooled = summed / denom                             

        pooled = self.dropout(pooled)
        z = self.proj(pooled)
        return F.normalize(z, dim=-1)

In [None]:
def load_hf_classifier_checkpoint(model, ckpt_path, map_location="cpu"):
    ckpt = torch.load(ckpt_path, map_location=map_location)

    if isinstance(ckpt, dict) and any(k.startswith("distilbert.") or k.startswith("classifier.") for k in ckpt.keys()):
        state = ckpt
    elif isinstance(ckpt, dict) and "state_dict" in ckpt:
        state = ckpt["state_dict"]
    elif isinstance(ckpt, dict) and "model_state_dict" in ckpt:
        state = ckpt["model_state_dict"]
    else:
        raise ValueError("Format checkpoint non reconnu")

    missing, unexpected = model.load_state_dict(state, strict=False)
    print("Missing:", missing)
    print("Unexpected:", unexpected)
    return model

### CONTRASTIVE MODEL

In [None]:
class CLIPLikeModel(nn.Module):
    def __init__(self, image_encoder: nn.Module, text_encoder: nn.Module, init_logit_scale=1/0.07):
        super().__init__()
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.logit_scale = nn.Parameter(torch.tensor(float(torch.log(torch.tensor(init_logit_scale)))))

    def forward(self, images, tokens, padding_mask=None):
        img = self.image_encoder(images)                   #(B, D)
        txt = self.text_encoder(tokens, padding_mask)      #(B, D)
        logit_scale = self.logit_scale.exp().clamp(1e-3, 100.0)
        logits = logit_scale * (img @ txt.t())             #(B, B)
        return logits, img, txt


def clip_contrastive_loss(logits):
    b = logits.size(0)
    labels = torch.arange(b, device=logits.device)

    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.t(), labels)
    return (loss_i2t + loss_t2i) / 2



### CONTRASTIVE HF MODEL

In [None]:
class CLIPLikeModelHF(nn.Module):
    def __init__(self, image_encoder, text_encoder, init_temp=0.07):
        super().__init__()
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.logit_scale = nn.Parameter(torch.log(torch.tensor(1.0/init_temp)))

    def forward(self, images, input_ids, attention_mask):
        img = self.image_encoder(images)                         
        txt = self.text_encoder(input_ids, attention_mask)       
        scale = self.logit_scale.exp().clamp(1e-3, 100.0)
        logits = scale * (img @ txt.t())
        return logits, img, txt

### TRAINING METHOD

In [None]:
class EarlyStopping:
    def __init__(self, patience=3, min_delta=0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.best_score = None
        self.counter = 0
        self.should_stop = False

    def step(self, metric):
        if self.best_score is None:
            self.best_score = metric
        elif metric < self.best_score + self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
        else:
            self.best_score = metric
            self.counter = 0


In [None]:
def clip_collate_keras(batch):
    images = torch.stack([b["image"] for b in batch], dim=0)  # (B,3,H,W)
    input_ids = torch.stack([b["input_ids"] for b in batch], dim=0)  # (B,L)
    attention_mask = torch.stack([b["attention_mask"] for b in batch], dim=0)  # (B,L)
    labels = torch.stack([b["label"] for b in batch], dim=0)  # (B,)
    idx = torch.stack([b["idx"] for b in batch], dim=0)  # (B,)
    return {"images": images, "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "idx": idx}


In [None]:
def train_step(batch, model, optimizer, device="cuda"):
    images, tokens = batch  # + éventuellement padding_mask
    images = images.to(device)
    tokens = tokens.to(device)

    model.train()
    optimizer.zero_grad()

    logits, _, _ = model(images, tokens)
    loss = clip_contrastive_loss(logits)

    loss.backward()
    optimizer.step()
    return loss.item()
@torch.no_grad()
def val_step(batch, model, device="cuda"):
    images = batch["images"].to(device)
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    padding_mask = (attention_mask == 0)

    model.eval()
    logits, _, _ = model(images, input_ids, padding_mask=padding_mask)
    loss = clip_contrastive_loss(logits)
    return loss.item()
def fit(model, train_loader, val_loader, optimizer, epochs=10, device="cuda"):
    model.to(device)

    for epoch in range(1, epochs + 1):
        train_losses = []
        for batch in tqdm(train_loader, desc=f"Epoch {epoch} [train]"):
            train_losses.append(train_step(batch, model, optimizer, device))

        val_losses = []
        for batch in tqdm(val_loader, desc=f"Epoch {epoch} [val]"):
            val_losses.append(val_step(batch, model, device))

        print(
            f"Epoch {epoch} | "
            f"train_loss={sum(train_losses)/len(train_losses):.4f} | "
            f"val_loss={sum(val_losses)/len(val_losses):.4f}"
        )



In [None]:
def train_step_hf(batch, model, optimizer, device="cuda"):
    images = batch["images"].to(device)
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)

    model.train()
    optimizer.zero_grad(set_to_none=True)

    logits, _, _ = model(images, input_ids, attention_mask=attention_mask)
    loss = clip_contrastive_loss(logits)

    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    return loss.item()

@torch.no_grad()
def val_step_hf(batch, model, device="cuda"):
    images = batch["images"].to(device)
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)

    model.eval()
    logits, _, _ = model(images, input_ids, attention_mask=attention_mask)
    loss = clip_contrastive_loss(logits)
    return loss.item()

def fit_hf(model, train_loader, val_loader, optimizer, epochs=10, device="cuda"):
    model.to(device)

    for epoch in range(1, epochs + 1):
        train_losses = []
        for batch in train_loader:
            train_losses.append(train_step_hf(batch, model, optimizer, device))

        val_losses = []
        for batch in val_loader:
            val_losses.append(val_step_hf(batch, model, device))

        print(f"Epoch {epoch} | train_loss={sum(train_losses)/len(train_losses):.4f} | val_loss={sum(val_losses)/len(val_losses):.4f}")



### EVAL METHODS

In [None]:
@torch.no_grad()
def recall_at_k_from_sims(sims: torch.Tensor, ks=(1,5,10)):
    N = sims.size(0)
    gt = torch.arange(N, device=sims.device).unsqueeze(1)  # (N,1)
    ranks = torch.argsort(sims, dim=1, descending=True)     # (N,N)
    out = {}
    for k in ks:
        hit = (ranks[:, :k] == gt).any(dim=1).float().mean().item()
        out[f"R@{k}"] = hit
    return out


@torch.no_grad()
def evaluate_clip_recall(index, ks=(1,5,10), device="cpu"):
    img = index["img_embs"].to(device)
    txt = index["txt_embs"].to(device)

    sims_t2i = txt @ img.t()
    sims_i2t = img @ txt.t()

    t2i = recall_at_k_from_sims(sims_t2i, ks=ks)
    i2t = recall_at_k_from_sims(sims_i2t, ks=ks)
    return t2i, i2t

### RETRIEVAL

In [None]:
def encode_text_keras(tokenizer, text: str, max_seq_len=64, vocab_size=1000, oov_id=1):
    seq = tokenizer.texts_to_sequences([text])[0]
    seq = [t if (t is not None and 0 <= t < vocab_size) else oov_id for t in seq]

    padded = pad_sequences(
        [seq], maxlen=max_seq_len, padding="post", truncating="post", value=0
    )[0].astype(np.int64)

    input_ids = torch.tensor(padded, dtype=torch.long).unsqueeze(0)      
    attention_mask = (input_ids != 0).long()                             
    padding_mask = (attention_mask == 0)                                 
    return input_ids, padding_mask


In [None]:
def clip_collate(batch):
    images = torch.stack([b["image"] for b in batch], dim=0) 
    input_ids = torch.stack([b["input_ids"] for b in batch], dim=0)  
    attention_mask = torch.stack([b["attention_mask"] for b in batch], dim=0) 

    labels = torch.stack([b["label"] for b in batch], dim=0) if "label" in batch[0] else None
    idx = torch.stack([b["idx"] for b in batch], dim=0) if "idx" in batch[0] else None

    out = {
        "images": images,
        "input_ids": input_ids,
        "attention_mask": attention_mask,
    }
    if labels is not None:
        out["labels"] = labels
    if idx is not None:
        out["idx"] = idx
    return out


In [None]:

@torch.no_grad()
def show_retrieval_from_text(
    clip_model,
    index,
    query_text: str,
    tokenizer,
    max_seq_len=64,
    vocab_size=1000,
    oov_id=1,
    topk=5,
    device="cuda",
    show_images=True
):
    """
    Donne une phrase -> retourne topk images (paths + scores) et optionnellement les affiche.
    """
    clip_model.eval()
    clip_model.to(device)

    input_ids, padding_mask = encode_text_keras(
        tokenizer, query_text, max_seq_len=max_seq_len,
        vocab_size=vocab_size, oov_id=oov_id
    )
    input_ids = input_ids.to(device)
    padding_mask = padding_mask.to(device)
    q = clip_model.text_encoder(input_ids, padding_mask)   
    q = F.normalize(q, dim=-1).cpu()                       

    sims = (index["img_embs"] @ q.squeeze(0))            
    topv, topi = torch.topk(sims, k=min(topk, sims.size(0)))

    results = []
    for score, idx in zip(topv.tolist(), topi.tolist()):
        path = index["image_paths"][idx] if index["image_paths"] is not None else None
        cap = index["captions"][idx] if index["captions"] is not None else None
        results.append({"rank": len(results)+1, "score": float(score), "image_path": path, "caption": cap})

    print(f"Query text: {query_text}\nTop-{len(results)} images:")
    for r in results:
        print(f"#{r['rank']} score={r['score']:.4f} | {r['image_path']}")

    if show_images and index["image_paths"] is not None:
        fig = plt.figure(figsize=(14, 6))
        for j, r in enumerate(results):
            ax = fig.add_subplot(1, len(results), j+1)
            img = Image.open(r["image_path"]).convert("RGB")
            ax.imshow(img)
            ax.set_title(f"#{r['rank']}\n{r['score']:.3f}")
            ax.axis("off")
        plt.tight_layout()
        plt.show()

    return results


In [None]:
@torch.no_grad()
def show_retrieval_from_image(
    clip_model,
    index,
    query_image,            # str path OR PIL.Image OR torch.Tensor 
    transform=None,         # même transform que ton dataset
    topk=5,
    device="cuda",
    show=True
):
    """
    Donne une image -> retourne topk captions (et leurs images associées si disponibles).
    """
    clip_model.eval()
    clip_model.to(device)

    # Prépare tensor image (1,3,H,W)
    if isinstance(query_image, str):
        img = Image.open(query_image).convert("RGB")
        x = transform(img) if transform is not None else img
    elif isinstance(query_image, Image.Image):
        x = transform(query_image) if transform is not None else query_image
    elif torch.is_tensor(query_image):
        x = query_image
    else:
        raise TypeError("query_image doit être un chemin, une PIL.Image, ou un torch.Tensor.")

    if not torch.is_tensor(x):
        raise ValueError("transform doit retourner un torch.Tensor (3,H,W).")

    x = x.unsqueeze(0).to(device)  #(1,3,H,W)

    q = clip_model.image_encoder(x)          #(1,D)
    q = F.normalize(q, dim=-1).cpu()         #(1,D)

    sims = (index["txt_embs"] @ q.squeeze(0))   #(N,)
    topv, topi = torch.topk(sims, k=min(topk, sims.size(0)))

    results = []
    for score, idx in zip(topv.tolist(), topi.tolist()):
        cap = index["captions"][idx] if index["captions"] is not None else None
        path = index["image_paths"][idx] if index["image_paths"] is not None else None
        results.append({"rank": len(results)+1, "score": float(score), "caption": cap, "image_path": path})

    print("Top captions:")
    for r in results:
        print(f"#{r['rank']} score={r['score']:.4f} | {r['caption']}")

    if show:
        fig = plt.figure(figsize=(14, 6))
        ax0 = fig.add_subplot(1, len(results)+1, 1)
        if isinstance(query_image, str):
            qimg = Image.open(query_image).convert("RGB")
        elif isinstance(query_image, Image.Image):
            qimg = query_image
        else:
            qimg = None
        if qimg is not None:
            ax0.imshow(qimg)
        ax0.set_title("QUERY")
        ax0.axis("off")

        if index["image_paths"] is not None:
            for j, r in enumerate(results):
                ax = fig.add_subplot(1, len(results)+1, j+2)
                img = Image.open(r["image_path"]).convert("RGB")
                ax.imshow(img)
                title = f"#{r['rank']} {r['score']:.3f}"
                ax.set_title(title)
                ax.axis("off")

        plt.tight_layout()
        plt.show()

    return results


In [None]:
@torch.no_grad()
def build_clip_index(clip_model, dataloader, device="cuda"):
    """
    Calcule embeddings normalisés pour tout le dataset du dataloader.
    """
    clip_model.eval()
    clip_model.to(device)

    img_embs = []
    txt_embs = []
    captions = []
    labels = []
    idxs = []

    ds = dataloader.dataset
    has_utils = hasattr(ds, "_get_img_path_from_idx") and hasattr(ds, "_get_caption_from_idx")

    for batch in tqdm(dataloader, desc="Building CLIP index"):
        images = batch["images"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        padding_mask = (attention_mask == 0)

        im = clip_model.image_encoder(images)                         
        tx = clip_model.text_encoder(input_ids, padding_mask)        

    
        im = F.normalize(im, dim=-1)
        tx = F.normalize(tx, dim=-1)

        img_embs.append(im.cpu())
        txt_embs.append(tx.cpu())

        if "idx" in batch:
            idx_batch = batch["idx"].cpu().tolist()
            idxs.extend(idx_batch)

            if has_utils:
                for i in idx_batch:
                    captions.append(str(ds._get_caption_from_idx(i)))
        if "labels" in batch:
            labels.extend(batch["labels"].cpu().tolist())

    img_embs = torch.cat(img_embs, dim=0)  
    txt_embs = torch.cat(txt_embs, dim=0)  

    image_paths = None
    if has_utils and len(idxs) == img_embs.size(0):
        image_paths = [str(ds._get_img_path_from_idx(i)) for i in idxs]

    return {
        "img_embs": img_embs,
        "txt_embs": txt_embs,
        "image_paths": image_paths,
        "captions": captions if len(captions) == img_embs.size(0) else None,
        "labels": labels if len(labels) == img_embs.size(0) else None,
    }

@torch.no_grad()
def build_clip_index_hf(clip_model, dataloader, device="cuda"):
    clip_model.eval()
    clip_model.to(device)

    img_embs = []
    txt_embs = []
    captions = []
    image_paths = []

    ds = dataloader.dataset
    has_utils = hasattr(ds, "_get_img_path_from_idx") and hasattr(ds, "_get_caption_from_idx")

    for batch in tqdm(dataloader, desc="Building CLIP index (HF)"):
        images = batch["images"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        img = clip_model.image_encoder(images)                      
        txt = clip_model.text_encoder(input_ids, attention_mask)   

        img = F.normalize(img, dim=-1)
        txt = F.normalize(txt, dim=-1)

        img_embs.append(img.cpu())
        txt_embs.append(txt.cpu())

        if "idx" in batch and has_utils:
            for i in batch["idx"].tolist():
                image_paths.append(str(ds._get_img_path_from_idx(i)))
                captions.append(str(ds._get_caption_from_idx(i)))

    return {
        "img_embs": torch.cat(img_embs, dim=0),   
        "txt_embs": torch.cat(txt_embs, dim=0),   
        "image_paths": image_paths,
        "captions": captions,
    }

In [None]:
@torch.no_grad()
def show_retrieval_from_text_hf(
    clip_model,
    index,
    query_text: str,
    tokenizer,
    max_len=128,
    topk=5,
    device="cuda",
    show_images=True
):
    clip_model.eval()
    clip_model.to(device)

    enc = tokenizer(
        query_text,
        truncation=True,
        padding="max_length",
        max_length=max_len,
        return_tensors="pt"
    )

    input_ids = enc["input_ids"].to(device)
    attention_mask = enc["attention_mask"].to(device)

    q = clip_model.text_encoder(input_ids, attention_mask)   
    q = F.normalize(q, dim=-1).cpu()

    sims = index["img_embs"] @ q.squeeze(0)             
    topv, topi = torch.topk(sims, k=min(topk, sims.size(0)))

    print(f"Query text: {query_text}\nTop-{topk} images:")

    results = []
    for r, (score, idx) in enumerate(zip(topv.tolist(), topi.tolist()), 1):
        path = index["image_paths"][idx]
        print(f"#{r} score={score:.4f} | {path}")
        results.append((score, path))

    if show_images:
        fig = plt.figure(figsize=(14, 5))
        for j, (score, path) in enumerate(results):
            ax = fig.add_subplot(1, len(results), j+1)
            img = Image.open(path).convert("RGB")
            ax.imshow(img)
            ax.set_title(f"{score:.3f}")
            ax.axis("off")
        plt.tight_layout()
        plt.show()

    return results


In [None]:
@torch.no_grad()
def show_retrieval_from_image_hf(
    clip_model,
    index,
    query_image,              
    transform,
    topk=5,
    device="cuda",
    show=True
):
    clip_model.eval()
    clip_model.to(device)

    if isinstance(query_image, str):
        img = Image.open(query_image).convert("RGB")
        x = transform(img)
        query_vis = img
    elif isinstance(query_image, Image.Image):
        x = transform(query_image)
        query_vis = query_image
    elif torch.is_tensor(query_image):
        x = query_image
        query_vis = None
    else:
        raise TypeError("query_image doit être str, PIL.Image ou Tensor")

    x = x.unsqueeze(0).to(device)

    q = clip_model.image_encoder(x)        
    q = F.normalize(q, dim=-1).cpu()

    sims = index["txt_embs"] @ q.squeeze(0)  
    topv, topi = torch.topk(sims, k=min(topk, sims.size(0)))

    print("Top captions:")
    results = []
    for r, (score, idx) in enumerate(zip(topv.tolist(), topi.tolist()), 1):
        cap = index["captions"][idx]
        print(f"#{r} score={score:.4f} | {cap}")
        results.append((score, cap, index["image_paths"][idx]))

    if show:
        fig = plt.figure(figsize=(14, 5))
        ax0 = fig.add_subplot(1, topk+1, 1)
        ax0.imshow(query_vis)
        ax0.set_title("QUERY")
        ax0.axis("off")

        for j, (score, _, path) in enumerate(results):
            ax = fig.add_subplot(1, topk+1, j+2)
            img = Image.open(path).convert("RGB")
            ax.imshow(img)
            ax.set_title(f"{score:.3f}")
            ax.axis("off")

        plt.tight_layout()
        plt.show()

    return results


### DATASET LOADING

In [None]:
metadata_path = Path("../data/final_dataset_noaug2/metadata.csv")
base_dir = Path("../data/final_dataset_noaug2")

df = pd.read_csv(metadata_path)
print(df.columns)
print(df.iloc[1])

In [None]:

df_train, df_temp = train_test_split(df, test_size=0.3, random_state=11, stratify=df["label"])
df_val, df_test = train_test_split(df_temp, test_size=0.5, random_state=11, stratify=df_temp["label"])


print(df_train["label"].value_counts(normalize=True) * 100)
print(df_val["label"].value_counts(normalize=True) * 100)
print(df_test["label"].value_counts(normalize=True) * 100)

X_train, y_train, caption_train = df_train["image_path"], df_train["label"], df_train["caption"]
X_val, y_val, caption_val  = df_val["image_path"], df_val["label"], df_val["caption"]
X_test, y_test, caption_test   = df_test["image_path"], df_test["label"], df_test["caption"]

X_train = caption_train.reset_index(drop=True)
y_train = y_train.reset_index(drop=True)
X_val = caption_val.reset_index(drop=True)
y_val = y_val.reset_index(drop=True)
X_test = caption_test.reset_index(drop=True)
y_test = y_test.reset_index(drop=True)

print("Avant :", len(df))
print("Après  :", len(df))
print("Doublons supprimés :", len(df) - len(df))


### TRAINING CNN SMALLBERT

In [None]:
cnn_basic = CNNBasic(num_classe=4)

text_cls = SmallBERTPourClassification(
    vocab_size=1000,      #atcher les ids de tokens
    sequence_length=128,   #matcher la longueur des séquences du modèle chargé
    embed_dim=128,
    num_heads=4,
    ff_dim=256,
    num_layers=2,
    num_classes=4
)

cnn_basic = load_model_weights(cnn_basic, "../models/best-model-ccnbasic.pth", strict=True)
text_cls  = load_model_weights(text_cls,  "../models/best-model-smallbert.pth", strict=True)

img_encoder = ImageEncoderFromCNNBasic(cnn_basic, embed_dim=256, proj_dim=256)
txt_encoder = TextEncoderFromSmallBERT(text_cls, proj_dim=256)

clip_model = CLIPLikeModel(img_encoder, txt_encoder)


In [None]:

tokenizer = Tokenizer(num_words=1000, oov_token="<OVV>")
tokenizer.fit_on_texts(df["caption"])

train_ds = CLIPDataset(df_train, base_dir=base_dir, transform=transform,
                       keras_tokenizer=tokenizer, max_seq_len=64,
                       vocab_size=1000, text_mode="keras")

val_ds = CLIPDataset(df_val, base_dir=base_dir, transform=transform,
                     keras_tokenizer=tokenizer, max_seq_len=64,
                     vocab_size=1000, text_mode="keras")

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4,
                          pin_memory=True, collate_fn=clip_collate_keras)

val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=4,
                        pin_memory=True, collate_fn=clip_collate_keras)


In [None]:
img_encoder = ImageEncoderFromCNNBasic(cnn_basic, embed_dim=256, proj_dim=256)
txt_encoder = TextEncoderFromSmallBERT(text_cls, proj_dim=256)
clip_model = CLIPLikeModel(img_encoder, txt_encoder)

def set_trainable(m, flag: bool):
    for p in m.parameters():
        p.requires_grad = flag

set_trainable(clip_model.image_encoder.features, False)
set_trainable(clip_model.text_encoder.encoder, False)

device = "cuda" if torch.cuda.is_available() else "cpu"

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, clip_model.parameters()),
    lr=1e-3, weight_decay=1e-4
)

fit(clip_model, train_loader, val_loader, optimizer, epochs=3, device=device)
index_val = build_clip_index(clip_model, val_loader, device=device)
res = show_retrieval_from_text(
    clip_model, index_val,
    query_text="the man playing with a ball",
    tokenizer=tokenizer,
    max_seq_len=64, vocab_size=1000, oov_id=1,
    topk=5, device=device, show_images=True
)

res = show_retrieval_from_image(
    clip_model, index_val,
    query_image=str(val_ds._get_img_path_from_idx(121
)), 
    transform=transform,
    topk=5, device=device, show=True
)


### CLIP DISTILLBERT CNN

In [None]:
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification


MODEL_NAME = "distilbert-base-uncased"  
distilbert_cls = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=4)
distilbert_cls = load_hf_classifier_checkpoint(distilbert_cls, "../models/best-distilbert.pth")

hf_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_ds = CLIPDataset(df_train, base_dir=base_dir, transform=transform,
                       hf_tokenizer=hf_tokenizer, max_seq_len=128,
                       text_mode="hf")

val_ds = CLIPDataset(df_val, base_dir=base_dir, transform=transform,
                     hf_tokenizer=hf_tokenizer, max_seq_len=128,
                     text_mode="hf")

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True,
                          num_workers=4, pin_memory=True,
                          collate_fn=clip_collate)

val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
                        num_workers=4, pin_memory=True,
                        collate_fn=clip_collate)

In [None]:
# image encoder (comme avant)
img_encoder = ImageEncoderFromCNNBasic(cnn_basic, embed_dim=256, proj_dim=256)

# texte encoder distilbert
# txt_encoder = TextEncoderFromDistilBERT(distilbert_cls, proj_dim=256)
txt_encoder = TextEncoderFromSmallBERT(text_cls, proj_dim=256)

clip_model = CLIPLikeModelHF(img_encoder, txt_encoder, init_temp=0.07)

device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model.to(device)
def set_trainable(m, flag: bool):
    for p in m.parameters():
        p.requires_grad = flag

# Freeze CNN features + distilbert backbone au début
set_trainable(clip_model.image_encoder.features, False)
set_trainable(clip_model.text_encoder.backbone, False)

optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, clip_model.parameters()),
                              lr=1e-3, weight_decay=1e-4)

# fit version HF (en appelant train_step_hf/val_step_hf)
set_trainable(clip_model.image_encoder.features, True)
set_trainable(clip_model.text_encoder.backbone, True)

optimizer = torch.optim.AdamW(clip_model.parameters(), lr=1e-4, weight_decay=1e-4)


In [None]:

index_val = build_clip_index_hf(clip_model, val_loader, device=device)

show_retrieval_from_text_hf(
    clip_model, index_val,
    query_text="a man riding a bike",
    tokenizer=hf_tokenizer,
    max_len=128,
    topk=5,
    device=device
)

show_retrieval_from_image_hf(
    clip_model, index_val,
    query_image=Image.open(val_ds._get_img_path_from_idx(121)),
    transform=transform,
    topk=5,
    device=device

)

### TRAINING SMALLBERT RESNET

In [None]:
train_ds_resnet = CLIPDataset(df_train, base_dir=base_dir, transform=transform_resnet,
                       keras_tokenizer=tokenizer, max_seq_len=128,
                       )

val_ds_resnet = CLIPDataset(df_val, base_dir=base_dir, transform=transform_resnet,
                     keras_tokenizer=tokenizer, max_seq_len=128,
                     )

train_loader_resnet = DataLoader(train_ds_resnet, batch_size=64, shuffle=True,
                          num_workers=4, pin_memory=True,
                          collate_fn=clip_collate)

val_loader_resnet = DataLoader(val_ds_resnet, batch_size=64, shuffle=False,
                        num_workers=4, pin_memory=True,
                        collate_fn=clip_collate)
resnet_cls = load_best_resnet18(
    num_classes=4,
    ckpt_path="../notebooks/best-model-resnet.pth",
    device=device
).to(device)

img_encoder = ImageEncoderFromResNet18(resnet_cls, proj_dim=256, train_backbone=True)

text_cls = SmallBERTPourClassification(
    vocab_size=1000,      
    sequence_length=128,   
    embed_dim=128,
    num_heads=4,
    ff_dim=256,
    num_layers=2,
    num_classes=4
)
txt_encoder = TextEncoderFromSmallBERT(text_cls, proj_dim=256)


clip_model = CLIPLikeModel(img_encoder, txt_encoder).to(device)

def set_trainable(m, flag: bool):
    for p in m.parameters():
        p.requires_grad = flag


set_trainable(clip_model.text_encoder.encoder, False)

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, clip_model.parameters()),
    lr=1e-3, weight_decay=1e-4
)
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
fit(clip_model, train_loader_resnet, val_loader_resnet, optimizer, epochs=5, device=device)
index_val = build_clip_index(clip_model, val_loader_resnet, device=device)
res = show_retrieval_from_text(
    clip_model, index_val,
    query_text="the man playing with a ball",
    tokenizer=tokenizer,
    max_seq_len=64, vocab_size=1000, oov_id=1,
    topk=5, device=device, show_images=True
)

res = show_retrieval_from_image(
    clip_model, index_val,
    query_image=str(val_ds_resnet._get_img_path_from_idx(121
)), 
    transform=transform_resnet,
    topk=5, device=device, show=True
)


                        

### TRAINING DISTILLBERT RESNET


In [None]:
train_ds_resnet = CLIPDataset(df_train, base_dir=base_dir, transform=transform_resnet,
                       keras_tokenizer=tokenizer, max_seq_len=128,
                       )

val_ds_resnet = CLIPDataset(df_val, base_dir=base_dir, transform=transform_resnet,
                     keras_tokenizer=tokenizer, max_seq_len=128,
                     )

train_loader_resnet = DataLoader(train_ds_resnet, batch_size=64, shuffle=True,
                          num_workers=4, pin_memory=True,
                          collate_fn=clip_collate)

val_loader_resnet = DataLoader(val_ds_resnet, batch_size=64, shuffle=False,
                        num_workers=4, pin_memory=True,
                        collate_fn=clip_collate)
resnet_cls = load_best_resnet18(
    num_classes=4,
    ckpt_path="../notebooks/best-model-resnet.pth",
    device=device
).to(device)

img_encoder = ImageEncoderFromResNet18(resnet_cls, proj_dim=256, train_backbone=True)

txt_encoder = TextEncoderFromDistilBERT(distilbert_cls, proj_dim=256)
clip_model = CLIPLikeModelHF(img_encoder, txt_encoder, init_temp=0.07).to(device)

clip_model = CLIPLikeModel(img_encoder, txt_encoder).to(device)

def set_trainable(m, flag: bool):
    for p in m.parameters():
        p.requires_grad = flag


set_trainable(clip_model.image_encoder.backbone, True)
set_trainable(clip_model.text_encoder.backbone, True)


import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
optimizer = torch.optim.AdamW(clip_model.parameters(), lr=1e-5, weight_decay=1e-4)
fit_hf(clip_model, train_loader, val_loader, optimizer, epochs=10, device=device)
res = show_retrieval_from_text(
    clip_model, index_val,
    query_text="the man playing with a ball",
    tokenizer=tokenizer,
    max_seq_len=64, vocab_size=1000, oov_id=1,
    topk=5, device=device, show_images=True
)

res = show_retrieval_from_image(
    clip_model, index_val,
    query_image=str(val_ds_resnet._get_img_path_from_idx(121
)), 
    transform=transform_resnet,
    topk=5, device=device, show=True
)
index_val = build_clip_index_hf(clip_model, val_loader, device=device)


show_retrieval_from_text_hf(
    clip_model, index_val,
    query_text="a man riding a bike",
    tokenizer=hf_tokenizer,
    max_len=128,
    topk=5,
    device=device
)

show_retrieval_from_image_hf(
    clip_model, index_val,
    query_image=Image.open(val_ds._get_img_path_from_idx(121)),
    transform=transform_resnet,
    topk=5,
    device=device

)           