In [22]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
from tqdm import tqdm

In [23]:
path_to_data = "/Users/abdoulabdillahi/Desktop/Thesis/Bio_project/200_samples_with_encoded.csv"
NUCLEOTIDE_TO_INT = {'A':0, 'C':1, 'G':2, 'T':3, '-':4, 'N':5}

In [24]:
NUCLEOTIDE_TO_INT


{'A': 0, 'C': 1, 'G': 2, 'T': 3, '-': 4, 'N': 5}

### Dataset class

In [25]:
class GeneDataset(Dataset):
    def __init__(self, csv_file, label_col='Ciprofloxacin_NS', gene_prefix='gene_', transform=None):
        self.df = pd.read_csv(csv_file)
        self.label_col = label_col
        self.gene_cols = [c for c in self.df.columns if c.startswith(gene_prefix)]
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        label = int(row[self.label_col])
        sequences = []
        for col in self.gene_cols:
            seq = row[col]
            if pd.isna(seq):
                # truly missing
                sequences.append([])
                continue

            if isinstance(seq, str) and seq.startswith('[') and seq.endswith(']'):
                # strip brackets, split on whitespace
                tokens = seq.strip('[]').split()
                # drop pure 'nan' tokens
                tokens = [t for t in tokens if t.lower() != 'nan']
                sequences.append(tokens)
            else:
                # assume it's already a list/array of chars
                sequences.append(list(seq))

        sample = {'sequences': sequences, 'label': label}
        return self.transform(sample) if self.transform else sample

def collate_fn(batch):
    labels = torch.tensor([b['label'] for b in batch], dtype=torch.long)
    seqs   = [b['sequences'] for b in batch]
    B, G = len(seqs), len(seqs[0])
    # find max length across each gene, then global
    max_lens = [max(len(s[j]) for s in seqs) for j in range(G)]
    L_max = max(max_lens)

    x = torch.full((B, G, L_max),
                   fill_value=NUCLEOTIDE_TO_INT['N'],
                   dtype=torch.long)

    for i in range(B):
        for j in range(G):
            for k, nt in enumerate(seqs[i][j]):
                x[i, j, k] = NUCLEOTIDE_TO_INT.get(nt.upper(), NUCLEOTIDE_TO_INT['N'])
    return x, labels


In [26]:
train_dataset = GeneDataset(path_to_data)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=2,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn
)

### Simple pytorch NN model

In [27]:
class GeneClassifier(nn.Module):
    def __init__(self,
                 num_genes: int,
                 vocab_size: int = 6,     # A/C/G/T/-/N → 6 tokens
                 embed_dim: int = 16,      # size of nucleotide embedding
                 hidden_dim: int = 128,    # MLP hidden size
                 num_classes: int = 2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=vocab_size-1)
        # after embedding: (B, G, L, E) → mean over L → (B, G, E)
        self.fc = nn.Sequential(
            nn.Linear(num_genes * embed_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x):
        # x: LongTensor of shape (B, G, L)
        B, G, L = x.shape
        # embed → (B, G, L, E)
        x = self.embed(x)
        # mean-pool over L → (B, G, E)
        x = x.mean(dim=2)
        # flatten genes → (B, G*E)
        x = x.view(B, G * x.size(-1))
        return self.fc(x)

### Instantiate model, loss, & optimizer

In [28]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_genes = len(train_dataset.gene_cols)
model      = GeneClassifier(num_genes).to(device)
criterion  = nn.CrossEntropyLoss()
optimizer  = torch.optim.Adam(model.parameters(), lr=1e-3)

In [29]:
num_epochs = 1
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    correct    = 0
    total      = 0

    for x_batch, y_batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        x_batch = x_batch.to(device)    # (B, G, L)
        y_batch = y_batch.to(device)    # (B,)

        optimizer.zero_grad()
        logits = model(x_batch)         # (B, 2)
        loss   = criterion(logits, y_batch)
        loss.backward()
        optimizer.step()

        # accumulate loss
        batch_size    = x_batch.size(0)
        total_loss   += loss.item() * batch_size

        # accumulate accuracy
        preds         = logits.argmax(dim=1)
        correct      += (preds == y_batch).sum().item()
        total        += batch_size

    avg_loss = total_loss / total
    acc      = correct / total

    print(f"Epoch {epoch+1} — avg training loss: {avg_loss:.4f} — training accuracy: {acc:.4%}")

Epoch 1/1:   9%|▉         | 7/77 [02:12<22:07, 18.96s/it]


KeyboardInterrupt: 