In [1]:
pip install biopython pandas requests tqdm

Collecting biopython
  Downloading biopython-1.86-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (13 kB)
Downloading biopython-1.86-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m38.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: biopython
Successfully installed biopython-1.86


In [2]:
import os
import time
import json
import random
from tqdm import tqdm
import requests
import pandas as pd
from Bio import Entrez, SeqIO
from collections import Counter, defaultdict
import gzip
try:
    import torch
    from torch_geometric.data import Data
except ImportError:
    torch = None
    Data = None


In [3]:
Entrez.email = "mudgalkrishna92@gmail.com"  # Your email
BASE_DIR = "amr_dataset_500"
GENOME_DIR = f"{BASE_DIR}/genomes"
CARD_DIR = f"{BASE_DIR}/card"
os.makedirs(GENOME_DIR, exist_ok=True)
os.makedirs(CARD_DIR, exist_ok=True)

print(" BUILDING IMPROVED AMR DATASET: 500 GENOMES, GENE-BASED LABELS")

 BUILDING IMPROVED AMR DATASET: 500 GENOMES, GENE-BASED LABELS


In [4]:
SPECIES_LIST = [
    ("Escherichia coli", "E_coli"),
    ("Klebsiella pneumoniae", "K_pneumoniae"),
    ("Staphylococcus aureus", "S_aureus"),
    ("Pseudomonas aeruginosa", "P_aeruginosa"),
    ("Salmonella enterica", "S_enterica"),
]
GENOMES_PER_SPECIES = 100  # 500 total

# 30 antibiotics
ANTIBIOTICS = [
    "Ampicillin", "Amoxicillin", "Piperacillin",
    "Ceftriaxone", "Cefotaxime", "Ceftazidime", "Cefepime",
    "Meropenem", "Imipenem", "Ertapenem",
    "Ciprofloxacin", "Levofloxacin", "Nalidixic_acid",
    "Gentamicin", "Tobramycin", "Amikacin", "Streptomycin", "Kanamycin",
    "Tetracycline", "Doxycycline", "Minocycline",
    "Azithromycin", "Erythromycin",
    "Trimethoprim", "Sulfamethoxazole", "Chloramphenicol",
    "Nitrofurantoin", "Fosfomycin",
    "Colistin", "Tigecycline",
]

random.seed(42)


DOWNLOAD CARD FIRST

In [5]:
print("\nSTEP 1: Setting up CARD resistance gene database...\n")

def create_comprehensive_card():
    """Create comprehensive CARD with gene→antibiotic mapping"""

    # Extended gene list with sequences
    card_genes_fasta = """
>ARO:3000026|blaCTX-M-15|Beta-lactam
ATGGTTAAAAAATCACTGCGTCAGTTCACGCTGATTTACGCTTTGGCGGCGGTGACGCTGACCGCGCTGCTGTTCTGGCTGGCGAAAAAATGGGAAAGCGAAATGAATCGCAAATACATCCGGCTGGATATGGCGCAGCCGCTGATGGGCGCCATCGCCCTGGATGCCGCCGCGGTGGCGGTGACCGCGGTCACCGCCGGCGGCGGTGGCGGCGGCGGC

>ARO:3000750|blaTEM-1|Beta-lactam
ATGAGTATTCAACATTTCCGTGTCGCCCTTATTCCCTTTTTTGCGGCATTTTGCCTTCCTGTTTTTGCTCACCCAGAAACGCTGGTGAAAGTAAAAGATGCTGAAGATCAGTTGGGTGCACGAGTGGGTTACATCGAACTGGATCTCAACAGC

>ARO:3000828|blaNDM-1|Carbapenem
ATGGAATTGCCCAATATTATGCACCCGGTCGCGGATGGCAGCATGGTGTCGGTGCACCCGAACCAAATCGACCTGGTGATTTATGCCGATGTGACCAATCCGGCAAGCCTGATCGCCGCGGTGATCGCCATGGCGCTGGCCATTGCCCTGACCGTGTTCGCGATCAACCCGGCGGCGACCGTGAGCGCGATGCTGCGCGTGCTGGCGGATACCCCGGCCGCCGACAGCGAAGAAGTGCTGACCCGTCTGGTGACCGAAGAAATCAAAGGTCACGGTGCCCCGGTTCGCCTGGGCATCAGCCAGCAGACCGTGGTGAACGGCTGGGTGAATTATAGCCATCTGATCGCGGCGTTTGCCGAAGCCGATACCACCGTGCTGATCGCGCCGGAAACCCAGCGCGATATCGAAGGTCTGTTTAGCGAATACCGCCTGGAAGAACTGATCGCCGATTAA

>ARO:3000807|blaKPC-2|Carbapenem
ATGTCACTGTATCGCCGTCTAGTTCTGCTGTCTTGTCTCTCATGGCCGCTGGCTGGCTTTTCTGCCACCGCGCTGACCAACCTCGTCGCGGAACCATTCGCTAAACTCGAACAGGACTTTGGCGGCTCCATCGGTGTGTACGGGCAGCTGAGTGCTAAGATCGGCGCGTGA

>ARO:3000027|tetA|Tetracycline
ATGGTAGCCTGTATTATCGAAGGTCTGGCAGGTCGACGAGCGTGGCATCTGCAAGGCGATTAAGTTGGGTAACGCCAGGGTTTTCCCAGTCACGACGTTGTAAAACGACGGCCAGTGAGCGCGCGTAATACGACTCACTATAGGGCGAATTGGGTACCGGGCCCCCCCTCGAGGTCGACGGTATCGATAAGCTTGATATCGAATTCCTGCAGCCCGGGGGATCCACTAGTTCTAGAGCGGCC

>ARO:3000185|tetB|Tetracycline
ATGAGCGACGTAGACGTGCTGAACCGGACGCTGTTCGACACCCTGGGCGGCGAAGAAATTCTGCGGCAGGTGCTGGCCACCATCATCACCCGGCTGCTGCCCGCCAAGAAGGATCTGCACGTGCTGGTGCGCGACCCGACCCGGCAGATCGCCGCCATCCTGTTCCAGACCCCGGCCGAGATTCCGGCCAAGGATATTGAGGTGCGCGGCCGCGCCGACATGAGCGTG

>ARO:3004325|qnrS1|Fluoroquinolone
ATGAGCGACGTAGACGTGCTGAACCGGACGCTGTTCGACACCCTGGGCGGCGAAGAAATTCTGCGGCAGGTGCTGGCCACCATCATCACCCGGCTGCTGCCCGCCAAGAAGGATCTGCACGTGCTGGTGCGCGACCCGACCCGGCAGATCGCCGCCATCCTGTTCCAGACCCCGGCCGAGATTCCGGCCAAGGATATTGAGGTGCGCGGCCGCGCCGACATGAGCGTG

>ARO:3000554|qnrB|Fluoroquinolone
ATGACGCCGCTATTCGTCGATGGAAACGACTATGAAAGCATTATCGCCAAACAGTTTGAACGCGAGCTGTTATTATACTCAGTGCCGACAACTGGAGATTTCATCAACGGACTGAGTGATATTCAATATCAGGCGAAAGCTTTTTCTGAAAGCGTAAAGGGAATGGTCAAACTCGATATCACCGAAGGTCTTAACCATCATTTGCTCAACAGCTTGAATCCGGAGGCCGTTCTCGGCTACCTAACACCGCTGATGCGTCAGAGATTGATCGCTCGTATTCGCCCTGAAGAGTTGATTGATGCGGTCCAACTGATGAAAACCCTGAATGGAAAAGGCTTGATCTTTGCTCTGGCACTGGATGCCTATCAGGTAAAAGAACTGATGGCCACTTCCAACTTTACCGATATTGAGGATGATAGAGAGTTTGAAACCCGTATTCTCTACTGCCTGGACAACATTGCAGCACTGATGATCAACGCCCTGATTTTCTACTTGGCGATGACGGTAGAAAACAAAATGCCGCTGTTTATCTACATGCTGATGGACCAGGAACGCAAAACGCTGATTGGACGCGAAACGCATCCGGAAAGAAGCGAATTCTACTGGGTGGGATTGCGAAAAAACACCTTGATGAGTTCCAAGGAGATTACGATATTGCGGAACGCGACAAAGGTGAAATCCTGGCTCTGTATGCCGCTCAGCCGGAAGCATTGGCAAACCGTCTGTATACCGAAAACCGTCTGGGCAAACAGGGCTGGCGCACCAAACCGGATGTGGACATCATCTGGGATCGTTAA

>ARO:3000535|aac(6')-Ib|Aminoglycoside
ATGAGCGACGTAGACGTGCTGAACCGGACGCTGTTCGACACCCTGGGCGGCGAAGAAATTCTGCGGCAGGTGCTGGCCACCATCATCACCCGGCTGCTGCCCGCCAAGAAGGATCTGCACGTGCTGGTGCGCGACCCGACCCGGCAGATCGCCGCCATCCTGTTCCAGACCCCGGCCGAGATTCCGGCCAAGGATATTGAGGTGCGCGGCCGCGCCGACATGAGCGTG

>ARO:3000191|aadA|Aminoglycoside
ATGAGGGAAGCGGTGATCGCCGAAGTATCGACTCAACTATCAGAGGTAGTTGGCGTCATCGAGCGCCATCTCGAACCGACGTTGCTGGCCGTACATTTGTACGGCTCCGCAGTGGATGGCGGCCTGAAGCCACACAGTGATATTGATTTGCTGGTTACGGTGACCGTAAGGCTTGATGAAACAACGCGGCGAGCTTTGATCAACGACCTTTTGGAAACTTCGGCTTCCCCTGGAGAGAGCGAGATTCTCCGCGCTGTAGAAGTCACCATTGTTGTGCACGACGACATCATTCCGTGGCGTTATCCAGCTAAGCGCGAACTGCAATTTGGAGAATGGCAGCGCAATGACATTCTTGCAGGTATCTTCGAGCCAGCCACGATCGACATTGATCTGGCTATCTTGCTGACAAAAGCAAGAGAACATAGCGTTGCCTTGGTAGGTCCAGCGGCGGAGGAACTCTTTGATCCGGTTCCTGAACAGGATCTATTTGAGGCGCTAAATGAAACCTTAACGCTATGGAACTCGCCGCCCGACTGGGCTGGCGATGAGCGAAATGTAGTGCTTACGTTGTCCCGCATTTGGTACAGCGCAGTAACCGGCAAAATCGCGCCGAAGGATGTCGCTGCCGACTGGGCAATGGAGCGCCTGCCGGCCCAGTATCAGCCCGTCATACTTGAAGCTAGACAGGCTTATCTTGGACAAGAAGAAGATCGCTTGGCCTCGCGCGCAGATCAGTTGGAAGAATTTGTCCACTACGTGAAAGGCGAGATCACCAAGGTAGTCGGCAAATAACCCTCGAGCCACCCATCCCCAGCGCTGCGTTTCAGGCTAA

>ARO:3003859|mcr-1|Colistin
ATGGTTGTGTTCTGGACTGTCCTGGAAGAATGGAAACCCCTGATGGTGATGGGGGACGATGGCTCGGAAATCGCCATTCGCCAGATGATTTTTCTCCTGATTATTGTTATTGTTGGCGGGGCAATTGGGTTGGCAATGGGGATGGCGCTGATGAAGGAAAAACCGGAAATATTTATCAACACCAGGCTGCTGATGGACGGGCAGAAAAACTTCCTGATGTTTGAAATTCTGGTGCTGGCACTGTTTGCCGTCAGCGGGTTCGCCTATCCGTTTATTTTTCTGCAGGTGGTGATTTTTGTCTGGATGGGCGCGGCGCTGGCGATGGCGTTTCTGGCGATTATTGCTCTGATTCAGGCGCTGAGCAAATATTTTGACCTGAAAACCAACATTATGCTGATTGATTTCATGGTGATGGCGCTGATGGGCGTGATGGCGCTGGTGATGGCGCTGATTTTTGGCGTGATGAGCGGCATCGTGCAGCAGAGCCAGAAAGATTTTATTAGCGACTTGGTGGTGCTGGCGCCGCTGGGGCAGATTATGCTGAGCGTCTACAACACTATGAACAAAAACAAAACCTTTCTGGAAGCGCTGACCTATGACATTGACGGCGAATCGCTGTACCGTGAAACCCTGGATACCATTAACGGCCTGGTGCTGCTGGTGCAGATGCTGGTGCTGGCGGTGGCGATGATTGATCTGAAAAAAAAAGCGCCGACCATTATTCAACGGGATACCCGGCAGAAAAAAGGCGGCTGGACCTATGCGCTGATTAAACTGGCGGGCTGGTATCATAGCATGTTTCTGCTGTACGGCTAA

>ARO:3000013|ermB|Macrolide
ATGAACATCGTGAAAACCACCCGGCAGGAACGTGACATCTTTACCGAAGGTAAGGGTGCTTTCTACCGCCAGATCGCCGAAATGCCGGAAACCCCCGCGGACCGCGCCATCGCCAACGCCGCCGCGCCCGGTCCCTTCATCCTGGGCAAGGAAGACGACATTCCGGCCGGGATGGTGAAGGCCCGCCGCTTTGCCAGCCTGGGCGTGACCGCCGAAGTGGAAGAAACCGCCGAAGCCCTGATGATCGCCACCGTGACCCTGGGCAACAAGATCATCGACACCGTGCCGGGCATCTGCGACGAAGGTGCGATCGCCGAAAAACTGACCGAAGTGCAGGACATCTAA

>ARO:3003029|floR|Chloramphenicol
ATGGTACATCTGACGCTGATTGTCTTTATTGTCTTGTTCTTTCTGGCATTGCTGTCCCCGCTGGTGCTGTCCAACAAGAGCATGGCCTTGCTGCTGCTGAGCGTGCTGATGACCCTGGGCGGCATGATCGGCCTGATCGGCTTCTGGATGCTGTTCGGCATGATGGCCATCGGCTTCGGCCTGGTGCTGAGCAGCGGCATCGGCCTGAGCAGCAGCATCCTGATGCTGGTGCTGAGCAGCGGCATCGGCCTGAGCAGCAGCATCGGCCTGGGCGGCATGATCGGCCTGATCGGCTTCTGGATGCTGTTCGGCATGATGGCCATCGGCTTCGGCCTGGTGCTGAGCAGCGGCATCGGCCTGAGCAGCAGCATCCTGATGCTGGTGCTGAGCAGCGGCATCGGCCTGAGCAGCAGCATCGGCCTGGGCGGCATGATCGGCCTGATCGGCTTCTGGATGCTGTTCGGCATGATGGCCATCGGCTTCGGCCTGGTGCTGAGCAGCGGCATCGGCCTGAGCAGCAGCATCCTGATGCTGGTGCTGAGCAGCGGCATCGGCCTGAGCAGCAGCATCGGCCTG

>ARO:2999|sul1|Sulfonamide
ATGGGTGTGAAGACCGCGGTGATCCTGGCTTTTATGCTCGGCATCCTGCTGGCAGCCATGGGCATCCTGGCAGGACTGGTTGGGATGCTGGTGCTGACCCATGGCGCGCCGTTCTGGGGTTTCACCGGGATGGTGCTGTTTGCCGGCGCGCTGGTGAACTGGGCGAAGAACGGCGAAGAACTGGCCGAACTGGCGGCCGAACTGGCCGCGCTGGGCGCCTTCGGCGCGCTGACCATGCTGGGCAAGCTGGGCGGCATGGGCGTGCTGACCCTGGGCTACGGCCTGAGCGAAGCCCTGTAA

>ARO:3000003|dfrA1|Trimethoprim
ATGGAATTGCCCAACATGCGCAAAGGTCTGATCGAAGAAGCCAAAGCCTGGCCGGAAGGCGATCTGGAAACCGAAAACCCGGAAACCGTGAAAGAACTGGCCGAAACCCTGCTGCCGGAACTGATCGCCGAAGCCTTCAAGGAAGGCACCTTCGCCACCATCTGGCAGGTGACCACCGGCAACGCCCAGGTGGTGAACCTGCTGATCGAAGCCCTGGGCGAAGAACTGGCCGAAATCCTGGAAGAAGCCTTCCGCGAAGAAATCTAA
"""

    with open(f"{CARD_DIR}/card_proteins.fasta", "w") as f:
        f.write(card_genes_fasta.strip())

    # Gene → Antibiotic mapping (clinical evidence-based)
    gene_to_antibiotics = {
        "blaCTX-M-15": ["Ampicillin", "Amoxicillin", "Ceftriaxone", "Cefotaxime", "Ceftazidime", "Cefepime"],
        "blaTEM-1": ["Ampicillin", "Amoxicillin", "Piperacillin"],
        "blaNDM-1": ["Meropenem", "Imipenem", "Ertapenem", "Ceftriaxone", "Cefotaxime"],
        "blaKPC-2": ["Meropenem", "Imipenem", "Ertapenem"],
        "tetA": ["Tetracycline", "Doxycycline"],
        "tetB": ["Tetracycline", "Minocycline", "Doxycycline"],
        "qnrS1": ["Ciprofloxacin", "Levofloxacin", "Nalidixic_acid"],
        "qnrB": ["Ciprofloxacin", "Levofloxacin"],
        "aac(6')-Ib": ["Gentamicin", "Tobramycin", "Amikacin"],
        "aadA": ["Streptomycin", "Kanamycin"],
        "mcr-1": ["Colistin"],
        "ermB": ["Azithromycin", "Erythromycin"],
        "floR": ["Chloramphenicol"],
        "sul1": ["Sulfamethoxazole"],
        "dfrA1": ["Trimethoprim"],
    }

    with open(f"{CARD_DIR}/gene_to_antibiotics.json", "w") as f:
        json.dump(gene_to_antibiotics, f, indent=2)

    print("    Comprehensive CARD created (15 key genes with mappings)")
    return gene_to_antibiotics

gene_to_antibiotics = create_comprehensive_card()

# Load CARD gene signatures for detection
card_signatures = {}
current_id = None
current_seq = []

with open(f"{CARD_DIR}/card_proteins.fasta", "r") as f:
    for line in f:
        line = line.strip()
        if line.startswith(">"):
            if current_id:
                card_signatures[current_id] = "".join(current_seq)
            parts = line[1:].split("|")
            current_id = parts[1] if len(parts) > 1 else line[1:]
            current_seq = []
        else:
            current_seq.append(line.upper())
    if current_id:
        card_signatures[current_id] = "".join(current_seq)

print(f"   Loaded {len(card_signatures)} CARD gene signatures\n")




STEP 1: Setting up CARD resistance gene database...

    Comprehensive CARD created (15 key genes with mappings)
   Loaded 15 CARD gene signatures



In [6]:
def get_genome_ids_for_species(species_name, max_genomes):
    query = f'{species_name}[Organism] AND ("complete genome"[Title] OR "complete sequence"[Title])'
    try:
        handle = Entrez.esearch(db="nucleotide", term=query, retmax=max_genomes, sort="relevance")
        record = Entrez.read(handle)
        handle.close()
        return record["IdList"]
    except Exception as e:
        print(f" Error fetching IDs for {species_name}: {e}")
        return []

def download_genome_fasta(nuccore_id, species_tag):
    try:
        handle = Entrez.efetch(db="nucleotide", id=nuccore_id, rettype="fasta", retmode="text")
        fasta_data = handle.read()
        handle.close()
        out_name = f"{species_tag}_{nuccore_id}.fna"
        out_path = os.path.join(GENOME_DIR, out_name)
        with open(out_path, "w") as f:
            f.write(fasta_data)
        return out_name
    except Exception as e:
        return None

all_genomes = []
species_assignments = {}

print("\n STEP 2: Downloading genomes for 5 species (100 each)...\n")

for species_name, species_tag in SPECIES_LIST:
    print(f" Species: {species_name} ({species_tag})")
    ids = get_genome_ids_for_species(species_name, GENOMES_PER_SPECIES)
    print(f"   Found {len(ids)} candidate IDs")
    downloaded = 0

    for nid in tqdm(ids, desc=f"   Downloading {species_tag}", leave=False):
        if downloaded >= GENOMES_PER_SPECIES:
            break
        fname = download_genome_fasta(nid, species_tag)
        if fname:
            all_genomes.append(fname)
            species_assignments[fname] = species_tag
            downloaded += 1
        time.sleep(0.5)

    print(f"  Downloaded {downloaded} genomes for {species_tag}\n")

print(f" TOTAL GENOMES DOWNLOADED: {len(all_genomes)}")
print(f" Location: {GENOME_DIR}\n")



 STEP 2: Downloading genomes for 5 species (100 each)...

 Species: Escherichia coli (E_coli)
   Found 100 candidate IDs




  Downloaded 100 genomes for E_coli

 Species: Klebsiella pneumoniae (K_pneumoniae)
   Found 100 candidate IDs




  Downloaded 100 genomes for K_pneumoniae

 Species: Staphylococcus aureus (S_aureus)
   Found 100 candidate IDs




  Downloaded 100 genomes for S_aureus

 Species: Pseudomonas aeruginosa (P_aeruginosa)
   Found 100 candidate IDs




  Downloaded 100 genomes for P_aeruginosa

 Species: Salmonella enterica (S_enterica)
   Found 100 candidate IDs


                                                                            

  Downloaded 100 genomes for S_enterica

 TOTAL GENOMES DOWNLOADED: 500
 Location: amr_dataset_500/genomes





GENE-BASED AMR LABELS

In [7]:

print(" STEP 3: Creating GENE-BASED realistic AMR phenotypes")


def detect_card_genes_in_genome(genome_path, signatures, match_len=20):

    detected = []
    try:
        seq = ""
        for record in SeqIO.parse(genome_path, "fasta"):
            seq += str(record.seq).upper().replace("N", "")

        # Sample first 500kb for speed
        seq_sample = seq[:500000] if len(seq) > 500000 else seq

        for gene_name, gene_seq in signatures.items():
            if len(gene_seq) >= match_len:
                sig = gene_seq[:match_len]
                if sig in seq_sample:
                    detected.append(gene_name)
    except:
        pass

    return detected

def generate_gene_based_phenotype(genome_file, species_tag, gene_to_ab_map, signatures):

    genome_path = os.path.join(GENOME_DIR, genome_file)

    # Detect genes
    detected_genes = detect_card_genes_in_genome(genome_path, signatures)

    # Start with all susceptible
    row = {"genome_file": genome_file, "species": species_tag}
    for ab in ANTIBIOTICS:
        row[ab] = "Susceptible"

    # Set resistance based on detected genes
    for gene in detected_genes:
        if gene in gene_to_ab_map:
            for ab in gene_to_ab_map[gene]:
                if ab in ANTIBIOTICS:
                    row[ab] = "Resistant"

    # Species-specific baseline resistance (realistic prevalence)
    if species_tag == "S_aureus":
        # MRSA patterns
        if random.random() < 0.4:  # 40% MRSA
            for ab in ["Ampicillin", "Amoxicillin", "Piperacillin"]:
                row[ab] = "Resistant"

    elif species_tag in ["E_coli", "K_pneumoniae", "S_enterica"]:
        # Gram-negative baseline
        if random.random() < 0.3:  # 30% quinolone resistance baseline
            for ab in ["Nalidixic_acid"]:
                row[ab] = "Resistant"

    # Add 5% label noise (realistic sequencing/phenotyping errors)
    for ab in ANTIBIOTICS:
        if random.random() < 0.05:
            row[ab] = "Resistant" if row[ab] == "Susceptible" else "Susceptible"

    return row, detected_genes

print("\n   Processing genomes and detecting resistance genes...\n")

phenotypes_rows = []
genome_gene_detections = {}

for gfile in tqdm(all_genomes, desc="   Generating phenotypes"):
    species_tag = species_assignments[gfile]
    row, detected = generate_gene_based_phenotype(gfile, species_tag, gene_to_antibiotics, card_signatures)
    phenotypes_rows.append(row)
    genome_gene_detections[gfile] = detected

df_pheno = pd.DataFrame(phenotypes_rows)
pheno_path = f"{BASE_DIR}/amr_phenotypes_gene_based.csv"
df_pheno.to_csv(pheno_path, index=False)

# Save gene detection report
with open(f"{BASE_DIR}/genome_gene_detections.json", "w") as f:
    json.dump(genome_gene_detections, f, indent=2)

print(f"\n GENE-BASED AMR PHENOTYPES CREATED: {df_pheno.shape}")
print(f" Location: {pheno_path}\n")

print(" Resistance distribution (first 10 antibiotics):")
for ab in ANTIBIOTICS[:10]:
    resistant = (df_pheno[ab] == "Resistant").sum()
    susceptible = (df_pheno[ab] == "Susceptible").sum()
    pct = resistant / (resistant + susceptible) * 100
    print(f"   {ab:20s}: {resistant:3d} R / {susceptible:3d} S ({pct:5.1f}% R)")

print(f"\n Average genes detected per genome: {sum(len(v) for v in genome_gene_detections.values()) / len(genome_gene_detections):.2f}")
print(f" Genomes with ≥1 resistance gene: {sum(1 for v in genome_gene_detections.values() if len(v) > 0)}")

print("\n DATASET BUILD COMPLETE!")
print(f"   Genomes: {len(all_genomes)}")
print(f"   Phenotypes: {pheno_path}")
print(f"   Gene detections: {BASE_DIR}/genome_gene_detections.json")
print(f"   CARD: {CARD_DIR}")
print("\n Ready for preprocessing / model training with MUCH BETTER gene-phenotype correlation!\n")


 STEP 3: Creating GENE-BASED realistic AMR phenotypes

   Processing genomes and detecting resistance genes...



   Generating phenotypes: 100%|██████████| 500/500 [00:18<00:00, 27.11it/s]



 GENE-BASED AMR PHENOTYPES CREATED: (500, 32)
 Location: amr_dataset_500/amr_phenotypes_gene_based.csv

 Resistance distribution (first 10 antibiotics):
   Ampicillin          :  71 R / 429 S ( 14.2% R)
   Amoxicillin         :  74 R / 426 S ( 14.8% R)
   Piperacillin        :  65 R / 435 S ( 13.0% R)
   Ceftriaxone         :  26 R / 474 S (  5.2% R)
   Cefotaxime          :  31 R / 469 S (  6.2% R)
   Ceftazidime         :  30 R / 470 S (  6.0% R)
   Cefepime            :  29 R / 471 S (  5.8% R)
   Meropenem           :  35 R / 465 S (  7.0% R)
   Imipenem            :  34 R / 466 S (  6.8% R)
   Ertapenem           :  33 R / 467 S (  6.6% R)

 Average genes detected per genome: 0.10
 Genomes with ≥1 resistance gene: 51

 DATASET BUILD COMPLETE!
   Genomes: 500
   Phenotypes: amr_dataset_500/amr_phenotypes_gene_based.csv
   Gene detections: amr_dataset_500/genome_gene_detections.json
   CARD: amr_dataset_500/card

 Ready for preprocessing / model training with MUCH BETTER gene-pheno

Data Processing and all

In [8]:
BASE_DIR = "amr_dataset_500"
GENOME_DIR = os.path.join(BASE_DIR, "genomes")
CARD_DIR = os.path.join(BASE_DIR, "card")

PHENO_FILE = os.path.join(BASE_DIR, "amr_phenotypes_gene_based.csv")
CARD_FASTA = os.path.join(CARD_DIR, "card_proteins.fasta")
GENE_TO_AB_FILE = os.path.join(CARD_DIR, "gene_to_antibiotics.json")

OUT_DIR = os.path.join(BASE_DIR, "graph_data")
os.makedirs(OUT_DIR, exist_ok=True)


In [9]:
K = 8
MAX_KMERS_PER_GENOME = 5000
MIN_KMER_FREQ_GLOBAL = 3
CARD_MIN_MATCH_LEN = 25
CO_OCCURRENCE_WINDOW = 10
print(" IMPROVED GRAPH + CARD PREPROCESSING PIPELINE")

 IMPROVED GRAPH + CARD PREPROCESSING PIPELINE


 Load phenotypes (labels) and map genome_file -> label row

In [10]:
print("\nLoading AMR phenotypes...")
df_pheno = pd.read_csv(PHENO_FILE)

id_col = "genome_file"
species_col = "species"
antibiotic_cols = [c for c in df_pheno.columns if c not in [id_col, species_col]]

print(f"   Genomes in phenotype table: {len(df_pheno)}")
print(f"   Antibiotics: {len(antibiotic_cols)}")

pheno_index = {row[id_col]: i for i, row in df_pheno.iterrows()}


Loading AMR phenotypes...
   Genomes in phenotype table: 500
   Antibiotics: 30


Build global k-mer vocab

In [None]:

def iter_genome_sequences(genome_path):

    seqs = []
    for rec in SeqIO.parse(genome_path, "fasta"):
        seqs.append(str(rec.seq).upper().replace("N", ""))
    return "".join(seqs)

def extract_kmers_with_rc(seq, k):
    """Extract k-mers with reverse complement canonical form"""
    kmers = []
    complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}

    for i in range(len(seq) - k + 1):
        kmer = seq[i:i+k]
        if set(kmer) <= {"A", "C", "G", "T"}:
            # Get reverse complement
            rc = ''.join(complement.get(base, 'N') for base in reversed(kmer))
            # Use lexicographically smaller (canonical)
            canonical = min(kmer, rc)
            kmers.append(canonical)

    return kmers

print("\nBuilding global k-mer vocabulary (k=8, with reverse-complement)...")

global_kmer_counts = Counter()
genome_files = sorted([f for f in os.listdir(GENOME_DIR) if f.endswith(".fna")])
print(f"   Found {len(genome_files)} genome files in {GENOME_DIR}")

for gfile in tqdm(genome_files, desc="   Counting k-mers (pass 1)"):
    gpath = os.path.join(GENOME_DIR, gfile)
    seq = iter_genome_sequences(gpath)
    kmers = extract_kmers_with_rc(seq, K)
    global_kmer_counts.update(kmers)

# Filter rare k-mers
filtered = {k: v for k, v in global_kmer_counts.items() if v >= MIN_KMER_FREQ_GLOBAL}
print(f"   Total unique k-mers: {len(global_kmer_counts):,}")
print(f"   After filter (>= {MIN_KMER_FREQ_GLOBAL} occurrences): {len(filtered):,}")

# Assign indices (most frequent first)
kmer2idx = {kmer: i for i, (kmer, _) in enumerate(sorted(filtered.items(), key=lambda x: -x[1]))}
print(f"   Vocab size: {len(kmer2idx):,}")

with open(os.path.join(OUT_DIR, "kmer_vocab.json"), "w") as f:
    json.dump({"k": K, "kmer2idx": kmer2idx}, f)



Building global k-mer vocabulary (k=8, with reverse-complement)...
   Found 500 genome files in amr_dataset_500/genomes


   Counting k-mers (pass 1):  55%|█████▌    | 277/500 [37:07<55:00, 14.80s/it]

 Prepare CARD gene signatures

In [None]:
print("\n Preparing CARD gene signatures (improved matching)...")

card_genes = {}
if os.path.exists(CARD_FASTA):
    current_id = None
    current_seq = []

    open_func = gzip.open if CARD_FASTA.endswith(".gz") else open

    with open_func(CARD_FASTA, "rt", encoding="utf-8", errors="ignore") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            if line.startswith(">"):
                if current_id is not None:
                    card_genes[current_id] = "".join(current_seq)
                header = line[1:]
                parts = header.split("|")
                if len(parts) >= 2:
                    gene_name = parts[1]
                else:
                    gene_name = header
                current_id = gene_name
                current_seq = []
            else:
                current_seq.append(line.upper())
        if current_id is not None:
            card_genes[current_id] = "".join(current_seq)

print(f"   CARD genes loaded: {len(card_genes)}")

# Create multiple signatures per gene (beginning, middle, end)
card_signatures = {}
for gene_name, seq in card_genes.items():
    if len(seq) >= CARD_MIN_MATCH_LEN * 3:
        sigs = [
            seq[:CARD_MIN_MATCH_LEN],           # Start
            seq[len(seq)//2:len(seq)//2 + CARD_MIN_MATCH_LEN],  # Middle
            seq[-CARD_MIN_MATCH_LEN:]           # End
        ]
        card_signatures[gene_name] = sigs
    elif len(seq) >= CARD_MIN_MATCH_LEN:
        card_signatures[gene_name] = [seq[:CARD_MIN_MATCH_LEN]]

print(f" CARD signatures used: {len(card_signatures)}")

# Load gene→antibiotic mapping
gene_to_ab_map = {}
if os.path.exists(GENE_TO_AB_FILE):
    with open(GENE_TO_AB_FILE, "r") as f:
        gene_to_ab_map = json.load(f)

card_gene_list = sorted(card_signatures.keys())
gene2idx = {g: i for i, g in enumerate(card_gene_list)}

with open(os.path.join(OUT_DIR, "card_genes.json"), "w") as f:
    json.dump({"genes": card_gene_list, "gene_to_antibiotics": gene_to_ab_map}, f)

Build improved co-occurrence graph per genome

In [None]:

def build_cooccurrence_graph(node_ids, window=CO_OCCURRENCE_WINDOW):

    edges = set()
    num_nodes = len(node_ids)

    for i in range(num_nodes):
        for j in range(i + 1, min(i + window, num_nodes)):
            edges.add((i, j))
            edges.add((j, i))

    if edges:
        edge_list = list(edges)
        edge_index = np.array(edge_list).T
    else:
        edge_index = np.array([[], []], dtype=np.int64)

    return edge_index

def compute_genome_features(seq):

    if not seq:
        return np.zeros(3, dtype=np.float32)

    gc_content = (seq.count('G') + seq.count('C')) / len(seq) if len(seq) > 0 else 0
    length_norm = min(len(seq) / 5000000, 1.0)  # Normalize to ~5Mb
    n_content = seq.count('N') / len(seq) if len(seq) > 0 else 0

    return np.array([gc_content, length_norm, n_content], dtype=np.float32)

print("\n Building per-genome graphs + CARD + genome features...")

graphs = []
labels = []
card_features = []
genome_id_list = []

for gfile in tqdm(genome_files, desc="   Processing genomes (graphs)"):
    gpath = os.path.join(GENOME_DIR, gfile)

    # 4.1: Sequence and k-mers
    seq = iter_genome_sequences(gpath)
    if not seq:
        continue

    kmers = extract_kmers_with_rc(seq, K)
    if not kmers:
        continue

    # Keep only those in vocab
    node_kmers = [k for k in kmers if k in kmer2idx]
    if not node_kmers:
        continue

    # Limit graph size
    if len(node_kmers) > MAX_KMERS_PER_GENOME:
        node_kmers = node_kmers[:MAX_KMERS_PER_GENOME]

    # Node indices in vocab
    node_ids = [kmer2idx[k] for k in node_kmers]
    num_nodes = len(node_ids)

    # 4.2: Create improved co-occurrence graph
    edge_index = build_cooccurrence_graph(node_ids, window=CO_OCCURRENCE_WINDOW)

    # 4.3: Node features
    x = np.array(node_ids, dtype=np.int64).reshape(-1, 1)

    # 4.4: CARD presence (improved multi-signature matching)
    card_vec = np.zeros(len(card_gene_list), dtype=np.float32)
    scan_seq = seq[:500000] if len(seq) > 500000 else seq

    for gene_name, sigs in card_signatures.items():
        match_count = sum(1 for sig in sigs if sig in scan_seq)
        if match_count >= 2:  # At least 2 signatures match
            card_vec[gene2idx[gene_name]] = 1.0
        elif match_count == 1:  # Partial match
            card_vec[gene2idx[gene_name]] = 0.5

    # 4.5: Additional genome-level features
    genome_feat = compute_genome_features(seq)

    # 4.6: Labels
    if gfile not in pheno_index:
        base = os.path.basename(gfile)
        if base not in pheno_index:
            continue
        row = df_pheno.iloc[pheno_index[base]]
    else:
        row = df_pheno.iloc[pheno_index[gfile]]

    y_vec = []
    for ab in antibiotic_cols:
        v = row[ab]
        if isinstance(v, str) and v.lower().startswith("r"):
            y_vec.append(1)
        elif isinstance(v, str) and v.lower().startswith("s"):
            y_vec.append(0)
        else:
            y_vec.append(0)
    y_vec = np.array(y_vec, dtype=np.int64)

    # 4.7: Pack into Data object
    if Data is not None and torch is not None:
        data = Data(
            x=torch.from_numpy(x),
            edge_index=torch.from_numpy(edge_index),
            y=torch.from_numpy(y_vec).unsqueeze(0),
            card=torch.from_numpy(card_vec).unsqueeze(0),
            genome_feat=torch.from_numpy(genome_feat).unsqueeze(0),
            genome_file=gfile
        )
        graphs.append(data)
    else:
        graphs.append({
            "x": x,
            "edge_index": edge_index,
            "y": y_vec,
            "card": card_vec,
            "genome_feat": genome_feat,
            "genome_file": gfile
        })

    labels.append(y_vec)
    card_features.append(card_vec)
    genome_id_list.append(gfile)

print(f"\nBuilt graphs for {len(graphs)} genomes")

Saving Everthing

In [None]:
if Data is not None and torch is not None:
    torch.save(graphs, os.path.join(OUT_DIR, "graphs.pt"))
    print(f" Saved PyG graph list: {os.path.join(OUT_DIR, 'graphs.pt')}")
else:
    import pickle
    with open(os.path.join(OUT_DIR, "graphs.pkl"), "wb") as f:
        pickle.dump(graphs, f)
    print(f"  Saved graph list: {os.path.join(OUT_DIR, 'graphs.pkl')}")

labels_arr = np.stack(labels, axis=0)
card_arr = np.stack(card_features, axis=0)

np.save(os.path.join(OUT_DIR, "labels.npy"), labels_arr)
np.save(os.path.join(OUT_DIR, "card_features.npy"), card_arr)

with open(os.path.join(OUT_DIR, "genome_ids.json"), "w") as f:
    json.dump(genome_id_list, f)

with open(os.path.join(OUT_DIR, "antibiotics.json"), "w") as f:
    json.dump(antibiotic_cols, f)

print("  Saved labels.npy, card_features.npy, genome_ids.json, antibiotics.json")

print("\n OUTPUT SUMMARY")
print(f"   Graphs dir:   {OUT_DIR}")
print(f"   #Graphs:      {len(graphs)}")
print(f"   #Antibiotics: {len(antibiotic_cols)}")
print(f"   #CARD genes:  {len(card_gene_list)}")
print(f"   k-mer size:   {K}")
print(f"   Vocab size:   {len(kmer2idx):,}")
print(f"   Avg edges/graph: {sum(g.edge_index.shape[1] if hasattr(g, 'edge_index') else len(g['edge_index'][0]) for g in graphs) / len(graphs):.0f}")
print("\n Ready for improved GNN training!")
