In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
# =========================================================
# 1) LECTURE DES DONNÉES
# =========================================================

# -- A. Matrice PLM (présence (1)/absence(0)) --
path_plm = "data/matricePLM_genes.txt"
df_plm = pd.read_csv(path_plm, sep="\t", index_col=0)
df_plm

Unnamed: 0,WMCAATAATTRW_-317,TGTAAAGT_-280,CCAATGT_-275,GGATA_-263,TAACAAA_-261,HDTTAACAGAAWW_-260,AWTTAAWT_-239,RTTTTTR_-229,TATCCA_-226,GRWAAW_-212,...,SAGATCYRR_295,DYCACCGACAHH_312,GTGGWWHG_319,AGATCCAA_326,HYRGATCYRD_334,ATGTCGGYRR_344,YAGATCTR_353,CTGACY_367,ACNGCT_426,AGCAGC_444
AT5G09440,0,0,0,0,0,0,0,0,0,1,...,0,0,0,0,0,0,0,0,1,0
AT3G29320,0,0,0,0,0,0,0,0,0,1,...,0,0,0,0,0,0,0,0,0,1
AT5G59570,0,0,0,0,0,0,0,0,0,0,...,1,0,0,0,1,0,0,0,1,0
AT3G55830,0,0,0,0,0,0,0,0,0,1,...,1,0,0,0,0,0,0,0,0,0
AT3G16140,0,0,0,0,0,0,0,1,0,1,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
AT2G03670,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,0
AT3G12530,0,0,0,0,0,0,0,0,0,1,...,0,0,0,0,0,0,0,0,0,0
AT5G40780,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
AT1G56280,0,0,0,0,0,0,0,0,0,1,...,0,0,0,0,0,0,0,0,0,0


In [3]:
# -- B. Données d'expression (down (-1), no change (0), up (1))--
path_expr = "data/expression_final.csv"
df_expr = pd.read_csv(path_expr, index_col=0)
df_expr

Unnamed: 0_level_0,1101,1104,29,30,1977,1973,1980,1976,1978,1974,...,1563,796,797,799,800,798,801,1296,1677,1672
Unnamed: 0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AT1G01010,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0
AT1G01030,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AT1G01040,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AT1G01050,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AT1G01060,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,-1.0,0.0,0.0,-1.0,-1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
AT5G67550,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AT5G67560,0.0,0.0,0.0,-1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,-1.0,0.0,0.0,0.0
AT5G67590,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,-1.0,-1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AT5G67620,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [4]:
# Alignement des données
common_genes = df_plm.index.intersection(df_expr.index)
common_genes

Index(['AT5G09440', 'AT5G59570', 'AT3G16140', 'AT5G64430', 'AT5G47550',
       'AT3G52880', 'AT4G33980', 'AT1G53170', 'AT3G63060', 'AT1G27000',
       ...
       'AT5G22520', 'AT1G66660', 'AT3G26790', 'AT3G03980', 'AT1G24575',
       'AT3G18890', 'AT1G12220', 'AT2G03670', 'AT1G56280', 'AT2G03620'],
      dtype='object', length=10769)

In [5]:
# Filtrer les deux DataFrames par l'ensemble common_genes et trier dans le même ordre
df_plm = df_plm.loc[common_genes].sort_index()
df_expr = df_expr.loc[common_genes].sort_index()

In [6]:
# Conversion des données en tenseurs PyTorch
plm_matrix = torch.tensor(df_plm.values, dtype=torch.float32)
# Par exemple, taille (nb_genes_communs, nb_plms)
num_genes, num_plms = plm_matrix.shape

expr = torch.tensor(df_expr.values, dtype=torch.float32)
# Par exemple, taille (nb_genes_communs, nb_stress_conditions)
num_stress = expr.size(1)

In [7]:
print("PLM shape:", plm_matrix.shape)  # (nb_genes_communs, nb_plms)
print("Expr shape:", expr.shape)       # (nb_genes_communs, nb_stress_conditions)

PLM shape: torch.Size([10769, 178])
Expr shape: torch.Size([10769, 387])


In [8]:
# =========================================================
# 2) FILTRAGE (GÈNE, STRESS) CONTENANT UNIQUEMENT LES ÉTIQUETTES -1 OU 1
#    ET CONVERTISSEZ -1 EN 0, 1 EN 1
# =========================================================

valid_pairs = []
labels_bin = []

# Conversion en int pour éviter les flottants
expr_int = expr.to(torch.int)
for i in range(num_genes):
    for j in range(num_stress):
        val = expr_int[i, j].item()
        if val == -1:
            valid_pairs.append((i, j))
            labels_bin.append(0)  # -1 -> 0
        elif val == 1:
            valid_pairs.append((i, j))
            labels_bin.append(1)  # 1 -> 1
        # On ignore val == 0

labels_bin = torch.tensor(labels_bin, dtype=torch.long)
print("Nombre total d'échantillons valides:", len(valid_pairs))

Nombre total d'échantillons valides: 280740


In [9]:
# =========================================================
# 3) DATASET BINAIRE (0 ET 1) CORRESPONDANT AUX ÉTIQUETTES -1 ET 1
# =========================================================

class GeneExpressionBinaryDataset(Dataset):
    """
    Ce Dataset ne contient que (gene_idx, stress_idx) dont l'étiquette
    initiale est -1 ou 1, déjà remappée en 0 ou 1.
    """
    def __init__(self, plm_matrix, valid_pairs, labels_bin):
        super().__init__()
        self.plm_matrix = plm_matrix
        self.valid_pairs = valid_pairs
        self.labels_bin = labels_bin

    def __len__(self):
        return len(self.valid_pairs)

    def __getitem__(self, idx):
        gene_idx, stress_idx = self.valid_pairs[idx]
        plm_vector = self.plm_matrix[gene_idx]  # (nb_plms,)
        label = self.labels_bin[idx]            # 0 ou 1
        return plm_vector, stress_idx, label

dataset_bin = GeneExpressionBinaryDataset(plm_matrix, valid_pairs, labels_bin)

In [10]:
# =========================================================
# 4) PONDÉRATION DE CLASSE POUR GÉRER LE DÉSÉQUILIBRE
# =========================================================

unique, counts = torch.unique(labels_bin, return_counts=True)
num_samples = labels_bin.size(0)
num_classes = 2
class_weights = []
for c in range(num_classes):
    if c in unique:
        count_c = counts[c].item()
        w_c = num_samples / (num_classes * count_c)
    else:
        w_c = 1.0
    class_weights.append(w_c)

weights_tensor = torch.tensor(class_weights, dtype=torch.float32)
print("Poids de classe =", weights_tensor)

Poids de classe = tensor([1.0376, 0.9650])


In [11]:
# =========================================================
# 5) DÉFINITION DU MODÈLE + DATALOADER
# =========================================================

class GeneExpressionModel(nn.Module):
    def __init__(self, num_plms, embedding_dim=16, hidden_dim=64, num_classes=2):
        super().__init__()
        self.plm_embedding = nn.Embedding(num_embeddings=num_plms, embedding_dim=embedding_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, plm_indices_batch):
        emb_batch = []
        for indices in plm_indices_batch:
            if indices.numel() == 0:
                emb = torch.zeros(self.plm_embedding.embedding_dim, device=indices.device)
            else:
                emb_active = self.plm_embedding(indices)  # (nb_actifs, embedding_dim)
                emb = emb_active.mean(dim=0)
            emb_batch.append(emb)
        emb_tensor = torch.stack(emb_batch, dim=0)  # (batch_size, embedding_dim)
        out = self.mlp(emb_tensor)                  # (batch_size, 2)
        return out

batch_size = 4
dataloader_bin = DataLoader(dataset_bin, batch_size=batch_size, shuffle=True)

model = GeneExpressionModel(num_plms=num_plms, num_classes=2)
criterion = nn.CrossEntropyLoss(weight=weights_tensor)
optimizer = optim.Adam(model.parameters(), lr=1e-4)


In [12]:
# =========================================================
# 6) BOUCLE D'ENTRAÎNEMENT
# =========================================================

num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for plm_batch, stress_batch, label_batch in dataloader_bin:
        optimizer.zero_grad()
        # Récupération des indices actifs
        plm_indices_batch = []
        for plm_vec in plm_batch:
            active_idx = (plm_vec > 0).nonzero(as_tuple=True)[0]
            plm_indices_batch.append(active_idx)

        outputs = model(plm_indices_batch)  # (batch_size, 2)
        loss = criterion(outputs, label_batch)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    avg_loss = running_loss / len(dataloader_bin)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss = {avg_loss:.6f}")

Epoch 1/20, Loss = 0.692530
Epoch 2/20, Loss = 0.691351
Epoch 3/20, Loss = 0.690653
Epoch 4/20, Loss = 0.689799
Epoch 5/20, Loss = 0.689275
Epoch 6/20, Loss = 0.688840
Epoch 7/20, Loss = 0.688620
Epoch 8/20, Loss = 0.688233
Epoch 9/20, Loss = 0.687664
Epoch 10/20, Loss = 0.687401
Epoch 11/20, Loss = 0.686926
Epoch 12/20, Loss = 0.686710
Epoch 13/20, Loss = 0.686356
Epoch 14/20, Loss = 0.686097
Epoch 15/20, Loss = 0.685762
Epoch 16/20, Loss = 0.685461
Epoch 17/20, Loss = 0.685338
Epoch 18/20, Loss = 0.684968
Epoch 19/20, Loss = 0.684609
Epoch 20/20, Loss = 0.684370


In [13]:
# =========================================================
# 7) EXEMPLE D'INFÉRENCE
# =========================================================

model.eval()
with torch.no_grad():
    plm_vec_test, stress_id_test, label_test = dataset_bin[0]
    active_idx_test = (plm_vec_test > 0).nonzero(as_tuple=True)[0]
    logits = model([active_idx_test])  # (1,2)
    pred_class = torch.argmax(logits, dim=1).item()
    print(f"Étiquette réelle (0->-1, 1->1) = {label_test.item()}, Prédiction = {pred_class}")

Étiquette réelle (0->-1, 1->1) = 1, Prédiction = 1
