In [None]:
# Cell 1
import random, math
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader

random.seed(1)
torch.manual_seed(1)

<torch._C.Generator at 0x18fe47790f0>

In [4]:
# Cell 2
class KmerTokenizer:
    def __init__(self, k=4):
        self.k = k
        bases = ['A','C','G','T','N']
        from itertools import product
        self.vocab = {'<pad>':0, '<unk>':1}
        idx = 2
        for kmer in map(''.join, product(bases, repeat=self.k)):
            self.vocab[kmer] = idx
            idx += 1
        self.vocab_size = len(self.vocab)
    def encode(self, seq):
        seq = seq.upper().replace('U','T')
        toks = []
        for i in range(0, max(1, len(seq)-self.k+1)):
            kmer = seq[i:i+self.k]
            toks.append(self.vocab.get(kmer, self.vocab['<unk>']))
        return toks


In [5]:
# Cell 3
def random_seq(length):
    return ''.join(random.choice(['A','C','G','T']) for _ in range(length))

taxa = ['Phylum_A','Phylum_B','Phylum_C']
roles = ['primary_producer','microbial_grazer','decomposer']
taxa_to_idx = {t:i for i,t in enumerate(taxa)}
role_to_idx = {r:i for i,r in enumerate(roles)}

def build_toy_dataset(n=80, seq_len_range=(30,60)):
    data=[]
    for i in range(n):
        l=random.randint(*seq_len_range)
        s=random_seq(l)
        gc=(s.count('G')+s.count('C'))/len(s)
        if gc<0.33:
            tax='Phylum_A'; role='primary_producer'
        elif gc<0.66:
            tax='Phylum_B'; role='microbial_grazer'
        else:
            tax='Phylum_C'; role='decomposer'
        novel = 1 if random.random()<0.1 else 0
        data.append({'seq':s,'tax_idx':taxa_to_idx[tax],'role_idx':role_to_idx[role],'novel':novel})
    return data


In [6]:
# Cell 4
tokenizer = KmerTokenizer(k=4)

class ASVDataset(Dataset):
    def __init__(self, records, tokenizer):
        self.records = records
        self.tokenizer = tokenizer
    def __len__(self):
        return len(self.records)
    def __getitem__(self, idx):
        rec = self.records[idx]
        toks = self.tokenizer.encode(rec['seq'])
        return {
            'tokens': torch.tensor(toks, dtype=torch.long),
            'tax_idx': torch.tensor(rec['tax_idx'], dtype=torch.long),
            'role_idx': torch.tensor(rec['role_idx'], dtype=torch.long),
            'novel': torch.tensor(rec['novel'], dtype=torch.float)
        }

def collate_fn(batch):
    tokens = [b['tokens'] for b in batch]
    lengths = [t.size(0) for t in tokens]
    maxlen = max(lengths)
    padded = torch.zeros(len(tokens), maxlen, dtype=torch.long)
    mask = torch.zeros(len(tokens), maxlen, dtype=torch.bool)
    for i,t in enumerate(tokens):
        padded[i,:t.size(0)] = t
        mask[i,:t.size(0)] = 1
    tax_idx = torch.stack([b['tax_idx'] for b in batch])
    role_idx = torch.stack([b['role_idx'] for b in batch])
    novel = torch.stack([b['novel'] for b in batch])
    # src_key_padding_mask: True where padding
    src_key_padding_mask = ~mask
    return padded, src_key_padding_mask, tax_idx, role_idx, novel


In [7]:
# Cell 5
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class MultiTaskTaxonomyModel(nn.Module):
    def __init__(self, vocab_size, d_model=64, nhead=2, num_layers=1, tax_classes=3, role_classes=3):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos = PositionalEncoding(d_model, max_len=512)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=128, dropout=0.1, activation='relu')
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.tax_head = nn.Linear(d_model, tax_classes)
        self.role_head = nn.Linear(d_model, role_classes)
        self.novel_head = nn.Linear(d_model, 1)
    def forward(self, x, src_key_padding_mask=None):
        emb = self.embed(x) * math.sqrt(self.embed.embedding_dim)
        emb = self.pos(emb)
        emb_t = emb.transpose(0,1)  # [L,B,D]
        enc = self.encoder(emb_t, src_key_padding_mask=src_key_padding_mask)
        enc = enc.transpose(0,1)   # [B,L,D]
        if src_key_padding_mask is not None:
            mask = ~src_key_padding_mask  # True at valid positions
            mask = mask.unsqueeze(-1).float()
            enc = enc * mask
            pooled = enc.sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)
        else:
            pooled = enc.mean(dim=1)
        return {
            'tax_logits': self.tax_head(pooled),
            'role_logits': self.role_head(pooled),
            'novel_logits': self.novel_head(pooled).squeeze(-1)
        }


In [8]:
# Cell 6
records = build_toy_dataset(n=80, seq_len_range=(30,60))
train = records[:64]
val = records[64:]

train_ds = ASVDataset(train, tokenizer)
val_ds = ASVDataset(val, tokenizer)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, collate_fn=collate_fn)


In [9]:
# Cell 7
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MultiTaskTaxonomyModel(vocab_size=tokenizer.vocab_size, d_model=64, nhead=2, num_layers=1).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=5e-4)
ce = nn.CrossEntropyLoss()
bce = nn.BCEWithLogitsLoss()




In [10]:
# Cell 8
def evaluate(loader):
    model.eval()
    total = 0
    correct_tax = 0
    correct_role = 0
    with torch.no_grad():
        for padded, src_key_padding_mask, tax_idx, role_idx, novel in loader:
            padded = padded.to(device); src_key_padding_mask = src_key_padding_mask.to(device)
            tax_idx = tax_idx.to(device); role_idx = role_idx.to(device)
            out = model(padded, src_key_padding_mask=src_key_padding_mask)
            correct_tax += (out['tax_logits'].argmax(dim=1) == tax_idx).sum().item()
            correct_role += (out['role_logits'].argmax(dim=1) == role_idx).sum().item()
            total += tax_idx.size(0)
    return correct_tax/total, correct_role/total

for epoch in range(1,6):
    model.train()
    running = 0.0
    for padded, src_key_padding_mask, tax_idx, role_idx, novel in train_loader:
        padded = padded.to(device); src_key_padding_mask = src_key_padding_mask.to(device)
        tax_idx = tax_idx.to(device); role_idx = role_idx.to(device); novel = novel.to(device)
        out = model(padded, src_key_padding_mask=src_key_padding_mask)
        loss = ce(out['tax_logits'], tax_idx) + ce(out['role_logits'], role_idx) + 1.5*bce(out['novel_logits'], novel)
        opt.zero_grad(); loss.backward(); opt.step()
        running += loss.item() * tax_idx.size(0)
    avg = running / len(train)
    tax_acc, role_acc = evaluate(val_loader)
    print(f"Epoch {epoch} | Train loss {avg:.4f} | Val tax acc {tax_acc:.3f} | Val role acc {role_acc:.3f}")


Epoch 1 | Train loss 2.8019 | Val tax acc 1.000 | Val role acc 1.000
Epoch 2 | Train loss 2.3931 | Val tax acc 1.000 | Val role acc 1.000
Epoch 3 | Train loss 2.0255 | Val tax acc 1.000 | Val role acc 1.000
Epoch 4 | Train loss 1.7130 | Val tax acc 1.000 | Val role acc 1.000
Epoch 5 | Train loss 1.4204 | Val tax acc 1.000 | Val role acc 1.000


In [11]:
# Cell 9
model.eval()
for ex in val[:6]:
    toks = torch.tensor(tokenizer.encode(ex['seq']), dtype=torch.long).unsqueeze(0).to(device)
    mask = torch.zeros(1, toks.size(1), dtype=torch.bool).to(device)
    out = model(toks, src_key_padding_mask=~mask)
    tax_pred = taxa[out['tax_logits'].argmax(dim=1).item()]
    role_pred = roles[out['role_logits'].argmax(dim=1).item()]
    novel_score = torch.sigmoid(out['novel_logits']).item()
    print(f"True tax: {taxa[ex['tax_idx']]} | Pred tax: {tax_pred} | True role: {roles[ex['role_idx']]} | Pred role: {role_pred} | Novel score: {novel_score:.3f}")


True tax: Phylum_B | Pred tax: Phylum_B | True role: microbial_grazer | Pred role: decomposer | Novel score: 0.508
True tax: Phylum_B | Pred tax: Phylum_B | True role: microbial_grazer | Pred role: decomposer | Novel score: 0.508
True tax: Phylum_B | Pred tax: Phylum_B | True role: microbial_grazer | Pred role: decomposer | Novel score: 0.508
True tax: Phylum_B | Pred tax: Phylum_B | True role: microbial_grazer | Pred role: decomposer | Novel score: 0.508
True tax: Phylum_B | Pred tax: Phylum_B | True role: microbial_grazer | Pred role: decomposer | Novel score: 0.508
True tax: Phylum_B | Pred tax: Phylum_B | True role: microbial_grazer | Pred role: decomposer | Novel score: 0.508


In [12]:
# Cell 10
torch.save({'model_state_dict': model.state_dict(), 'tokenizer_vocab': tokenizer.vocab}, 'multitask_model_demo_small.pth')
print("Saved model to multitask_model_demo_small.pth")


Saved model to multitask_model_demo_small.pth
