In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, mean_squared_error, classification_report
import shap
from skorch import NeuralNetClassifier
from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import randint, uniform
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ============ 1) Chargement et Prétraitement des Données ============
path_plm = "data/matricePLM_genes.txt"
df_plm = pd.read_csv(path_plm, sep="\t", index_col=0)

path_expr = "data/expression_final.csv"
df_expr = pd.read_csv(path_expr, index_col=0)

# Alignement des données
genes_communs = df_plm.index.intersection(df_expr.index)
df_plm = df_plm.loc[genes_communs].sort_index()
df_expr = df_expr.loc[genes_communs].sort_index()

# Conversion des données en tenseurs PyTorch
plm_matrix = torch.tensor(df_plm.values, dtype=torch.float32)
df_expr.replace({-1: 2}, inplace=True)  # Remplacer -1 par 2 pour éviter l'erreur avec CrossEntropyLoss
labels = torch.tensor(df_expr.values, dtype=torch.long)
stress_conditions = torch.arange(df_expr.shape[1])  # Liste des conditions de stress

In [3]:
# ============ 2) Définition de la Classe Dataset ============
class GeneExpressionDataset(Dataset):
    def __init__(self, plm_matrix, stress_conditions, labels):
        self.plm_matrix = plm_matrix
        self.stress_conditions = stress_conditions
        self.labels = labels
        self.num_genes = plm_matrix.size(0)
        self.num_stress = stress_conditions.size(0)

    def __len__(self):
        return self.num_genes * self.num_stress

    def __getitem__(self, idx):
        gene_idx = idx // self.num_stress
        stress_idx = idx % self.num_stress
        plm_vector = self.plm_matrix[gene_idx]
        stress_id = self.stress_conditions[stress_idx]
        label = self.labels[gene_idx, stress_idx]
        return plm_vector, stress_id, label

In [4]:
# ============ 3) Définition du Modèle ============
class GeneExpressionModel(nn.Module):
    def __init__(self, num_plms, embedding_dim=16, hidden_dim=64, num_classes=3):
        super().__init__()
        self.plm_embedding = nn.Embedding(num_embeddings=num_plms, embedding_dim=embedding_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, plm_indices_batch):
        plm_embedded_batch = []
        for indices in plm_indices_batch:
            if indices.numel() == 0:
                aggregated = torch.zeros(self.plm_embedding.embedding_dim, device=indices.device)
            else:
                embedded = self.plm_embedding(indices)
                aggregated = embedded.mean(dim=0)
            plm_embedded_batch.append(aggregated)
        aggregated_plm = torch.stack(plm_embedded_batch, dim=0)
        return self.mlp(aggregated_plm)

In [5]:
# ============ 4) Préparation pour l'Entraînement ============
batch_size = 64
num_epochs = 100
learning_rate = 5e-4

dataset = GeneExpressionDataset(plm_matrix, stress_conditions, labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

num_plms = plm_matrix.size(1)
model = GeneExpressionModel(num_plms)

# Calculer le nombre d'échantillons de chaque classe
num_samples = torch.bincount(labels.flatten(), minlength=3)  # Décalage -1 à 0, 0 à 1, 1 à 2
class_weights = 1.0 / num_samples.float()
class_weights /= class_weights.sum() # Normaliser la somme des poids à 1

# Utiliser la perte d'entropie croisée pondérée
criterion = nn.CrossEntropyLoss(weight=class_weights.to(plm_matrix.device))
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [6]:
# ============ 5) Boucle d'Entraînement avec Checkpoint et Early Stopping ============
best_loss = float('inf')
checkpoint_path = "checkpoints/best_model.pth"
patience = 10
counter = 0

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for plm_batch, stress_id_batch, label_batch in dataloader:
        optimizer.zero_grad()
        plm_indices_batch = [(plm_vec > 0).nonzero(as_tuple=True)[0] for plm_vec in plm_batch]
        outputs = model(plm_indices_batch)
        loss = criterion(outputs, label_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
    
    if avg_loss < best_loss:
        best_loss = avg_loss
        os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) 
        torch.save(model.state_dict(), checkpoint_path)
        print(f"✔️ Meilleur modèle sauvegardé à {checkpoint_path} (Loss: {best_loss:.4f})")
        counter = 0
    else:
        counter += 1
        print(f"⚠️ Pas d'amélioration depuis {counter} epoch(s)")
    
    if counter >= patience:
        print("🛑 Early stopping déclenché !")
        break



Epoch 1/100, Loss: 1.0903
✔️ Meilleur modèle sauvegardé à checkpoints/best_model.pth (Loss: 1.0903)
Epoch 2/100, Loss: 1.0833
✔️ Meilleur modèle sauvegardé à checkpoints/best_model.pth (Loss: 1.0833)
Epoch 3/100, Loss: 1.0785
✔️ Meilleur modèle sauvegardé à checkpoints/best_model.pth (Loss: 1.0785)
Epoch 4/100, Loss: 1.0757
✔️ Meilleur modèle sauvegardé à checkpoints/best_model.pth (Loss: 1.0757)
Epoch 5/100, Loss: 1.0737
✔️ Meilleur modèle sauvegardé à checkpoints/best_model.pth (Loss: 1.0737)
Epoch 6/100, Loss: 1.0721
✔️ Meilleur modèle sauvegardé à checkpoints/best_model.pth (Loss: 1.0721)
Epoch 7/100, Loss: 1.0710
✔️ Meilleur modèle sauvegardé à checkpoints/best_model.pth (Loss: 1.0710)
Epoch 8/100, Loss: 1.0694
✔️ Meilleur modèle sauvegardé à checkpoints/best_model.pth (Loss: 1.0694)
Epoch 9/100, Loss: 1.0687
✔️ Meilleur modèle sauvegardé à checkpoints/best_model.pth (Loss: 1.0687)
Epoch 10/100, Loss: 1.0680
✔️ Meilleur modèle sauvegardé à checkpoints/best_model.pth (Loss: 1.0680)

In [7]:
# ============ 6) Sauvegarde du Modèle ==========
final_model_path = "checkpoints/final_model.pth"
torch.save(model.state_dict(), final_model_path)
print(f"📌 Modèle final sauvegardé à {final_model_path}")

📌 Modèle final sauvegardé à checkpoints/final_model.pth


In [8]:
# ============ 7) Évaluation du Modèle ============
def evaluate_model(model, dataloader, criterion=None):
    model.eval()
    all_preds, all_labels, total_loss = [], [], 0.0
    with torch.no_grad():
        for plm_batch, stress_id_batch, label_batch in dataloader:
            plm_indices_batch = [(plm_vec > 0).nonzero(as_tuple=True)[0] for plm_vec in plm_batch]
            outputs = model(plm_indices_batch)
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(label_batch.cpu().numpy())
            if criterion:
                loss = criterion(outputs, label_batch)
                total_loss += loss.item()
    return {"Accuracy": accuracy_score(all_labels, all_preds)}

test_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
results = evaluate_model(model, test_dataloader, criterion)
print(results)

{'Accuracy': 0.7621162572346742}


In [1]:
# ============ 8) Explication avec SHAP ============
def model_wrapper(data):
    # Conversion des données NumPy en tenseurs PyTorch
    data_tensor = torch.tensor(data, dtype=torch.float32)
    plm_indices_batch = [(row > 0).nonzero(as_tuple=True)[0] for row in data_tensor] 
    with torch.no_grad():
        return model(plm_indices_batch).cpu().numpy()

# Conversion de la matrice PLM en tableau NumPy pour SHAP
plm_matrix_np = plm_matrix.numpy()

# Initialisation de l'explainer SHAP avec KernelExplainer
explainer = shap.KernelExplainer(model_wrapper, plm_matrix_np)

# Calcul des valeurs SHAP
shap_values = explainer.shap_values(plm_matrix_np)

# Visualisation des valeurs SHAP
shap.summary_plot(shap_values, plm_matrix_np, feature_names=df_plm.columns.tolist())


NameError: name 'plm_matrix' is not defined

# ============ 9) Recherche d'Hyperparamètres avec Skorch ============
skorch_model = NeuralNetClassifier(
    module=GeneExpressionModel,
    module__num_plms=plm_matrix.size(1),
    module__num_classes=3,
    max_epochs=10,
    lr=1e-4,
    optimizer=torch.optim.Adam,
    criterion=nn.CrossEntropyLoss,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

param_distributions = {
    'lr': uniform(1e-5, 1e-2),
    'max_epochs': randint(10, 50),
    'batch_size': randint(16, 64),
}

search = RandomizedSearchCV(skorch_model, param_distributions, n_iter=20, scoring='accuracy', cv=3, random_state=42)
X, y = plm_matrix.numpy(), labels.numpy().flatten()
search.fit(X, y)
print("Meilleurs paramètres:", search.best_params_)