In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from transformers import AutoTokenizer, AutoModel
from torch.optim import AdamW
import torch.nn as nn
from tqdm import tqdm

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

In [None]:
MODEL_NAME = "InstaDeepAI/nucleotide-transformer-V2-250m-multi-species"
MAX_LEN = 512

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)


In [None]:
def chunk_sequence(seq, chunk_size=512, stride=256):
    seq = seq.upper()
    for i in range(0, len(seq) - chunk_size + 1, stride):
        yield seq[i:i + chunk_size]

In [None]:
import bisect

class GMODataset(Dataset):
    def __init__(self, csv_path, tokenizer, chunk_size=512, stride=256, max_chunks_per_seq=None):
        self.df = pd.read_csv(csv_path)
        self.tokenizer = tokenizer
        self.chunk_size = chunk_size
        self.stride = stride
        self.max_chunks = max_chunks_per_seq

        # Fais le traitement avec les le nombre de sequence maximal(512) valides (sans 'N')
        self.chunks_per_row = []
        for seq in self.df['sequence'].astype(str):
            s = seq.upper()
            n_chunks = max(0, (len(s) - chunk_size) // stride + 1) if len(s) >= chunk_size else 1 if len(s) > 0 else 0
            # Compte uniquement les morceaux de sequence sans 'N'
            count = 0
            for i in range(0, len(s) - chunk_size + 1, stride):
                chunk = s[i:i+chunk_size]
                if 'N' not in chunk:
                    count += 1
                    if self.max_chunks and count >= self.max_chunks:
                        break
            # Traite les s√©quences courtes (< chunk_size) comme un seul morceau de sequence si valide
            if len(s) < chunk_size and 'N' not in s:
                count = max(count, 1)
            self.chunks_per_row.append(count)

        # Fais le cumule des indice pour le mappage.
        self.cum = []
        total = 0
        for c in self.chunks_per_row:
            total += c
            self.cum.append(total)

    def __len__(self):
        return self.cum[-1] if self.cum else 0

    def __getitem__(self, idx):
        # Trouve la ligne contenant l'indice
        row_idx = bisect.bisect_right(self.cum, idx)
        row_start_cum = self.cum[row_idx - 1] - self.chunks_per_row[row_idx] if row_idx > 0 else 0
        local_idx = idx - row_start_cum # 

        seq = str(self.df.iloc[row_idx]['sequence']).upper()
        label = int(self.df.iloc[row_idx]['label'])

        # Parcoure la sequence jusqu'a ce qu'on trouve le ie indice valide (pas de 'N')
        valid_seen = 0
        for i in range(0, len(seq) - self.chunk_size + 1, self.stride):
            chunk = seq[i:i + self.chunk_size]
            if 'N' in chunk:
                continue
            if valid_seen == local_idx:
                selected = chunk
                break
            valid_seen += 1
        else:
            # 
            selected = seq[:self.chunk_size].ljust(self.chunk_size, 'A')

        tokens = self.tokenizer(
            selected,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.chunk_size
        )

        return {
            "input_ids": tokens["input_ids"].squeeze(0),
            "attention_mask": tokens["attention_mask"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long),
        }

Classification GMO - nonGMO

In [None]:
class NTForGMO(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(
            model_name,
            ignore_mismatched_sizes=True,
            trust_remote_code=True
        )
        hidden_size = self.encoder.config.hidden_size

        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 2)
        )

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        # CLS token
        cls_emb = outputs.last_hidden_state[:, 0, :]

        logits = self.classifier(cls_emb)

        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)

        return {
            "loss": loss,
            "logits": logits
        }


In [None]:
train_ds = GMODataset("data/processed/splits/train.csv", tokenizer)

In [None]:
val_ds   = GMODataset("data/processed/splits/val.csv", tokenizer)

In [None]:
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True)
# val_loader   = DataLoader(val_ds, batch_size=8)

In [None]:
model = NTForGMO(MODEL_NAME).to(DEVICE)

In [None]:
optimizer = AdamW(model.parameters(), lr=2e-5)

EPOCHS = 5

In [None]:
def create_dataloader(dataset, batch_size=8, shuffle=True, num_workers=4):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True
    )


Training

In [None]:
from torch.amp import autocast, GradScaler
scaler = GradScaler()
accum_steps = 4  # Real batch_size = batch_size * accum_steps

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    optimizer.zero_grad()
    for step, batch in enumerate(train_loader):
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        with autocast():
            outputs = model(**batch)
            loss = outputs["loss"] / accum_steps
        scaler.scale(loss).backward()
        if (step + 1) % accum_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

    print(f"Epoch {epoch+1} | Train loss: {total_loss/len(train_loader):.4f}")


Evaluation

In [None]:
from sklearn.metrics import classification_report
import numpy as np

model.eval()
y_true, y_pred = [], []

with torch.no_grad():
    for batch in val_loader:
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        outputs = model(**batch)

        preds = torch.argmax(outputs["logits"], dim=1)
        y_pred.extend(preds.cpu().numpy())
        y_true.extend(batch["labels"].cpu().numpy())

print(classification_report(y_true, y_pred, target_names=["Non-GMO", "GMO"]))


In [None]:
import gc, torch
gc.collect()
torch.cuda.empty_cache()

In [None]:
dataset = GMODataset(
    csv_path="data/processed/data.csv",
    tokenizer=tokenizer,
    chunk_size=512,
    stride=256,
    max_chunks_per_seq=10
)

loader = create_dataloader(dataset, batch_size=4)