# OCR + Détection de fraude sur documents d'identité

Pipeline complet : prétraitement (alignement), OCR multilingue, classification image, features OCR tabulaires, *late fusion*, inférence test (JSON champs + CSV classes).

In [1]:
%load_ext autoreload
%autoreload 2

## 0) Installation / utils
Exécuter ce bloc si nécessaire dans votre environnement.

In [2]:
# %pip install torch torchvision torchaudio pytorch-lightning==2.4.0 timm==1.0.9 albumentations opencv-python-headless shapely rapidfuzz python-Levenshtein paddlepaddle-gpu paddleocr lightgbm scikit-learn pandas numpy

In [3]:
# %pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

In [4]:
from tqdm.notebook import tqdm

utils

In [5]:
import os, json, random, numpy as np, torch

def seed_everything(seed=42):
    import random, numpy as np, torch
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def find_images(root, exts=(".jpg",".jpeg",".png",".bmp",".tif",".tiff")):
    out=[]
    for dp,_,files in os.walk(root):
        for f in files:
            if f.lower().endswith(exts):
                out.append(os.path.join(dp,f))
    return out

def save_json(obj, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as w:
        json.dump(obj, w, ensure_ascii=False, indent=2)


constant

In [6]:
CLASSES = ["normal","forgery_1","forgery_2","forgery_3","forgery_4"]
COUNTRIES = ["spain","estonia","russia","arizona"]

data

In [7]:

import os, glob, random, cv2, numpy as np
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
# from .constants import CLASSES

# def make_transforms(train=True, size=768):
#     aug = [A.LongestMaxSize(size), A.PadIfNeeded(size,size, border_mode=cv2.BORDER_CONSTANT, value=(255,255,255))]
#     if train:
#         aug += [
#             A.ImageCompression(quality_lower=40,quality_upper=90,p=0.5),
#             A.MotionBlur(3,p=0.2), A.GaussianBlur(3,p=0.2),
#             A.RandomBrightnessContrast(0.2,0.2,p=0.5),
#             A.Rotate(limit=7, border_mode=cv2.BORDER_CONSTANT, value=(255,255,255),p=0.5),
#         ]
#     aug += [A.Normalize(), ToTensorV2()]
#     return A.Compose(aug)


import inspect
import albumentations as A
import cv2
from albumentations.pytorch import ToTensorV2

def _init_with_supported_kwargs(cls, **kwargs):
    """Construit un transform Albumentations en ne gardant que les kwargs supportés."""
    sig = inspect.signature(cls.__init__)
    allowed = set(sig.parameters.keys())
    # __init__(self, ...) -> enlever 'self'
    allowed.discard('self')
    filtered = {k: v for k, v in kwargs.items() if k in allowed}
    return cls(**filtered)

def _pad_if_needed(size):
    # Essaie d’abord border_value (nouvelles versions), sinon value (anciennes).
    # On filtre automatiquement selon la signature présente en 2.0.8.
    return _init_with_supported_kwargs(
        A.PadIfNeeded,
        min_height=size, min_width=size,
        border_mode=cv2.BORDER_CONSTANT,
        border_value=(255, 255, 255),
        value=(255, 255, 255),         # au cas où ta build attend 'value'
        position='center'
    )

def _rotate(limit):
    return _init_with_supported_kwargs(
        A.Rotate,
        limit=limit,
        border_mode=cv2.BORDER_CONSTANT,
        border_value=(255, 255, 255),
        value=(255, 255, 255),
        p=0.5
    )

def _image_compression():
    return _init_with_supported_kwargs(
        A.ImageCompression,
        quality_range=(40, 90),
        quality_lower=40, quality_upper=90,
        p=0.5
    )

def make_transforms(train=True, size=768):
    aug = [
        _init_with_supported_kwargs(A.LongestMaxSize, max_size=size, p=1.0),
        _pad_if_needed(size),
    ]
    if train:
        aug += [
            _image_compression(),
            _init_with_supported_kwargs(A.MotionBlur, blur_limit=3, p=0.2),
            _init_with_supported_kwargs(A.GaussianBlur, blur_limit=3, p=0.2),
            _init_with_supported_kwargs(A.RandomBrightnessContrast,
                                        brightness_limit=0.2, contrast_limit=0.2, p=0.5),
            _rotate(limit=7),
        ]
    aug += [A.Normalize(), ToTensorV2()]
    return A.Compose(aug)



def infer_country_from_path(p):
    p=p.lower()
    if "estonia" in p or "ee" in p: return "ee"
    if "spain" in p or "es" in p: return "es"
    if "russia" in p or "ru" in p: return "ru"
    if "arizona" in p or "az" in p or "usa" in p: return "az"
    return "unknown"

class IdDocsDataset(Dataset):
    def __init__(self, root, train=True, size=768):
        self.root = root
        self.items=[]
        for country in os.listdir(root):
            cdir = os.path.join(root, country)
            if not os.path.isdir(cdir): continue
            for cls in CLASSES:
                img_dir = os.path.join(cdir, cls if cls!="normal" else "normal")
                if not os.path.isdir(img_dir): continue
                for ext in ("*.jpg","*.jpeg","*.png","*.bmp","*.tif","*.tiff"):
                    for imgp in glob.glob(os.path.join(img_dir,ext)):
                        self.items.append({"img": imgp, "label": cls, "country": infer_country_from_path(imgp)})
        random.shuffle(self.items)
        self.transforms = make_transforms(train=train, size=size)
        self.class_to_idx = {c:i for i,c in enumerate(CLASSES)}

    def __len__(self): return len(self.items)

    def __getitem__(self, i):
        x = self.items[i]
        img = cv2.imread(x["img"]); 
        if img is None:
            raise FileNotFoundError(f"Impossible de lire: {x['img']}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        t = self.transforms(image=img)["image"]
        y = self.class_to_idx[x["label"]]
        return {"image": t, "label": y, "country": x["country"], "path": x["img"]}

  from .autonotebook import tqdm as notebook_tqdm


align

In [8]:

import cv2, numpy as np

def compute_homography(img, template):
    orb = cv2.ORB_create(nfeatures=3000)
    kp1, des1 = orb.detectAndCompute(img,None)
    kp2, des2 = orb.detectAndCompute(template,None)
    if des1 is None or des2 is None: return None
    bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
    matches = bf.knnMatch(des1, des2, k=2)
    good=[]
    for m,n in matches:
        if m.distance < 0.75*n.distance:
            good.append(m)
    if len(good)<12: return None
    src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1,1,2)
    dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1,1,2)
    H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC,5.0)
    return H

def warp_to_template(img, template):
    H = compute_homography(img, template)
    if H is None: 
        return img, None
    h,w = template.shape[:2]
    warped = cv2.warpPerspective(img, H, (w,h), borderValue=(255,255,255))
    return warped, H


features

In [9]:

import numpy as np, re

def iou(boxA, boxB):
    xA=max(boxA[0], boxB[0]); yA=max(boxA[1], boxB[1])
    xB=min(boxA[0]+boxA[2], boxB[0]+boxB[2]); yB=min(boxA[1]+boxA[3], boxB[1]+boxB[3])
    inter=max(0,xB-xA)*max(0,yB-yA)
    union=boxA[2]*boxA[3]+boxB[2]*boxB[3]-inter
    return inter/union if union>0 else 0.0

def mrz_checksum(s):
    weights=[7,3,1]; total=0
    def val(c):
        if c.isdigit(): return int(c)
        if 'A'<=c<='Z': return ord(c)-55
        return 0
    for i,c in enumerate(s):
        total += val(c) * weights[i%3]
    return str(total % 10)

def build_features(ocr_items, expected_fields):
    feats={}
    ious=[]; confs=[]; miss=0; regex_ok=0; lens=[]
    for f in expected_fields:
        eb=f["bbox"]
        best=None; best_d=1e18
        ex=(eb[0]+eb[2]/2, eb[1]+eb[3]/2)
        for it in ocr_items:
            xs=[p[0] for p in it["box"]]; ys=[p[1] for p in it["box"]]
            cx,cy=sum(xs)/4,sum(ys)/4
            d=(cx-ex[0])**2+(cy-ex[1])**2
            if d<best_d: best=it; best_d=d
        if best is None: miss+=1; continue
        xs=[p[0] for p in best["box"]]; ys=[p[1] for p in best["box"]]
        bb=[min(xs),min(ys), max(xs)-min(xs), max(ys)-min(ys)]
        ious.append(iou(eb,bb))
        confs.append(best["conf"])
        txt=best["text"]
        lens.append(min(len(txt), 64))
    feats["iou_mean"]=float(np.mean(ious)) if ious else 0.0
    feats["conf_mean"]=float(np.mean(confs)) if confs else 0.0
    feats["missing_ratio"]=float(miss/max(1,len(expected_fields)))
    feats["text_len_mean"]=float(np.mean(lens)) if lens else 0.0
    return feats


fuse

In [10]:
import numpy as np
def fuse_probs(p_img, p_tab, alpha=0.6):
    return alpha*np.array(p_img)+(1-alpha)*np.array(p_tab)

model_img

In [11]:
import torch, timm, pytorch_lightning as pl
import torch.nn as nn
import torchmetrics

NUM_CLASSES=5

class ImgClassifier(pl.LightningModule):
    def __init__(self, lr=1e-4, wd=1e-4, class_weights=None, model_name="tf_efficientnet_b0"):
        super().__init__()
        self.save_hyperparameters()
        self.net = timm.create_model(model_name, pretrained=True, num_classes=NUM_CLASSES)
        self.crit = nn.CrossEntropyLoss(weight=class_weights)
        self.f1 = torchmetrics.F1Score(task="multiclass", num_classes=NUM_CLASSES, average="macro")

    def forward(self,x): return self.net(x)

    def step(self, batch, stage):
        y = batch["label"]; yhat = self.forward(batch["image"])
        loss = self.crit(yhat, y)
        preds = yhat.argmax(1)
        f1 = self.f1(preds, y)
        self.log(f"{stage}_loss", loss, prog_bar=True)
        self.log(f"{stage}_f1", f1, prog_bar=True)
        return loss

    def training_step(self,batch,_): return self.step(batch,"train")
    def validation_step(self,batch,_): return self.step(batch,"val")
    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.wd)
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
        return {"optimizer":opt,"lr_scheduler":sch}


model_tab

In [12]:

import torch, torch.nn as nn, pytorch_lightning as pl
import torchmetrics

class TabClassifier(pl.LightningModule):
    def __init__(self, in_dim, lr=1e-3, wd=1e-4, hidden=128):
        super().__init__()
        self.save_hyperparameters()
        self.m = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(), nn.BatchNorm1d(hidden),
            nn.Linear(hidden, hidden), nn.ReLU(), nn.BatchNorm1d(hidden),
            nn.Linear(hidden, 5)
        )
        self.crit = nn.CrossEntropyLoss()
        self.f1 = torchmetrics.F1Score(task="multiclass", num_classes=5, average="macro")

    def forward(self,x): return self.m(x)
    def step(self,b,stage):
        y=b["y"]; yhat=self.forward(b["x"]); loss=self.crit(yhat,y)
        preds=yhat.argmax(1); f1=self.f1(preds,y)
        self.log(f"{stage}_loss",loss,prog_bar=True); self.log(f"{stage}_f1",f1,prog_bar=True)
        return loss
    def training_step(self,b,_): return self.step(b,"train")
    def validation_step(self,b,_): return self.step(b,"val")
    def configure_optimizers(self):
        opt=torch.optim.AdamW(self.parameters(),lr=self.hparams.lr,weight_decay=self.hparams.wd)
        return opt

ocr

In [13]:
# %pip install paddlepaddle-gpu==2.6.1 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html

In [14]:
from paddleocr import PaddleOCR

LANGS_BY_COUNTRY = {"es":["es","en"], "ee":["en"], "ru":["ru","en"], "az":["en"]}

class OCRWrapper:
    def __init__(self, country="es"):
        self.country = country
        # Note: 'multilang' suppose que les modèles multilingues sont installés
        self.ocr = PaddleOCR(use_angle_cls=True, lang='multilang', show_log=False)

    def run(self, img):
        res = self.ocr.ocr(img, cls=True)
        out=[]
        if res and len(res)>0:
            for line in res[0]:
                box = line[0]; text=line[1][0]; conf=line[1][1]
                out.append({"box":box, "text":text, "conf":float(conf)})
        return out

schema

In [15]:
# import os, json, numpy as np

# class SchemaManager:
#     def __init__(self):
#         self.templates = {}  # {country: {"template_image": ndarray or None, "fields":[{"key":..., "bbox":[x,y,w,h], "regex":..., "lang":...}] }}

#     def load_country_schema(self, country, gt_dir):
#         # Agrège les bboxes de gt/*.json -> bbox médiane par clé
#         boxes_by_key = {}
#         for fn in os.listdir(gt_dir):
#             if not fn.lower().endswith(".json"): continue
#             with open(os.path.join(gt_dir, fn), "r", encoding="utf-8") as r:
#                 obj = json.load(r)
#             for k,v in obj.items():
#                 bb = v.get("bbox") or v.get("box") or v.get("bbox_xywh")
#                 if bb is None: 
#                     continue
#                 boxes_by_key.setdefault(k, []).append(bb)
#         fields=[]
#         for k, arr in boxes_by_key.items():
#             arr_np = np.array(arr, dtype=float)
#             med = np.median(arr_np, axis=0).tolist()
#             fields.append({"key":k, "bbox": [float(med[0]), float(med[1]), float(med[2]), float(med[3])] })
#         self.templates[country] = {"template_image": None, "fields": fields}

#     def set_template_image(self, country, img):
#         self.templates.setdefault(country, {"template_image": None, "fields": []})
#         self.templates[country]["template_image"] = img

#     def fields_for_country(self, country):
#         return self.templates.get(country, {}).get("fields", [])

In [16]:
import os, json

class SchemaManager:
    """
    Schéma basé sur les CLÉS seulement (pas de bboxes).
    - fields = [{"key": <nom_du_champ>}]
    - has_bbox = False
    """
    def __init__(self):
        self.templates = {}  # {country: {"template_image": None, "fields":[{"key":...}], "has_bbox": False}}

    def load_country_schema(self, country, gt_dir):
        keys_counter = {}
        n_files = 0
        for fn in os.listdir(gt_dir):
            if not fn.lower().endswith(".json"):
                continue
            n_files += 1
            path = os.path.join(gt_dir, fn)
            try:
                with open(path, "r", encoding="utf-8") as r:
                    obj = json.load(r)
            except Exception as e:
                print(f"[WARN] JSON invalide ignoré: {path} ({e})")
                continue

            # On attend un JSON plat {key: value}; on collecte juste les clés
            if isinstance(obj, dict):
                for k in obj.keys():
                    keys_counter[k] = keys_counter.get(k, 0) + 1

        # Construit la liste ordonnée des clés les plus fréquentes
        fields = [{"key": k} for k, _ in sorted(keys_counter.items(), key=lambda kv: (-kv[1], kv[0]))]
        self.templates[country] = {
            "template_image": None,
            "fields": fields,
            "has_bbox": False
        }
        print(f"[{country}] gt lus: {n_files}, clés uniques: {len(fields)}, has_bbox=False")

    def set_template_image(self, country, img):
        self.templates.setdefault(country, {"template_image": None, "fields": [], "has_bbox": False})
        self.templates[country]["template_image"] = img

    def fields_for_country(self, country):
        return self.templates.get(country, {}).get("fields", [])

    def has_bbox(self, country):
        return False


## 0) helpers OCR pour chaque clé

In [17]:
import re
from datetime import datetime

DATE_RE = re.compile(r"\b(\d{2})[ ./-](\d{2})[ ./-](\d{2,4})\b")
DNI_RE  = re.compile(r"\b(\d{8})([A-Z])\b")
ALNUM_RE = re.compile(r"[A-Z0-9]{6,}")

# Calcul de la lettre de contrôle du DNI (Espagne)
_DNI_LETTERS = "TRWAGMYFPDXBNJZSQVHLCKE"
def dni_letter(num_8digits: str) -> str:
    try:
        n = int(num_8digits)
        return _DNI_LETTERS[n % 23]
    except:
        return ""

def _norm_date(d, prefer_format="DD/MM/YYYY"):
    m = DATE_RE.search(d.replace("\\", "/"))
    if not m:
        return ""
    dd, mm, yy = m.groups()
    if len(yy) == 2:
        yy = "20"+yy if int(yy) <= 30 else "19"+yy
    try:
        dt = datetime(int(yy), int(mm), int(dd))
        if prefer_format == "YYYY-MM-DD":
            return dt.strftime("%Y-%m-%d")
        return dt.strftime("%d/%m/%Y")
    except:
        return ""

def pick_text_for_key_from_ocr(key: str, ocr_items: list):
    k = key.lower()

    # country_code : 3 lettres (ESP, EST, RUS, USA/AZ…)
    if "country_code" in k:
        # chercher tokens de 3 lettres
        best = ""
        best_conf = -1
        for it in ocr_items:
            for tok in re.findall(r"\b[A-Z]{3}\b", it["text"].upper()):
                if it["conf"] > best_conf:
                    best, best_conf = tok, it["conf"]
        return best or "ESP"  # défaut si rien

    # dates
    if any(s in k for s in ["birthday", "issue_date", "expire_date", "fecha", "expiry", "expiration"]):
        for it in ocr_items:
            s = _norm_date(it["text"])
            if s:
                return s
        return ""

    # genre
    if "gender" in k or "sex" in k:
        for it in ocr_items:
            m = re.search(r"\b([MF])\b", it["text"].upper())
            if m:
                return m.group(1)
        return ""

    # DNI / card_num : 8 chiffres + 1 lettre
    if "card_num" in k or ("dni" in k) or ("document" in k and "num" in k):
        best = ""
        best_conf = -1
        for it in ocr_items:
            for m in DNI_RE.finditer(it["text"].replace(" ", "").upper()):
                num, letter = m.groups()
                if dni_letter(num) == letter and it["conf"] > best_conf:
                    best, best_conf = num + letter, it["conf"]
        if best:
            return best
        # fallback : meilleur alphanum
        for it in ocr_items:
            m = ALNUM_RE.search(it["text"].replace(" ", "").upper())
            if m and it["conf"] > best_conf:
                best, best_conf = m.group(0), it["conf"]
        return best

    # personal_num : identifiant alphanumérique
    if "personal_num" in k or "personal" in k:
        best = ""
        best_conf = -1
        for it in ocr_items:
            m = ALNUM_RE.search(it["text"].replace(" ", "").upper())
            if m and it["conf"] > best_conf:
                best, best_conf = m.group(0), it["conf"]
        return best

    # noms/prénoms : on prend la ligne la plus confiante en majuscules
    if any(s in k for s in ["surname", "given_name", "second_surname", "name"]):
        best = ""
        best_conf = -1
        for it in ocr_items:
            t = it["text"].strip()
            if t.upper() == t and len(t) >= 2 and it["conf"] > best_conf:
                best, best_conf = t, it["conf"]
        if best:
            return best

    # fallback général : meilleure confiance
    if ocr_items:
        it = max(ocr_items, key=lambda x: x.get("conf", 0.0))
        return it["text"]
    return ""


## 1) Imports & config

In [18]:
import os, json, cv2, numpy as np, pandas as pd, torch
from torch.utils.data import DataLoader, random_split, TensorDataset
import pytorch_lightning as pl
from sklearn.metrics import classification_report, confusion_matrix
from albumentations.pytorch import ToTensorV2
# from src.utils import seed_everything, save_json
# from src.data import IdDocsDataset, infer_country_from_path, make_transforms
# from src.align import warp_to_template
# from src.ocr import OCRWrapper
# from src.schema import SchemaManager
# from src.features import build_features
# from src.model_img import ImgClassifier
# from src.model_tab import TabClassifier
# from src.fuse import fuse_probs
# from src.constants import CLASSES

seed_everything(42)
DATA_ROOT = "data" 
TRAIN_ROOT = os.path.join(DATA_ROOT, "train")
TEST_ROOT  = os.path.join(DATA_ROOT, "test")
WEIGHTS_DIR = "../weights"
SUB_DIR = "../submissions"
os.makedirs(WEIGHTS_DIR, exist_ok=True)
os.makedirs(os.path.join(SUB_DIR, "fields_json"), exist_ok=True)


## 2) Lecture d'images (aperçu rapide)
Ici on recense les images d'entraînement et on affiche quelques exemples (optionnel).

In [19]:
ds_all = IdDocsDataset(TRAIN_ROOT, train=True, size=768)
# print(f"Total images train: {len(ds_all)}")

In [20]:
# Exemple item
ds_all[0]

{'image': tensor([[[-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
          [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
          [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
          ...,
          [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
          [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
          [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179]],
 
         [[-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
          [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
          [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
          ...,
          [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
          [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
          [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357]],
 
         [[-1.8044, -1.8044, -1.8044,  ..., -1.8044, -1.8044, -1.8044],
          [-1.8044,

In [21]:
ds_all.class_to_idx

{'normal': 0, 'forgery_1': 1, 'forgery_2': 2, 'forgery_3': 3, 'forgery_4': 4}

In [22]:
ds_all.items[0]

{'img': 'data\\train\\esp\\forgery_2\\generated.photos_v3_0044110.png',
 'label': 'forgery_2',
 'country': 'es'}

## 3) Gabarits & schémas depuis gt/
On construit un schéma par pays (bbox médiane par champ) à partir de `train/<country>/gt/*.json`. 

In [23]:
schema = SchemaManager()
for country_dir in os.listdir(TRAIN_ROOT):
    gt_dir = os.path.join(TRAIN_ROOT, country_dir, "gt")
    if os.path.isdir(gt_dir):
        schema.load_country_schema(country_dir, gt_dir)
print("Pays chargés:", list(schema.templates.keys()))
# Optionnel: charger des templates images par pays si disponibles (ex: une image normale de référence)

[arizona_dl] gt lus: 500, clés uniques: 12, has_bbox=False
[esp] gt lus: 500, clés uniques: 10, has_bbox=False
[est] gt lus: 500, clés uniques: 13, has_bbox=False
[rus] gt lus: 500, clés uniques: 10, has_bbox=False
Pays chargés: ['arizona_dl', 'esp', 'est', 'rus']


## 4) DataLoaders PyTorch (train/val)

In [24]:
ds = ds_all
n = len(ds); n_val = max(1, int(0.15*n)); n_train = max(1, n-n_val)
train_ds, val_ds = random_split(ds, [n_train, n_val])
# train_dl = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)
# val_dl   = DataLoader(val_ds, batch_size=8, shuffle=False, num_workers=2, pin_memory=True)
use_cuda = torch.cuda.is_available()

import torch
print("CUDA dispo:", torch.cuda.is_available())
print("torch:", torch.__version__)
print("torch.version.cuda:", torch.version.cuda)
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
    
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True,
                      num_workers=4 if use_cuda else 0,
                      pin_memory=use_cuda, persistent_workers=use_cuda)
val_dl   = DataLoader(val_ds, batch_size=16, shuffle=False,
                      num_workers=4 if use_cuda else 0,
                      pin_memory=use_cuda, persistent_workers=use_cuda)
n, n_train, n_val

CUDA dispo: True
torch: 2.5.1+cu121
torch.version.cuda: 12.1
GPU: NVIDIA GeForce RTX 2050


(10000, 8500, 1500)

## 5) Entraînement modèle image (Lightning + timm)

In [25]:
import torch, os
torch.set_float32_matmul_precision("high")     # kernels matmul plus rapides
torch.backends.cudnn.benchmark = True          # profils optimisés si tailles fixes
os.environ["CUDA_VISIBLE_DEVICES"] = "0"       # au cas où (1 seul GPU)

In [None]:
# ==== Modèle ====
# (Optionnel) poids de classes si dataset déséquilibré :
# class_weights = torch.tensor([1.,1.,1.,1.,1.], dtype=torch.float32, device="cuda")
class_weights = None

model_img = ImgClassifier(
    lr=1e-4, wd=1e-4,
    class_weights=class_weights,
    model_name="tf_efficientnet_b0"
)
# petit gain mémoire/perf CNN
model_img = model_img.to(memory_format=torch.channels_last)

# ==== Callbacks & Trainer ====
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

class PrintMetricsCallback(Callback):
    def on_validation_end(self, trainer, pl_module):
        # récupère les derniers logs connus
        logs = trainer.callback_metrics
        f1 = logs.get("val_f1")
        loss = logs.get("val_loss")
        epoch = trainer.current_epoch
        if f1 is not None and loss is not None:
            print(f"\n[Epoch {epoch}] val_loss={loss:.4f} | val_f1={f1:.4f}")

ckpt = ModelCheckpoint(
    monitor="val_f1", mode="max", save_top_k=1,
    dirpath=WEIGHTS_DIR, filename="best_img"
)
es = EarlyStopping(monitor="val_f1", mode="max", patience=5)

trainer = pl.Trainer(
    max_epochs=12,                 # un peu plus que 10 pour le CosineAnnealingLR du modèle
    callbacks=[ckpt, es],
    accelerator="gpu", devices=1,  # force l'usage GPU (évite le fallback CPU silencieux)
    precision="16-mixed",          # AMP
    gradient_clip_val=1.0,         # stable sur lots un peu gros
    accumulate_grad_batches=1,     # ↑ si tu veux un batch effectif plus grand
    log_every_n_steps=20
)

trainer.fit(model_img, train_dl, val_dl)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
d:\Projet\anip-challenge-ocr\.venv\Lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type              | Params | Mode 
---------------------------------------------------
0 | net  | EfficientNet      | 4.0 M  | train
1 | crit | CrossEntropyLoss  | 0      | train
2 | f1   | MulticlassF1Score | 0      | train
------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

In [None]:
# from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
# class_weights=None  # éventuellement torch.tensor([...], device='cuda')
# model_img = ImgClassifier(lr=1e-4, wd=1e-4, class_weights=class_weights, model_name="tf_efficientnet_b0")
# ckpt = ModelCheckpoint(monitor="val_f1", mode="max", save_top_k=1, dirpath=WEIGHTS_DIR, filename="best_img")
# es = EarlyStopping(monitor="val_f1", mode="max", patience=5)
# trainer = pl.Trainer(max_epochs=10, callbacks=[ckpt,es], accelerator="auto", devices="auto", precision="16-mixed")
# trainer.fit(model_img, train_dl, val_dl)

## 6) Construction features OCR sur l'ensemble de validation
On aligne (si template dispo), on applique PaddleOCR, on agrège des features par image.

In [None]:
import numpy as np
rows=[]
for b in val_dl:
    ims=b["image"]; ys=b["label"]; countries=b["country"]; paths=b["path"]
    for img_t, y, country, p in zip(ims, ys, countries, paths):
        img = (img_t.permute(1,2,0).numpy()*255).astype(np.uint8)
        template = None
        if country in schema.templates and schema.templates[country]["template_image"] is not None:
            warped,_ = warp_to_template(img, schema.templates[country]["template_image"])
        else:
            warped = img
        ocr = OCRWrapper(country).run(warped)
        expected = schema.fields_for_country(country)
        feats = build_features(ocr, expected)
        feats["y"]=int(y); feats["image_id"]=os.path.basename(p)
        rows.append(feats)
import pandas as pd
df_feats = pd.DataFrame(rows).fillna(0.0)
feat_cols = [c for c in df_feats.columns if c not in ["y","image_id"]]
df_feats.head()

## 7) Modèle tabulaire (MLP) sur features OCR

In [None]:
import torch
from src.model_tab import TabClassifier
from torch.utils.data import TensorDataset, DataLoader, random_split

X = torch.tensor(df_feats[feat_cols].values, dtype=torch.float32)
Y = torch.tensor(df_feats["y"].values, dtype=torch.long)
ds_tab = TensorDataset(X,Y)
n=len(ds_tab); n_val2=max(1,int(0.2*n)); n_tr=n-n_val2
train_tab, val_tab = random_split(ds_tab,[n_tr,n_val2])
train_tab_dl=DataLoader(train_tab,batch_size=64,shuffle=True)
val_tab_dl=DataLoader(val_tab,batch_size=64,shuffle=False)

tab = TabClassifier(in_dim=len(feat_cols), lr=1e-3, wd=1e-4, hidden=128)
ckpt2 = pl.callbacks.ModelCheckpoint(monitor="val_f1", mode="max", save_top_k=1, dirpath=WEIGHTS_DIR, filename="best_tab")
trainer2 = pl.Trainer(max_epochs=20, callbacks=[ckpt2], accelerator="auto", devices="auto", precision="16-mixed")
trainer2.fit(tab, train_tab_dl, val_tab_dl)

## 8) Évaluation sur validation & fusion

In [None]:
import numpy as np
y_true=[]; y_pred_img=[]
model_img.eval()
with torch.no_grad():
    for b in val_dl:
        logits = model_img(b["image"]).cpu()
        y_true += b["label"].cpu().tolist()
        y_pred_img += logits.argmax(1).tolist()
print("Image-only:\n", classification_report(y_true, y_pred_img, target_names=CLASSES, digits=3))

# Tab-only (évaluation rapide sur df_feats déjà créé)
tab.eval()
with torch.no_grad():
    probs_tab = torch.softmax(tab(torch.tensor(df_feats[feat_cols].values, dtype=torch.float32)), dim=1).cpu().numpy()
y_pred_tab = probs_tab.argmax(1)
print("Tab-only:\n", classification_report(df_feats["y"].values, y_pred_tab, target_names=CLASSES, digits=3))

# Fusion naïve: si on veut fusionner sur un sous-ensemble commun, il faudrait aligner les indices; ici démo avec tab-only.
print("(Astuce) Pour une vraie fusion sur val: collecter proba image et proba tab sur le même split et combiner avec fuse_probs.")

## 9) Inférence sur test : JSON champs + CSV classes

In [None]:
# --- À ajouter AVANT la boucle (une seule fois) ---
from paddleocr import PaddleOCR
ocr_engine = PaddleOCR(use_angle_cls=True, lang='multilang', show_log=False)

# Sécurise le modèle tabulaire : s'il n'existe pas, on fait image-only
if 'tab' in globals():
    tab.eval()
else:
    tab = None

In [None]:
# --- À COLLER EN REMPLACEMENT de ton bloc “OCR + JSON champs … -> fin” ---

sub_rows = []
model_img.eval()

for imgp in test_images:
    image_id = os.path.basename(imgp)
    country = infer_country_from_path(imgp)

    # lecture + alignement (si tu as un template)
    img_bgr = cv2.imread(imgp)
    if img_bgr is None:
        print(f"[WARN] Image illisible: {imgp}"); 
        continue
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

    template = None
    if country in schema.templates and schema.templates[country].get("template_image") is not None:
        warped_rgb, _ = warp_to_template(img_rgb, schema.templates[country]["template_image"])
    else:
        warped_rgb = img_rgb

    # ---------- OCR + JSON plat (mêmes clés que gt) ----------
    # PaddleOCR attend du BGR → on reconvertit
    ocr_res = ocr_engine.ocr(cv2.cvtColor(warped_rgb, cv2.COLOR_RGB2BGR), cls=True)
    ocr_items = []
    if ocr_res and len(ocr_res) > 0 and ocr_res[0] is not None:
        for line in ocr_res[0]:
            box = line[0]
            text = line[1][0]
            conf = float(line[1][1])
            ocr_items.append({"box": box, "text": text, "conf": conf})

    expected = schema.fields_for_country(country)  # [{"key": ...}, ...]
    flat_fields = {}
    for f in expected:
        k = f["key"]
        flat_fields[k] = pick_text_for_key_from_ocr(k, ocr_items)

    # Sauvegarde JSON plat (clé -> valeur), pas de bbox
    save_json(
        flat_fields, 
        os.path.join(SUB_DIR, "fields_json", image_id.rsplit(".", 1)[0] + ".json")
    )

    # ---------- Probas image ----------
    ti = t_infer(image=warped_rgb)["image"].unsqueeze(0)
    with torch.no_grad():
        p_img = torch.softmax(model_img(ti), dim=1).cpu().numpy()[0]

    # ---------- Probas tab (facultatif) ----------
    if tab is not None:
        # features texte simples alignées avec l'entraînement tabulaire
        confs = [it.get("conf", 0.0) for it in ocr_items]
        text_len = [min(len(it.get("text","")), 64) for it in ocr_items]
        feats = {
            "conf_mean": float(np.mean(confs)) if confs else 0.0,
            "conf_max": float(np.max(confs)) if confs else 0.0,
            "ocr_lines": float(len(ocr_items)),
            "text_len_mean": float(np.mean(text_len)) if text_len else 0.0,
        }
        # Si feat_cols n'existe pas, on le définit ici dans le même ordre
        if 'feat_cols' not in globals():
            feat_cols = ["conf_mean", "conf_max", "ocr_lines", "text_len_mean"]

        feat_vec = np.array([[feats.get(c, 0.0) for c in feat_cols]], dtype=np.float32)
        with torch.no_grad():
            p_tab = torch.softmax(tab(torch.tensor(feat_vec)), dim=1).cpu().numpy()[0]

        # Fusion tardive
        alpha = 0.6
        p = alpha * p_img + (1.0 - alpha) * p_tab
    else:
        p = p_img

    pred = int(np.argmax(p))
    sub_rows.append({"image_id": image_id, "class_pred": CLASSES[pred]})

# ---------- CSV final ----------
pd.DataFrame(sub_rows).to_csv(os.path.join(SUB_DIR, "submission.csv"), index=False)
print("Fichiers écrits dans:", SUB_DIR)


In [None]:
# import glob
# from albumentations import Compose, LongestMaxSize, PadIfNeeded, Normalize
# from albumentations.pytorch import ToTensorV2

# t_infer = Compose([
#     LongestMaxSize(768),
#     PadIfNeeded(768,768, border_mode=cv2.BORDER_CONSTANT, value=(255,255,255)),
#     Normalize(),
#     ToTensorV2()
# ])

# sub_rows=[]
# model_img.eval(); tab.eval()

# test_images = []
# for ext in ("*.jpg","*.jpeg","*.png","*.bmp"):
#     test_images += glob.glob(os.path.join(TEST_ROOT, "**", ext), recursive=True)
# print(f"Total images test: {len(test_images)}")

# for imgp in test_images:
#     image_id=os.path.basename(imgp)
#     country = infer_country_from_path(imgp)
#     img=cv2.cvtColor(cv2.imread(imgp), cv2.COLOR_BGR2RGB)
#     template=None
#     if country in schema.templates and schema.templates[country]["template_image"] is not None:
#         warped,_=warp_to_template(img, schema.templates[country]["template_image"])
#     else:
#         warped=img

#     # OCR + JSON champs (structure minimale: {key: {text, bbox}})
#     ocr_items=OCRWrapper(country).run(warped)
#     expected = schema.fields_for_country(country)
#     fields={}
#     for f in expected:
#         fields[f["key"]] = {"text":"", "bbox": f["bbox"]}
#     # Simple appariement: nearest center (démo)
#     for f in expected:
#         ex=(f["bbox"][0]+f["bbox"][2]/2, f["bbox"][1]+f["bbox"][3]/2)
#         best=None; best_d=1e18
#         for it in ocr_items:
#             xs=[p[0] for p in it["box"]]; ys=[p[1] for p in it["box"]]
#             cx,cy=sum(xs)/4,sum(ys)/4
#             d=(cx-ex[0])**2+(cy-ex[1])**2
#             if d<best_d: best=it; best_d=d
#         if best is not None:
#             fields[f["key"]]["text"] = best["text"]
#     save_json(fields, os.path.join(SUB_DIR, "fields_json", image_id.replace('.jpg','.json').replace('.png','.json')))

#     # Probas image
#     ti=t_infer(image=warped)["image"].unsqueeze(0)
#     with torch.no_grad():
#         p_img=torch.softmax(model_img(ti), dim=1).cpu().numpy()[0]

#     # Probas tab
#     feats=build_features(ocr_items, expected)
#     feat_vec=np.array([[feats.get(c,0.0) for c in feat_cols]], dtype=np.float32)
#     with torch.no_grad():
#         p_tab=torch.softmax(tab(torch.tensor(feat_vec)), dim=1).cpu().numpy()[0]

#     p=fuse_probs(p_img, p_tab, alpha=0.6)
#     pred=int(np.argmax(p))
#     sub_rows.append({"image_id":image_id, "class_pred": CLASSES[pred]})

# pd.DataFrame(sub_rows).to_csv(os.path.join(SUB_DIR, "submission.csv"), index=False)
# print("Fichiers écrits dans:", SUB_DIR)

## 10) Sauvegarde des poids au format `.pt`
Lightning génère des `.ckpt`. On exporte aussi en `.pt` si besoin.

In [None]:
torch.save(model_img.state_dict(), os.path.join(WEIGHTS_DIR, "best_img.pt"))
torch.save(tab.state_dict(), os.path.join(WEIGHTS_DIR, "best_tab.pt"))
print("Poids sauvegardés dans:", WEIGHTS_DIR)

## 11) Notes
- Adapter `DATA_ROOT` vers votre dossier de données.
- Fournir un `template_image` par pays (facultatif mais recommandé) en modifiant `schema.set_template_image(country, img)`. 