dataset : https://zenodo.org/records/10066853

In [None]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split
from PIL import Image
import os
import json
import random
import torchvision.models as models

# Augmenter la limite de pixels pour éviter le warning DecompressionBombWarning
Image.MAX_IMAGE_PIXELS = None

# Nom du dossier contenant les images HE (à ajuster si nécessaire)
he_folder_name = "HE"

# 1. Génération du fichier JSON des labels
def create_label_json(dataset_dir, output_json):
    labels = {}
    for immunomark in os.listdir(dataset_dir):
        immunomark_path = os.path.join(dataset_dir, immunomark)
        if os.path.isdir(immunomark_path):
            for img_name in os.listdir(immunomark_path):
                if img_name.endswith(".png") or img_name.endswith(".jpg"):
                    labels[img_name] = immunomark  # Associe l'image au type d'immunomarquage
    
    with open(output_json, "w") as f:
        json.dump(labels, f, indent=4)
    print(f"Fichier JSON généré : {output_json}")

# Définir le chemin du dataset et du fichier JSON
image_dir = "C:/Users/baril/Code/SiRiC/immunomarquage/dataset_VM"
label_file = "labels.json"

# Génération du fichier JSON
create_label_json(image_dir, label_file)

# 2. Téléchargement et prétraitement des données
with open(label_file, "r") as f:
    labels = json.load(f)

# Création d'un dictionnaire pour convertir les noms d'immunomarquage en indices
unique_labels = sorted(set(labels.values()) - {he_folder_name})  # Exclure HE des labels normaux
label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
label_to_idx[he_folder_name] = -1  # Spécialisation des HE

print("Mapping des labels:", label_to_idx)

class TMADataset(Dataset):
    def __init__(self, image_paths, labels, label_to_idx, transform=None):
        self.image_paths = image_paths
        self.labels = labels  # Dictionnaire associant une image à son type d'immunomarquage
        self.label_to_idx = label_to_idx  # Mapping texte -> indice
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert("RGB")
        label_str = self.labels.get(os.path.basename(img_path), None)
        if label_str is None or label_str not in self.label_to_idx:
            raise ValueError(f"Label introuvable ou incorrect pour l'image {img_path}")
        
        label = self.label_to_idx[label_str]
        
        # Si c'est une image HE, assigner un label aléatoire pour l'entraînement
        if label == -1:
            label = random.choice(list(label_to_idx.values())[:-1])  # Exclure HE
        
        if self.transform:
            img = self.transform(img)
        
        return img, torch.tensor(label, dtype=torch.long)

# Définition des transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

valid_image_paths = []
for img_path in image_paths:
    try:
        with Image.open(img_path) as img:
            img.verify()  # Vérification de l'intégrité du fichier
        valid_image_paths.append(img_path)
    except Exception as e:
        print(f"Image corrompue détectée et ignorée : {img_path}, Erreur: {e}")

image_paths = valid_image_paths
#image_paths = [os.path.join(root, fname) for root, _, files in os.walk(image_dir) for fname in files if fname.endswith(".png")]

# Utilisation d'une partie du dataset pour l'entraînement (ex: 50%)
train_size = int(0.5 * len(image_paths))
dataset = TMADataset(image_paths[:train_size], labels, label_to_idx, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# 3. Définition d'un Consistency Model conditionnel
class ConsistencyModel(torch.nn.Module):
    def __init__(self, num_classes):
        super(ConsistencyModel, self).__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU()
        )
        self.label_embedding = torch.nn.Embedding(num_classes, 128)
        self.decoder = torch.nn.Sequential(
            torch.nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
        )
    
    def forward(self, x, labels):
        x = self.encoder(x)
        label_embed = self.label_embedding(labels).view(labels.size(0), 128, 1, 1)
        x = x + label_embed  # Fusion du label avec les features
        x = self.decoder(x)
        return x

# Initialisation du modèle
num_classes = len(label_to_idx) - 1  # Exclure HE du comptage
target_classes = len(label_to_idx) - 1
model = ConsistencyModel(target_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Ajout d'une perte perceptuelle équilibrée
vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features[:16].to(device)
vgg.eval()
for param in vgg.parameters():
    param.requires_grad = False

def perceptual_loss(output, target):
    return 0.1 * torch.nn.functional.mse_loss(vgg(output), vgg(target))

# 4. Entraînement du modèle et sauvegarde des checkpoints
optimizer = torch.optim.Adam(model.parameters(), lr=0.0002, betas=(0.5, 0.999))
criterion = torch.nn.L1Loss()
checkpoint_dir = "checkpoints_TMA"
os.makedirs(checkpoint_dir, exist_ok=True)

def train_model(model, dataloader, epochs=10):
    model.train()
    for epoch in range(epochs):
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device), labels.to(device)
            output = model(imgs, labels)
            loss = criterion(output, imgs) + perceptual_loss(output, imgs)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}, Loss: {loss.item()}")
        torch.save(model.state_dict(), os.path.join(checkpoint_dir, f"model_epoch_{epoch+1}.pth"))

train_model(model, dataloader)


In [None]:
# 5. Test du modèle sur une image inédite
def test_model(model, image_path, target_label):
    model.eval()
    img = Image.open(image_path).convert("RGB")
    img = transform(img).unsqueeze(0).to(device)
    label_idx = torch.tensor([label_to_idx[target_label]], dtype=torch.long).to(device)
    with torch.no_grad():
        output = model(img, label_idx)
    output_img = transforms.ToPILImage()(output.squeeze(0).cpu())
    output_img.save("output_image_p53.png")
    print(f"Image enregistrée sous output_image.png avec immunomarquage : {target_label}")

# Exemple d'utilisation
test_image = "C:/Users/baril/Code/SiRiC/immunomarquage/dataset_VM/HE/B3_TMA_15_02_IB_HE.png"  # Remplace par un chemin valide
test_label = "p53"  # Remplace par un immunomarquage existant
test_model(model, test_image, test_label)


Avec seuillage d'Otsu pour garantir que le fond blanc reste blanc

In [2]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import json
import random
import torchvision.models as models
import cv2
import numpy as np

# Augmenter la limite de pixels pour éviter le warning DecompressionBombWarning
Image.MAX_IMAGE_PIXELS = None

# Nom du dossier contenant les images HE (à ajuster si nécessaire)
he_folder_name = "HE"

# 1. Génération du fichier JSON des labels
def create_label_json(dataset_dir, output_json):
    labels = {}
    for immunomark in os.listdir(dataset_dir):
        immunomark_path = os.path.join(dataset_dir, immunomark)
        if os.path.isdir(immunomark_path):
            for img_name in os.listdir(immunomark_path):
                if img_name.endswith(".png") or img_name.endswith(".jpg"):
                    labels[img_name] = immunomark  # Associe l'image au type d'immunomarquage
    
    with open(output_json, "w") as f:
        json.dump(labels, f, indent=4)
    print(f"Fichier JSON généré : {output_json}")

# Définir le chemin du dataset et du fichier JSON
image_dir = "C:/Users/baril/Code/SiRiC/immunomarquage/dataset_VM"
label_file = "labels.json"

# Génération du fichier JSON
create_label_json(image_dir, label_file)

# 2. Téléchargement et prétraitement des données
with open(label_file, "r") as f:
    labels = json.load(f)

# Création d'un dictionnaire pour convertir les noms d'immunomarquage en indices
unique_labels = sorted(set(labels.values()) - {he_folder_name})  # Exclure HE des labels normaux
label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
label_to_idx[he_folder_name] = len(unique_labels)  # Assigner un index spécial à HE

print("Mapping des labels:", label_to_idx)

class TMADataset(Dataset):
    def __init__(self, image_paths, labels, label_to_idx, transform=None):
        self.image_paths = image_paths
        self.labels = labels  # Dictionnaire associant une image à son type d'immunomarquage
        self.label_to_idx = label_to_idx  # Mapping texte -> indice
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert("RGB")
        label_str = self.labels.get(os.path.basename(img_path), None)
        if label_str is None or label_str not in self.label_to_idx:
            raise ValueError(f"Label introuvable ou incorrect pour l'image {img_path}")
        
        label = self.label_to_idx[label_str]
        
        # Si c'est une image HE, assigner un label aléatoire pour l'entraînement
        if label == len(unique_labels):
            label = random.choice(list(label_to_idx.values())[:-1])  # Exclure HE
        
        if self.transform:
            img = self.transform(img)
        
        return img, torch.tensor(label, dtype=torch.long)

# Définition des transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

def detect_white_background(image):
    """Détecte les pixels de fond blanc dans une image et retourne un masque booléen."""
    if len(image.shape) == 3 and image.shape[2] == 3:
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    else:
        raise ValueError("L'image d'entrée doit avoir 3 canaux (RGB)")
    
    _, mask = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY)
    return mask.astype(bool)

image_paths = [os.path.join(root, fname) for root, _, files in os.walk(image_dir) for fname in files if fname.endswith((".png", ".jpg"))]

valid_image_paths = []
for img_path in image_paths:
    try:
        with Image.open(img_path) as img:
            img.verify()  # Vérification de l'intégrité du fichier
        valid_image_paths.append(img_path)
    except Exception as e:
        print(f"Image corrompue détectée et ignorée : {img_path}, Erreur: {e}")

image_paths = valid_image_paths

# Création du DataLoader
dataset = TMADataset(image_paths, labels, label_to_idx, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# 3. Définition d'un Consistency Model conditionnel
class ConsistencyModel(torch.nn.Module):
    def __init__(self, num_classes):
        super(ConsistencyModel, self).__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU()
        )
        self.label_embedding = torch.nn.Embedding(num_classes, 128)
        self.decoder = torch.nn.Sequential(
            torch.nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
        )
    
    def forward(self, x, labels):
        x = self.encoder(x)
        is_he = (labels == len(unique_labels))
        if not torch.any(is_he):
            label_embed = self.label_embedding(labels).view(labels.size(0), 128, 1, 1)
            x = x + label_embed  # Fusion du label avec les features
        x = self.decoder(x)
        return x

# Initialisation du modèle
num_classes = len(label_to_idx) - 1  # Exclure HE du comptage
model = ConsistencyModel(num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Définition de la fonction de perte et de l'optimiseur
criterion = torch.nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Fonction d'entraînement
def train_model(model, dataloader, epochs=10):
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device), labels.to(device)
            output = model(imgs, labels)
            loss = criterion(output, imgs)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {epoch_loss/len(dataloader)}")

# Lancement de l'entraînement
train_model(model, dataloader)


Fichier JSON généré : labels.json
Mapping des labels: {'AR': 0, 'CD146': 1, 'CD44': 2, 'ERG': 3, 'NKX3': 4, 'p53': 5, 'HE': 6}
Image corrompue détectée et ignorée : C:/Users/baril/Code/SiRiC/immunomarquage/dataset_VM\CD44\cd4.png, Erreur: cannot identify image file 'C:\\Users\\baril\\Code\\SiRiC\\immunomarquage\\dataset_VM\\CD44\\cd4.png'
Epoch 1, Loss: 0.07040721983096357
Epoch 2, Loss: 0.03457504056267819
Epoch 3, Loss: 0.027783501056670133
Epoch 4, Loss: 0.022905063908547164
Epoch 5, Loss: 0.020548305265858012
Epoch 6, Loss: 0.018575589708447204
Epoch 7, Loss: 0.01676779358284706
Epoch 8, Loss: 0.016219634880801127
Epoch 9, Loss: 0.0150031102221396
Epoch 10, Loss: 0.01447367490569161


In [3]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import os
import cv2
import numpy as np

# Charger le modèle
checkpoint_path = "checkpoints_TMA/model_epoch_10.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Charger le state_dict pour détecter le nombre de classes
checkpoint = torch.load(checkpoint_path, map_location=device)
num_classes = checkpoint['label_embedding.weight'].shape[0]  # Déduire le nombre de classes

class ConsistencyModel(torch.nn.Module):
    def __init__(self, num_classes):
        super(ConsistencyModel, self).__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU()
        )
        self.label_embedding = torch.nn.Embedding(num_classes, 128)
        self.decoder = torch.nn.Sequential(
            torch.nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
        )
    
    def forward(self, x, labels):
        x = self.encoder(x)
        label_embed = self.label_embedding(labels).view(labels.size(0), 128, 1, 1)
        x = x + label_embed  # Fusion du label avec les features
        x = self.decoder(x)
        return x

# Charger le modèle sauvegardé
model = ConsistencyModel(num_classes)
model.load_state_dict(checkpoint)
model.to(device)
model.eval()

# Transformer une image HE
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

def detect_white_background(image):
    """Détecte les pixels de fond blanc dans une image et retourne un masque booléen."""
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    _, mask = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY)
    return mask.astype(bool)

def apply_immunostaining(image_path, chosen_label):
    img = Image.open(image_path).convert("RGB")
    img_resized = img.resize((256, 256))
    img_np = np.array(img_resized)
    mask = detect_white_background(img_np)
    img_tensor = transform(img_resized).unsqueeze(0).to(device)
    
    label_to_idx = {"AR": 0, "CD146": 1, "CD44": 2, "ERG": 3, "NKX3": 4, "p53": 5}  # HE exclu
    label_idx = torch.tensor([label_to_idx[chosen_label]], dtype=torch.long).to(device)
    
    with torch.no_grad():
        output = model(img_tensor, label_idx)
    
    output_img = output.squeeze(0).cpu().permute(1, 2, 0).numpy()
    output_img = (output_img * 255).astype(np.uint8)
    
    # Appliquer le masque : garder le fond blanc
    output_img[mask] = [255, 255, 255]
    
    output_pil = Image.fromarray(output_img.astype(np.uint8))
    return output_pil



  checkpoint = torch.load(checkpoint_path, map_location=device)


In [20]:
# Exemple d'utilisation
imgs = os.listdir('C:/Users/baril/Code/SiRiC/immunomarquage/HE_test_VM/')
for img in imgs[10:] :
    print(img)
    image_path = "C:/Users/baril/Code/SiRiC/immunomarquage/HE_test_VM/"+img  # Remplace par le chemin de ton image HE

    labels = ['AR','CD146','CD44','ERG','NKX3','p53']
    for label in labels :
        chosen_label = label  # Choisir l'immunomarquage voulu
        print(chosen_label)
        output_image = apply_immunostaining(image_path, chosen_label)
        #output_image.show()  # Afficher l'image générée
        output_dir = 'output_immunostained/'+chosen_label+'/'
        if not os.path.exists(output_dir) :
            os.mkdir(output_dir)
        output_image.save(output_dir+img)

E1_TMA_15_02_IIIB_HE.png
AR
CD146
CD44
ERG
NKX3
p53
E6_TMA_15_02_IB_HE.png
AR
CD146
CD44
ERG
NKX3
p53
E9_TMA_15_02_IIIB_HE.png
AR
CD146
CD44
ERG
NKX3
p53
F2_TMA_15_02_IIIB_HE.png
AR
CD146
CD44
ERG
NKX3
p53
F6_TMA_15_02_IIIB_HE.png
AR
CD146
CD44
ERG
NKX3
p53
G11_TMA_15_02_IIIB_HE.png
AR
CD146
CD44
ERG
NKX3
p53
G12_TMA_15_02_IIIB_HE.png
AR
CD146
CD44
ERG
NKX3
p53
G1_TMA_15_02_IIB_HE.png
AR
CD146
CD44
ERG
NKX3
p53
G5_TMA_15_02_IIIB_HE.png
AR
CD146
CD44
ERG
NKX3
p53
H10_TMA_15_02_IB_HE.png
AR
CD146
CD44
ERG
NKX3
p53
H4_TMA_15_02_IB_HE.png
AR
CD146
CD44
ERG
NKX3
p53
I3_TMA_15_02_IIB_HE.png
AR
CD146
CD44
ERG
NKX3
p53
I9_TMA_15_02_IIIB_HE.png
AR
CD146
CD44
ERG
NKX3
p53
J10_TMA_15_02_IIIB_HE.png
AR
CD146
CD44
ERG
NKX3
p53
J3_TMA_15_02_IIB_HE.png
AR
CD146
CD44
ERG
NKX3
p53
K11_TMA_15_02_IIIB_HE.png
AR
CD146
CD44
ERG
NKX3
p53
K12_TMA_15_02_IB_HE.png
AR
CD146
CD44
ERG
NKX3
p53
K6_TMA_15_02_IB_HE.png
AR
CD146
CD44
ERG
NKX3
p53
L7_TMA_15_02_IIIB_HE.png
AR
CD146
CD44
ERG
NKX3
p53
M4_TMA_15_02_IB_HE.

In [None]:
'AR': 0, 'CD146': 1, 'CD44': 2, 'ERG': 3, 'NKX3': 4, 'p53': 5, 'HE': 6