In [28]:
# Cell 1
import random, math
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd

random.seed(1)
torch.manual_seed(1)

<torch._C.Generator at 0x7f5144758830>

In [29]:
# Cell 2
class KmerTokenizer:
    def __init__(self, k=4):
        self.k = k
        bases = ['A','C','G','T','N']
        from itertools import product
        self.vocab = {'<pad>':0, '<unk>':1}
        idx = 2
        for kmer in map(''.join, product(bases, repeat=self.k)):
            self.vocab[kmer] = idx
            idx += 1
        self.vocab_size = len(self.vocab)
    def encode(self, seq):
        seq = seq.upper().replace('U','T')
        toks = []
        for i in range(0, max(1, len(seq)-self.k+1)):
            kmer = seq[i:i+self.k]
            toks.append(self.vocab.get(kmer, self.vocab['<unk>']))
        return toks[:512]  # Truncate to 512 tokens


In [30]:
# Cell 3
def load_16s_data(csv_path='16S_sequences.csv'):
    df = pd.read_csv(csv_path)
    data = []
    taxa = set()
    for _, row in df.iterrows():
        # Extract genus from taxonomy string
        tax_parts = row['taxonomy'].split()
        genus = tax_parts[0] if tax_parts else 'Unknown'
        taxa.add(genus)

    taxa = sorted(list(taxa))
    taxa_to_idx = {t:i for i,t in enumerate(taxa)}

    for _, row in df.iterrows():
        tax_parts = row['taxonomy'].split()
        genus = tax_parts[0] if tax_parts else 'Unknown'
        gc = (row['sequence'].count('G') + row['sequence'].count('C')) / len(row['sequence'])
        novel = 1 if gc > 0.7 or gc < 0.3 else 0  # Extreme GC content as novelty indicator
        data.append({
            'seq': row['sequence'],
            'tax_idx': taxa_to_idx[genus],
            'role_idx': 0,  # Single role for simplicity
            'novel': novel
        })
    return data, taxa


In [31]:
# Cell 4
tokenizer = KmerTokenizer(k=4)

class ASVDataset(Dataset):
    def __init__(self, records, tokenizer):
        self.records = records
        self.tokenizer = tokenizer
    def __len__(self):
        return len(self.records)
    def __getitem__(self, idx):
        rec = self.records[idx]
        toks = self.tokenizer.encode(rec['seq'])
        return {
            'tokens': torch.tensor(toks, dtype=torch.long),
            'tax_idx': torch.tensor(rec['tax_idx'], dtype=torch.long),
            'role_idx': torch.tensor(rec['role_idx'], dtype=torch.long),
            'novel': torch.tensor(rec['novel'], dtype=torch.float)
        }

def collate_fn(batch):
    tokens = [b['tokens'] for b in batch]
    lengths = [t.size(0) for t in tokens]
    maxlen = max(lengths)
    padded = torch.zeros(len(tokens), maxlen, dtype=torch.long)
    mask = torch.zeros(len(tokens), maxlen, dtype=torch.bool)
    for i,t in enumerate(tokens):
        padded[i,:t.size(0)] = t
        mask[i,:t.size(0)] = 1
    tax_idx = torch.stack([b['tax_idx'] for b in batch])
    role_idx = torch.stack([b['role_idx'] for b in batch])
    novel = torch.stack([b['novel'] for b in batch])
    # src_key_padding_mask: True where padding
    src_key_padding_mask = ~mask
    return padded, src_key_padding_mask, tax_idx, role_idx, novel


In [32]:
# Cell 5
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class MultiTaskTaxonomyModel(nn.Module):
    def __init__(self, vocab_size, d_model=64, nhead=2, num_layers=1, tax_classes=3, role_classes=3):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos = PositionalEncoding(d_model, max_len=512)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=128, dropout=0.1, activation='relu')
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.tax_head = nn.Linear(d_model, tax_classes)
        self.role_head = nn.Linear(d_model, role_classes)
        self.novel_head = nn.Linear(d_model, 1)
    def forward(self, x, src_key_padding_mask=None):
        emb = self.embed(x) * math.sqrt(self.embed.embedding_dim)
        emb = self.pos(emb)
        emb_t = emb.transpose(0,1)  # [L,B,D]
        enc = self.encoder(emb_t, src_key_padding_mask=src_key_padding_mask)
        enc = enc.transpose(0,1)   # [B,L,D]
        if src_key_padding_mask is not None:
            mask = ~src_key_padding_mask  # True at valid positions
            mask = mask.unsqueeze(-1).float()
            enc = enc * mask
            pooled = enc.sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)
        else:
            pooled = enc.mean(dim=1)
        return {
            'tax_logits': self.tax_head(pooled),
            'role_logits': self.role_head(pooled),
            'novel_logits': self.novel_head(pooled).squeeze(-1)
        }


In [33]:
# Cell 6
records, taxa = load_16s_data('16S_sequences.csv')
print(f'Loaded {len(records)} sequences with {len(taxa)} taxa')
split_idx = int(0.8 * len(records))
train = records[:split_idx]
val = records[split_idx:]

train_ds = ASVDataset(train, tokenizer)
val_ds = ASVDataset(val, tokenizer)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, collate_fn=collate_fn)

Loaded 27313 sequences with 4307 taxa


In [34]:
# Cell 7
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MultiTaskTaxonomyModel(vocab_size=tokenizer.vocab_size, d_model=64, nhead=2, num_layers=1, tax_classes=len(taxa)).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=5e-4)
ce = nn.CrossEntropyLoss()
bce = nn.BCEWithLogitsLoss()




In [35]:
# Cell 8
def evaluate(loader):
    model.eval()
    total = 0
    correct_tax = 0
    correct_role = 0
    with torch.no_grad():
        for padded, src_key_padding_mask, tax_idx, role_idx, novel in loader:
            padded = padded.to(device); src_key_padding_mask = src_key_padding_mask.to(device)
            tax_idx = tax_idx.to(device); role_idx = role_idx.to(device)
            out = model(padded, src_key_padding_mask=src_key_padding_mask)
            correct_tax += (out['tax_logits'].argmax(dim=1) == tax_idx).sum().item()
            correct_role += (out['role_logits'].argmax(dim=1) == role_idx).sum().item()
            total += tax_idx.size(0)
    return correct_tax/total, correct_role/total

max_epochs = 20
early_stopping_patience = 3
best_val_loss = float('inf')
epochs_no_improve = 0

for epoch in range(1, max_epochs + 1):
    model.train()
    running = 0.0
    for padded, src_key_padding_mask, tax_idx, role_idx, novel in train_loader:
        padded = padded.to(device); src_key_padding_mask = src_key_padding_mask.to(device)
        tax_idx = tax_idx.to(device); role_idx = role_idx.to(device); novel = novel.to(device)
        out = model(padded, src_key_padding_mask=src_key_padding_mask)
        loss = ce(out['tax_logits'], tax_idx) + ce(out['role_logits'], role_idx) + 1.5*bce(out['novel_logits'], novel)
        opt.zero_grad(); loss.backward(); opt.step()
        running += loss.item() * tax_idx.size(0)
    avg_train_loss = running / len(train)

    model.eval()
    running_val_loss = 0.0
    with torch.no_grad():
        for padded, src_key_padding_mask, tax_idx, role_idx, novel in val_loader:
            padded = padded.to(device); src_key_padding_mask = src_key_padding_mask.to(device)
            tax_idx = tax_idx.to(device); role_idx = role_idx.to(device); novel = novel.to(device)
            out = model(padded, src_key_padding_mask=src_key_padding_mask)
            loss = ce(out['tax_logits'], tax_idx) + ce(out['role_logits'], role_idx) + 1.5*bce(out['novel_logits'], novel)
            running_val_loss += loss.item() * tax_idx.size(0)
    avg_val_loss = running_val_loss / len(val)

    tax_acc, role_acc = evaluate(val_loader)
    print(f"Epoch {epoch} | Train loss {avg_train_loss:.4f} | Val loss {avg_val_loss:.4f} | Val tax acc {tax_acc:.3f} | Val role acc {role_acc:.3f}")

    # Early stopping
    # if avg_val_loss < best_val_loss:
    #     best_val_loss = avg_val_loss
    #     epochs_no_improve = 0
    # else:
    #     epochs_no_improve += 1

    # if epochs_no_improve == early_stopping_patience:
    #     print(f"Early stopping after {epoch} epochs.")
    #     break

Epoch 1 | Train loss 7.4639 | Val loss 7.6403 | Val tax acc 0.027 | Val role acc 1.000
Epoch 2 | Train loss 6.8033 | Val loss 7.5541 | Val tax acc 0.050 | Val role acc 1.000
Epoch 3 | Train loss 6.5395 | Val loss 7.4046 | Val tax acc 0.070 | Val role acc 1.000
Epoch 4 | Train loss 6.2054 | Val loss 7.1359 | Val tax acc 0.076 | Val role acc 1.000
Epoch 5 | Train loss 5.8735 | Val loss 6.9244 | Val tax acc 0.092 | Val role acc 1.000
Epoch 6 | Train loss 5.5808 | Val loss 6.7407 | Val tax acc 0.105 | Val role acc 1.000
Epoch 7 | Train loss 5.3100 | Val loss 6.5141 | Val tax acc 0.121 | Val role acc 1.000
Epoch 8 | Train loss 5.0574 | Val loss 6.3385 | Val tax acc 0.133 | Val role acc 1.000
Epoch 9 | Train loss 4.8228 | Val loss 6.1672 | Val tax acc 0.151 | Val role acc 1.000
Epoch 10 | Train loss 4.6038 | Val loss 5.9991 | Val tax acc 0.163 | Val role acc 1.000
Epoch 11 | Train loss 4.3980 | Val loss 5.8592 | Val tax acc 0.178 | Val role acc 1.000
Epoch 12 | Train loss 4.1944 | Val loss 5

In [36]:
# Cell 9
model.eval()
for ex in val[:6]:
    toks = torch.tensor(tokenizer.encode(ex['seq']), dtype=torch.long).unsqueeze(0).to(device)
    mask = torch.zeros(1, toks.size(1), dtype=torch.bool).to(device)
    out = model(toks, src_key_padding_mask=~mask)
    tax_pred = taxa[out['tax_logits'].argmax(dim=1).item()]
    novel_score = torch.sigmoid(out['novel_logits']).item()
    print(f"True tax: {taxa[ex['tax_idx']]} | Pred tax: {tax_pred} | Novel score: {novel_score:.3f}")


True tax: Alloscardovia | Pred tax: Eubacterium | Novel score: 0.397
True tax: Bifidobacterium | Pred tax: Eubacterium | Novel score: 0.397
True tax: Galliscardovia | Pred tax: Eubacterium | Novel score: 0.397
True tax: Roseomonas | Pred tax: Eubacterium | Novel score: 0.397
True tax: Chthonobacter | Pred tax: Eubacterium | Novel score: 0.397
True tax: Actinophytocola | Pred tax: Eubacterium | Novel score: 0.397


In [37]:
# Cell 10
torch.save({'model_state_dict': model.state_dict(), 'tokenizer_vocab': tokenizer.vocab}, 'multitask_model_demo_small.pth')
print("Saved model to multitask_model_demo_small.pth")


Saved model to multitask_model_demo_small.pth


In [38]:
# Fix Cell - Recreate everything
del tokenizer, train_ds, val_ds, train_loader, val_loader, model

tokenizer = KmerTokenizer(k=4)
records, taxa = load_16s_data('16S_sequences.csv')
print(f'Loaded {len(records)} sequences with {len(taxa)} taxa')

# Test tokenization
test_tokens = tokenizer.encode(records[0]['seq'])
print(f'First sequence tokenized to {len(test_tokens)} tokens')

split_idx = int(0.8 * len(records))
train = records[:split_idx]
val = records[split_idx:]

train_ds = ASVDataset(train, tokenizer)
val_ds = ASVDataset(val, tokenizer)

train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, collate_fn=collate_fn)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MultiTaskTaxonomyModel(vocab_size=tokenizer.vocab_size, d_model=64, nhead=2, num_layers=1, tax_classes=len(taxa)).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=5e-4)
ce = nn.CrossEntropyLoss()
bce = nn.BCEWithLogitsLoss()

print('Everything recreated successfully')


Loaded 27313 sequences with 4307 taxa
First sequence tokenized to 512 tokens
Everything recreated successfully




In [43]:
# Prediction Cell
def predict_taxonomy(sequence, model, tokenizer, taxa):
    model.eval()
    tokens = tokenizer.encode(sequence)
    input_tensor = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(input_tensor)
        print(f"Raw tax logits: {output['tax_logits']}") # Added print statement
        tax_probs = torch.softmax(output['tax_logits'], dim=1)
        novel_score = torch.sigmoid(output['novel_logits']).item()

        top_probs, top_indices = torch.topk(tax_probs, k=min(3, len(taxa)))

        print(f'Sequence length: {len(sequence)} bp')
        print(f'Tokenized length: {len(tokens)} k-mers')
        print(f'Novel score: {novel_score:.3f}')
        print('\nTop 3 Predictions:')
        for i, (prob, idx) in enumerate(zip(top_probs[0], top_indices[0])):
            genus = taxa[idx.item()]
            confidence = prob.item() * 100
            print(f'{i+1}. {genus} ({confidence:.1f}%)')

# Example prediction
test_seq = 'GATTATGGCTCAGAACGAACGCTGGCGGCAGGCCTAACACATGCAAGTCGAGCGCGCCTTTCGGGGCGAGCGGCGGACGGGTTAGTAACGCGTGGGAATGTACCCTTTTCTACGGAATAGCCTCGGGAAACTGAGATTAATACCGTATACGCCCCCCCAATCAAATTTCATTTGATTGAATTTTCAGTCATATCAAATTCCGAATGGAATTTGATGGGGGGGGAAAGATTTATCGGAGAAGGATCAGCCCGCGTTAGATTAGATAGTTGGTGGGGTAATGGCCTACCAAGTCTACGATCTATAGCTGGTTTGAGAGGATGATCAGCAACACTGGGACTGAGACACGGCCCAGACTCCTACGGGAGGCAGCAGTGGGGAATCTTAGACAATGGGCGCAAGCCTGATCTAGCCATGCCGCGTGAGTGAAGAAGGCCTTAGGGTCGTAAAGCTCTTTCAGTGGGGAAGATAATGACGGTACCCACAGAAGAAACCCCGGCTAACTCCGTGCCAGCAGCCGCGGTAATACGGAGGGGGTTAGCGTTGTTCGGAATTACTGGGCGTAAAGCGTACGTAGGCGGATTAGCAAGTTAGAGGTGAAATCCCAGGGCTCAACCTTGGAACTGCCTTTAAAACTGCTAGTCTTGAGTTCGAGAGAGGTGAGTGGAATTCCGAGTGTAGAGGTGAAATTCGTAGATATTCGGAGGAACACCAGTGGCGAAGGCGGCTCACTGGCTCGATACTGACGCTGAGGTACGAAAGCGTGGGGAGCAAACAGGATTAGATACCCTGGTAGTCCACGCCGTAAACGATGAGAGCTAGTCGTCGGGTTGCATGCAATTCGGTGACGCAGTTAACGCATTAAGCTCTCCGCCTGGGGAGTACGGTCGCAAGATTAAAACTCAAAGGAATTGACGGGGGCCCGCACAAGCGGTGGAGCATGTGGTTTAATTCGAAGCAACGCGCAGAACCTTACCAACCCTTGACATACCTGTCGCGGCCCGAGAGATCGGGCTTTCAGTTCGGCTGGACAGGATACAGGTGCTGCATGGCTGTCGTCAGCTCGTGTCGTGAGATGTTCGGTTAAGTCCGGCAACGAGCGCAACCCCTGCCTTTAGTTGCCAGCATTCAGTTGGGCACTCTAGAGGGACCGCCGGTGATAAGCCGGAGGAAGGTGGGGATGACGTCAAGTCCTCATGGCCCTTACGGGTTGGGCTACACACGTGCTACAATGGTAGTGACAATGGGTTAATCCCAAAAAGCTATCTCAGTTCGGATTGTCCTCTGCAACTCGAGGGCATGAAGTTGGAATCGCTAGTAATCGCGTAACAGCATGACGCGGTGAATACGTTCCCGGGCCTTGTACACACCGCCCGTCACACCATGGGAATTGGATCTACCCGAAGGCCGTGCGCTAATTTGGCAGCGGACCACGGTAGGTTCAGTGACTGGGGTGAAGTCGTAACAAGG'
predict_taxonomy(test_seq, model, tokenizer, taxa)

Raw tax logits: tensor([[-4.2747, -4.3353,  0.8509,  ..., -1.0977,  1.0979, -1.4322]],
       device='cuda:0')
Sequence length: 1467 bp
Tokenized length: 512 k-mers
Novel score: 0.003

Top 3 Predictions:
1. Chitinophaga (9.4%)
2. Pedobacter (6.8%)
3. Halolamina (3.6%)


In [44]:
# Load the saved model state dictionary
import torch
import math
from itertools import product
from torch import nn
import pandas as pd # Import pandas

# Define KmerTokenizer class (assuming it's not globally available)
class KmerTokenizer:
    def __init__(self, k=4):
        self.k = k
        bases = ['A','C','G','T','N']
        from itertools import product
        self.vocab = {'<pad>':0, '<unk>':1}
        idx = 2
        for kmer in map(''.join, product(bases, repeat=self.k)):
            self.vocab[kmer] = idx
            idx += 1
        self.vocab_size = len(self.vocab)
    def encode(self, seq):
        seq = seq.upper().replace('U','T')
        toks = []
        for i in range(0, max(1, len(seq)-self.k+1)):
            kmer = seq[i:i+self.k]
            toks.append(self.vocab.get(kmer, self.vocab['<unk>']))
        return toks[:512]  # Truncate to 512 tokens

# Define PositionalEncoding class (assuming it's not globally available)
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# Define MultiTaskTaxonomyModel class (assuming it's not globally available)
class MultiTaskTaxonomyModel(nn.Module):
    def __init__(self, vocab_size, d_model=64, nhead=2, num_layers=1, tax_classes=3, role_classes=3):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos = PositionalEncoding(d_model, max_len=512)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=128, dropout=0.1, activation='relu')
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.tax_head = nn.Linear(d_model, tax_classes)
        self.role_head = nn.Linear(d_model, role_classes)
        self.novel_head = nn.Linear(d_model, 1)
    def forward(self, x, src_key_padding_mask=None):
        emb = self.embed(x) * math.sqrt(self.embed.embedding_dim)
        emb = self.pos(emb)
        emb_t = emb.transpose(0,1)  # [L,B,D]
        enc = self.encoder(emb_t, src_key_padding_mask=src_key_padding_mask)
        enc = enc.transpose(0,1)   # [B,L,D]
        if src_key_padding_mask is not None:
            mask = ~src_key_padding_mask  # True at valid positions
            mask = mask.unsqueeze(-1).float()
            enc = enc * mask
            pooled = enc.sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)
        else:
            pooled = enc.mean(dim=1)
        return {
            'tax_logits': self.tax_head(pooled),
            'role_logits': self.role_head(pooled),
            'novel_logits': self.novel_head(pooled).squeeze(-1)
        }

# Assuming 'taxa' is a list containing the unique taxa from your data
# and its length represents the correct number of tax_classes (4307)
# If 'taxa' is not defined in a previous cell, you might need to load your data again to get it.
# For now, I will assume 'taxa' is available.
# If you still get a NameError for 'taxa', you might need to add a cell to load the data.

# Load data and get taxa to determine the correct number of tax_classes
def load_16s_data(csv_path='16S_sequences.csv'):
    df = pd.read_csv(csv_path)
    data = []
    taxa = set()
    for _, row in df.iterrows():
        # Extract genus from taxonomy string
        tax_parts = row['taxonomy'].split()
        genus = tax_parts[0] if tax_parts else 'Unknown'
        taxa.add(genus)

    taxa = sorted(list(taxa))
    taxa_to_idx = {t:i for i,t in enumerate(taxa)}

    for _, row in df.iterrows():
        tax_parts = row['taxonomy'].split()
        genus = tax_parts[0] if tax_parts else 'Unknown'
        gc = (row['sequence'].count('G') + row['sequence'].count('C')) / len(row['sequence'])
        novel = 1 if gc > 0.7 or gc < 0.3 else 0  # Extreme GC content as novelty indicator
        data.append({
            'seq': row['sequence'],
            'tax_idx': taxa_to_idx[genus],
            'role_idx': 0,  # Single role for simplicity
            'novel': novel
        })
    return data, taxa

records, taxa = load_16s_data('16S_sequences.csv') # Load data and define taxa

# Re-initialize the model with the correct number of tax_classes
tokenizer = KmerTokenizer(k=4) # Define tokenizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Define device
model = MultiTaskTaxonomyModel(vocab_size=tokenizer.vocab_size, d_model=64, nhead=2, num_layers=1, tax_classes=len(taxa)).to(device)
checkpoint = torch.load('multitask_model_demo_small.pth')
model.load_state_dict(checkpoint['model_state_dict'])
print("Model state dictionary loaded successfully.")

Model state dictionary loaded successfully.




# New Section

In [47]:
from google.colab import files

files.download('multitask_model_demo_small.pth')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>