In [1]:
def detect_white_background(image):
    """Détecte les pixels de fond blanc dans une image et retourne un masque booléen."""
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    _, mask = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY)
    return mask.astype(bool)

def apply_immunostaining(image_path, chosen_label):
    img = Image.open(image_path).convert("RGB")
    img_resized = img.resize((256, 256), Image.BICUBIC)
    img_np = np.array(img_resized)
    print(f"Taille avant mask : {img_np.shape}")
    mask = detect_white_background(img_np)
    img_tensor = transform(img_resized).unsqueeze(0).to(device)
    
    label_to_idx = {"AR": 0, "CD146": 1, "CD44": 2, "ERG": 3, "NKX3": 4, "p53": 5}  # HE exclu
    label_idx = torch.tensor([label_to_idx[chosen_label]], dtype=torch.long).to(device)
    
    with torch.no_grad():
        output = model(img_tensor, label_idx)
    
    output_img = output.squeeze(0).cpu().permute(1, 2, 0).numpy()
    output_img = (output_img * 255).astype(np.uint8)
    
    # Appliquer le masque : garder le fond blanc
    output_img[mask] = [255, 255, 255]
    
    output_pil = Image.fromarray(output_img.astype(np.uint8))
    return output_pil

In [8]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision import models
from PIL import Image
import os
import json
import random
import cv2
import numpy as np

# Augmenter la limite de pixels pour éviter les warnings
Image.MAX_IMAGE_PIXELS = None

# Nom du dossier contenant les images HE (à ajuster si nécessaire)
he_folder_name = "HE"

# 1. Génération du fichier JSON des labels
def create_label_json(dataset_dir, output_json):
    labels = {}
    for immunomark in os.listdir(dataset_dir):
        immunomark_path = os.path.join(dataset_dir, immunomark)
        if os.path.isdir(immunomark_path):
            for img_name in os.listdir(immunomark_path):
                if img_name.endswith(".png") or img_name.endswith(".jpg"):
                    labels[img_name] = immunomark  # Associe l'image au type d'immunomarquage
    
    with open(output_json, "w") as f:
        json.dump(labels, f, indent=4)
    print(f"Fichier JSON généré : {output_json}")

# Définition du dataset
class TMADataset(Dataset):
    def __init__(self, dataset_dir, labels, label_to_idx, transform=None):
        self.dataset_dir = dataset_dir
        self.image_paths = [os.path.join(dataset_dir, immunomark, img) for img, immunomark in labels.items()]
        self.labels = labels
        self.label_to_idx = label_to_idx
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
        
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            img = Image.open(img_path).convert("RGB")
        except (IOError, OSError) as e:
            print(f"⚠️ Image corrompue ignorée : {img_path}")
            return self.__getitem__((idx + 1) % len(self.image_paths))  # Passe à l’image suivante
        
        label_str = self.labels.get(os.path.basename(img_path), None)
        if label_str is None or label_str not in self.label_to_idx:
            raise ValueError(f"Label invalide pour l'image {img_path}")
    
        label = self.label_to_idx[label_str]
        if self.transform:
            img = self.transform(img)
        
        return img, torch.tensor(label, dtype=torch.long)

# Définition des transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Chargement des données
dataset_dir = "C:/Users/baril/Code/SiRiC/immunomarquage/dataset_VM/"  # Chemin correct du dataset
label_file = "labels.json"
with open(label_file, "r") as f:
    labels = json.load(f)

label_to_idx = {label: idx for idx, label in enumerate(set(labels.values()))}

# Création du DataLoader
dataset = TMADataset(dataset_dir, labels, label_to_idx, transform=transform)
print(f"Nombre d'images dans le dataset : {len(dataset)}")
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0)

# Définition du modèle U-Net amélioré
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
    
    def forward(self, x):
        batch_size, C, width, height = x.size()
        query = self.query(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        key = self.key(x).view(batch_size, -1, width * height)
        energy = torch.bmm(query, key)
        attention = torch.nn.functional.softmax(energy, dim=-1)
        value = self.value(x).view(batch_size, -1, width * height)
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, width, height)
        return self.gamma * out + x

class UNet(nn.Module):
    def __init__(self, num_classes):
        super(UNet, self).__init__()
        self.encoder = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.encoder.fc = nn.Identity()
        self.label_embedding = nn.Embedding(num_classes, 2048)
        
        self.expand_conv = nn.Sequential(
            nn.ConvTranspose2d(2048, 1024, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),
            nn.ReLU()
        )
        
        self.attention = SelfAttention(512)
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=1, stride=1, padding=0),
            nn.Sigmoid()
        )
    
    def forward(self, x, labels):
        x = self.encoder(x)
        x = x.view(x.size(0), 2048, 1, 1)
        label_embed = self.label_embedding(labels).view(labels.size(0), 2048, 1, 1)
        x = x + label_embed
        x = self.expand_conv(x)
        x = self.attention(x)
        x = self.decoder(x)
        return torch.nn.functional.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)

# Initialisation du modèle
num_classes = len(label_to_idx)
model = UNet(num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0002, betas=(0.5, 0.999))

def train_model(model, dataloader, epochs=10):
    model.train()
    for epoch in range(epochs):
        print(f'Epoch {epoch+1}')
        epoch_loss = 0
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            output = model(imgs, labels)
            loss = criterion(output, torch.nn.functional.interpolate(imgs, size=(256, 256), mode='bilinear', align_corners=False))
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {epoch_loss / len(dataloader)}")
        torch.save(model.state_dict(), f"unet_epoch_{epoch+1}.pth")

train_model(model, dataloader, epochs=10)


Nombre d'images dans le dataset : 4720
Epoch 1
⚠️ Image corrompue ignorée : C:/Users/baril/Code/SiRiC/immunomarquage/dataset_VM/CD44\cd4.png
Epoch 1, Loss: 0.07842466816053552
Epoch 2
⚠️ Image corrompue ignorée : C:/Users/baril/Code/SiRiC/immunomarquage/dataset_VM/CD44\cd4.png
Epoch 2, Loss: 0.05267187471233182
Epoch 3
⚠️ Image corrompue ignorée : C:/Users/baril/Code/SiRiC/immunomarquage/dataset_VM/CD44\cd4.png
Epoch 3, Loss: 0.05130449905991554
Epoch 4
⚠️ Image corrompue ignorée : C:/Users/baril/Code/SiRiC/immunomarquage/dataset_VM/CD44\cd4.png
Epoch 4, Loss: 0.050301981966753125
Epoch 5
⚠️ Image corrompue ignorée : C:/Users/baril/Code/SiRiC/immunomarquage/dataset_VM/CD44\cd4.png
Epoch 5, Loss: 0.049483979600718465
Epoch 6
⚠️ Image corrompue ignorée : C:/Users/baril/Code/SiRiC/immunomarquage/dataset_VM/CD44\cd4.png
Epoch 6, Loss: 0.048798508834788354
Epoch 7
⚠️ Image corrompue ignorée : C:/Users/baril/Code/SiRiC/immunomarquage/dataset_VM/CD44\cd4.png
Epoch 7, Loss: 0.0480779795070826


In [9]:
torch.save(model.state_dict(), "checkpoints_TMA/model_V4_epoch_10.pth")

In [10]:
model.load_state_dict(torch.load("checkpoints_TMA/model_V4_epoch_10.pth"))
model.to(device)
model.eval()

  model.load_state_dict(torch.load("checkpoints_TMA/model_V4_epoch_10.pth"))


UNet(
  (encoder): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0)

In [12]:
# Exemple d'utilisation
imgs = os.listdir('C:/Users/baril/Code/SiRiC/immunomarquage/HE_test_VM/')
for img in imgs :
    print(img)
    image_path = "C:/Users/baril/Code/SiRiC/immunomarquage/HE_test_VM/"+img  # Remplace par le chemin de ton image HE

    labels = ['AR','CD146','CD44','ERG','NKX3','p53']
    for label in labels :
        chosen_label = label  # Choisir l'immunomarquage voulu
        print(chosen_label)
        output_image = apply_immunostaining(image_path, chosen_label)
        #output_image.show()  # Afficher l'image générée
        output_dir = 'output_immunostained_V4/'+chosen_label+'/'
        if not os.path.exists(output_dir) :
            os.mkdir(output_dir)
        output_image.save(output_dir+img)

A2_TMA_15_02_IIB_HE.png
AR
Taille avant mask : (256, 256, 3)
CD146
Taille avant mask : (256, 256, 3)
CD44
Taille avant mask : (256, 256, 3)
ERG
Taille avant mask : (256, 256, 3)
NKX3
Taille avant mask : (256, 256, 3)
p53
Taille avant mask : (256, 256, 3)
A5_TMA_15_02_IIB_HE.png
AR
Taille avant mask : (256, 256, 3)
CD146
Taille avant mask : (256, 256, 3)
CD44
Taille avant mask : (256, 256, 3)
ERG
Taille avant mask : (256, 256, 3)
NKX3
Taille avant mask : (256, 256, 3)
p53
Taille avant mask : (256, 256, 3)
A9_TMA_15_02_IVB_HE.png
AR
Taille avant mask : (256, 256, 3)
CD146
Taille avant mask : (256, 256, 3)
CD44
Taille avant mask : (256, 256, 3)
ERG
Taille avant mask : (256, 256, 3)
NKX3
Taille avant mask : (256, 256, 3)
p53
Taille avant mask : (256, 256, 3)
B2_TMA_15_02_IVB_HE.png
AR
Taille avant mask : (256, 256, 3)
CD146
Taille avant mask : (256, 256, 3)
CD44
Taille avant mask : (256, 256, 3)
ERG
Taille avant mask : (256, 256, 3)
NKX3
Taille avant mask : (256, 256, 3)
p53
Taille avant m