In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from transformers import AutoTokenizer, AutoModel
from torch.optim import AdamW
import torch.nn as nn
from tqdm import tqdm

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
MODEL_NAME = "InstaDeepAI/nucleotide-transformer-V2-250m-multi-species"
MAX_LEN = 512

In [4]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)


tokenizer_config.json:   0%|          | 0.00/129 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/101 [00:00<?, ?B/s]

In [5]:
def chunk_sequence(seq, chunk_size=512, stride=256):
    seq = seq.upper()
    for i in range(0, len(seq) - chunk_size + 1, stride):
        yield seq[i:i + chunk_size]

In [6]:
class GMODataset(Dataset):
    def __init__(self, csv_path, tokenizer, chunk_size=512, stride=256, max_chunks_per_seq=None):

        self.df = pd.read_csv(csv_path)
        self.tokenizer = tokenizer
        self.chunk_size = chunk_size
        self.stride = stride
        self.max_chunks = max_chunks_per_seq

        self.samples = []
        self._prepare_samples()

    def _prepare_samples(self):
        for _, row in self.df.iterrows():
            seq = row["sequence"]
            label = int(row["label"])

            chunks = list(chunk_sequence(seq, self.chunk_size, self.stride))
            if self.max_chunks:
                chunks = chunks[:self.max_chunks]

            for chunk in chunks:
                if "N" in chunk:
                    continue
                self.samples.append((chunk, label))
                
    def __len__(self):
            return len(self.samples)

    def __getitem__(self, idx):
        seq, label = self.samples[idx]

        tokens = self.tokenizer(
            seq,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.chunk_size
        )

        return {
            "input_ids": tokens["input_ids"].squeeze(0),
            "attention_mask": tokens["attention_mask"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long)
        }


In [7]:
class GMODataset(Dataset):
    def __init__(self, csv_path, tokenizer, chunk_size=512, stride=256, max_chunks_per_seq=None):

        self.df = pd.read_csv(csv_path)
        self.tokenizer = tokenizer
        self.chunk_size = chunk_size
        self.stride = stride
        self.max_chunks = max_chunks_per_seq

        self.samples = []
        self._prepare_samples()

    def _prepare_samples(self):
        for idx, row in self.df.iterrows():
            seq_id = row.get("id", f"seq_{idx:06d}")
            seq = row["sequence"].upper()
            label = int(row["label"])

            chunks = list(chunk_sequence(seq, self.chunk_size, self.stride))
            if self.max_chunks:
                chunks = chunks[:self.max_chunks]

            for i,chunk in enumerate(chunks):
                if "N" in chunk:
                    continue
                self.samples.append({'seq_id':seq_id,'chunk_id':i,'sequence':chunk, 'labe':label})
                
    def __len__(self):
            return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        tokens = self.tokenizer(
            sample['sequence'],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.chunk_size
        )

        return {
            "input_ids": tokens["input_ids"].squeeze(0),
            "attention_mask": tokens["attention_mask"].squeeze(0),
            "labels": torch.tensor(sample['label'], dtype=torch.long),
            "seq_id": sample['seq_id'],
            "chunk_id": sample['chunk_id']
        }


Classification GMO - nonGMO

In [8]:
class NTForGMO(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(
            model_name,
            trust_remote_code=True
        )
        hidden_size = self.encoder.config.hidden_size

        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 2)
        )

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        # CLS token
        cls_emb = outputs.last_hidden_state[:, 0, :]

        logits = self.classifier(cls_emb)

        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)

        return {
            "loss": loss,
            "logits": logits
        }


In [9]:
train_ds = GMODataset("data/processed/splits/train.csv", tokenizer)

In [None]:
val_ds   = GMODataset("data/processed/splits/val.csv", tokenizer)

In [None]:
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=8)

model = NTForGMO(MODEL_NAME).to(DEVICE)

optimizer = AdamW(model.parameters(), lr=2e-5)

EPOCHS = 5

In [None]:
def create_dataloader(dataset, batch_size=8, shuffle=True, num_workers=4):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True
    )


Training

In [None]:
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for batch in tqdm(train_loader):
        optimizer.zero_grad()

        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        outputs = model(**batch)

        loss = outputs["loss"]
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} | Train loss: {total_loss/len(train_loader):.4f}")


Evaluation

In [None]:
from sklearn.metrics import classification_report
import numpy as np

model.eval()
y_true, y_pred = [], []

with torch.no_grad():
    for batch in val_loader:
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        outputs = model(**batch)

        preds = torch.argmax(outputs["logits"], dim=1)
        y_pred.extend(preds.cpu().numpy())
        y_true.extend(batch["labels"].cpu().numpy())

print(classification_report(y_true, y_pred, target_names=["Non-GMO", "GMO"]))


In [None]:
dataset = GMODataset(
    csv_path="data/processed/data.csv",
    tokenizer=tokenizer,
    chunk_size=512,
    stride=256,
    max_chunks_per_seq=10
)

loader = create_dataloader(dataset, batch_size=4)