In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt



In [2]:
# ============ 1) Chargement des données ============

# -- A. Matrice PLM (présence (1)/absence(0)) --
path_plm = "data/matricePLM_genes.txt"
df_plm = pd.read_csv(path_plm, sep="\t", index_col=0)
df_plm

Unnamed: 0,WMCAATAATTRW_-317,TGTAAAGT_-280,CCAATGT_-275,GGATA_-263,TAACAAA_-261,HDTTAACAGAAWW_-260,AWTTAAWT_-239,RTTTTTR_-229,TATCCA_-226,GRWAAW_-212,...,SAGATCYRR_295,DYCACCGACAHH_312,GTGGWWHG_319,AGATCCAA_326,HYRGATCYRD_334,ATGTCGGYRR_344,YAGATCTR_353,CTGACY_367,ACNGCT_426,AGCAGC_444
AT5G09440,0,0,0,0,0,0,0,0,0,1,...,0,0,0,0,0,0,0,0,1,0
AT3G29320,0,0,0,0,0,0,0,0,0,1,...,0,0,0,0,0,0,0,0,0,1
AT5G59570,0,0,0,0,0,0,0,0,0,0,...,1,0,0,0,1,0,0,0,1,0
AT3G55830,0,0,0,0,0,0,0,0,0,1,...,1,0,0,0,0,0,0,0,0,0
AT3G16140,0,0,0,0,0,0,0,1,0,1,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
AT2G03670,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,0
AT3G12530,0,0,0,0,0,0,0,0,0,1,...,0,0,0,0,0,0,0,0,0,0
AT5G40780,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
AT1G56280,0,0,0,0,0,0,0,0,0,1,...,0,0,0,0,0,0,0,0,0,0


In [3]:
# -- B. Données d'expression (down (-1), no change (0), up (1))--
path_expr = "data/expression_final.csv"
df_expr = pd.read_csv(path_expr, index_col=0)
df_expr

Unnamed: 0_level_0,1101,1104,29,30,1977,1973,1980,1976,1978,1974,...,1563,796,797,799,800,798,801,1296,1677,1672
Unnamed: 0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AT1G01010,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0
AT1G01030,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AT1G01040,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AT1G01050,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AT1G01060,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,-1.0,0.0,0.0,-1.0,-1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
AT5G67550,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AT5G67560,0.0,0.0,0.0,-1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,-1.0,0.0,0.0,0.0
AT5G67590,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,-1.0,-1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AT5G67620,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [4]:
# Alignement des données
common_genes = df_plm.index.intersection(df_expr.index)

In [5]:
# Filtrer les deux DataFrames par l'ensemble common_genes et trier dans le même ordre
df_plm = df_plm.loc[common_genes].sort_index()
df_expr = df_expr.loc[common_genes].sort_index()

In [6]:
# Conversion des données en tenseurs PyTorch
plm_matrix = torch.tensor(df_plm.values, dtype=torch.float32)
labels = torch.tensor(df_expr.values, dtype=torch.long)

In [7]:
# Lister les indices des PLMs activés (valeur = 1) pour chaque gène
active_plms = [
    np.where(df_plm.iloc[gene_idx].values == 1)[0]
    for gene_idx in range(df_plm.shape[0])
]

In [8]:
# ------------- 2) Création de la classe Dataset -------------
class GeneExpressionDatasetFiltered(Dataset):
    def __init__(self, active_plms, labels):
        self.active_plms = active_plms  # Liste des indices PLMs activés pour chaque gène
        self.labels = labels  # Matrice des étiquettes

    def __len__(self):
        # Total des échantillons = nombre de gènes * nombre de conditions de stress
        return len(self.active_plms) * self.labels.shape[1]

    def __getitem__(self, idx):
        # Calculer l'indice du gène et de la condition de stress
        num_stress = self.labels.shape[1]
        gene_idx = idx // num_stress
        stress_idx = idx % num_stress

        # Récupérer les indices PLMs activés pour le gène
        plm_indices = self.active_plms[gene_idx]

        # Récupérer l'étiquette correspondante
        label = self.labels[gene_idx, stress_idx]

        return plm_indices, stress_idx, label


In [9]:
class GeneExpressionModel(nn.Module):
    def __init__(self, num_plms, embedding_dim=16, hidden_dim=64, num_classes=3):
        super(GeneExpressionModel, self).__init__()
        
        # Matrice d'embedding pour les PLMs
        self.plm_embedding = nn.Embedding(num_plms, embedding_dim)
        
        # MLP pour la classification finale
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, plm_indices_batch):
        # Embedding des PLMs pour chaque batch
        plm_embedded_batch = []
        for indices in plm_indices_batch:
            if len(indices) > 0:  # Vérifier s'il y a des PLMs activés
                indices_tensor = torch.tensor(indices, dtype=torch.long, device=self.plm_embedding.weight.device)
                embedded = self.plm_embedding(indices_tensor)  # Taille : (num_plms_actifs, embedding_dim)
                aggregated = torch.mean(embedded, dim=0)  # Moyenne des embeddings
            else:
                aggregated = torch.zeros(self.plm_embedding.embedding_dim, device=self.plm_embedding.weight.device)  # Vector nul si aucun PLM activé
            plm_embedded_batch.append(aggregated)

        # Convertir en un tensor (batch_size, embedding_dim)
        aggregated_plm = torch.stack(plm_embedded_batch, dim=0)  # Taille : (batch_size, embedding_dim)

        # Passer dans le MLP
        output = self.mlp(aggregated_plm)
        return output



In [10]:
# ------------- 4) Préparation pour l'entraînement -------------
# Hyperparamètres
batch_size = 4
num_epochs = 20
learning_rate = 0.001

# Préparer Dataset et DataLoader
dataset = GeneExpressionDatasetFiltered(active_plms, labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: (
    [item[0] for item in x],  # PLM indices
    torch.tensor([item[1] for item in x], dtype=torch.long),  # Stress indices
    torch.tensor([item[2] for item in x], dtype=torch.long)   # Labels
))

# Initialiser le modèle
num_plms = df_plm.shape[1]  # Nombre total de PLMs (colonnes de df_plm)
model = GeneExpressionModel(num_plms, embedding_dim=16, hidden_dim=64, num_classes=3)

# Fonction de perte et optimiseur
criterion = nn.CrossEntropyLoss(ignore_index=-1)  # Ignorer les échantillons avec étiquette -1
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [11]:
# ------------- 5) Boucle d'entraînement -------------
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for plm_indices, stress_idx, label in dataloader:
        optimizer.zero_grad()
        
        # Forward pass
        output = model(plm_indices)
        
        # Calcul de la perte
        loss = criterion(output, label)
        total_loss += loss.item()
        
        # Backpropagation et mise à jour des poids
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}")

Epoch 1/20, Loss: nan


KeyboardInterrupt: 

In [None]:
# ------------- 6) Prédiction -------------
model.eval()
with torch.no_grad():
    # Exemple : prédire pour le Gène 1
    plm_indices = active_plms[0]  # Indices PLMs activés pour le Gène 1

    # Prédiction
    output = model([plm_indices])  # Ajouter les indices dans une liste pour le batch
    predicted_class = torch.argmax(output, dim=1)

    print(f"Prédiction pour le Gène 1 : Classe {predicted_class.item()}")