## Convert sequences to amino acids

Create Fasta files of amino acid sequences for alignment.

In [None]:
# --- User-configurable paths ---
# Set these to match your environment.
# The `NCBI_GROUPED_DIR` is produced by `data_scripts/data_curation/download_cds_clean.ipynb`.
from pathlib import Path


PROJECT_ROOT = Path("/workspace")  # Path to the repo root
NCBI_GROUPED_DIR = Path("/data/ncbi_grouped")  # Grouped CSVs per organism group
AA_OUTPUT_DIR = Path("/data/ncbi_grouped_aa")  # Output folder for generated FASTA files

# MMseqs2 working directory (contains allSeqClust.* and allSeqs.lookup)
MMSEQS_WORK_DIR = Path("/data/codonfm_mmseqs/ncbi")

# Dataset storage for mmap and cache
DATASET_DIR = Path("/workspace/codonfm_mmseqs/temp_save")
CACHE_PATH = Path("/data/codonfm_mmseqs/ncbi/global_index.cache.npy")

# Where to save final clustering outputs
CLUSTERS_OUTPUT_DIR = Path("clusters")
CLUSTERS_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Ensure output directories exist
AA_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
import os
from glob import glob

import polars as pl
from Bio.Seq import Seq


# Input grouped CSVs and output FASTA directory
files = sorted(glob(str(NCBI_GROUPED_DIR / "*.csv")))


def translate_cds(cds_seq):
    try:
        return str(Seq(cds_seq).translate(to_stop=False))
    except:
        return None


for fn in files:
    data = pl.read_csv(fn)

    group = os.path.basename(fn).split(".")[0]

    data = data.with_columns(pl.col("cds").map_elements(translate_cds, return_dtype=pl.Utf8).alias("amino_acid"))
    out_fa = AA_OUTPUT_DIR / f"{group}.fa"
    with open(out_fa, "w") as f:
        f.writelines(f">{group}_{i}\n{row['amino_acid']}\n" for i, row in enumerate(data.iter_rows(named=True)))

## Use MMSeqs to cluster the sequences

MMseqs2 is required for the clustering step. If you don't have it installed, install it first:

- Recommended (Conda/):
```bash
conda install -c conda-forge -c bioconda mmseqs2
```
- Ubuntu (apt, may be older version):
```bash
sudo apt-get update && sudo apt-get install -y mmseqs2
```
- Docker (no local install needed):
```bash
docker run --rm -it -v "$PWD":"/work" soedinglab/mmseqs2:latest bash
# Then run the commands below inside the container in /work
```

Once installed, run the following from the directory containing your `.fa` files (each file holds amino-acid sequences):

```bash
# Create a temporary working directory for MMseqs2
mkdir -p alltmp

# Create an MMseqs2 database from all fasta files in the folder
mmseqs createdb *.fa allSeqs

# Cluster sequences at 50% identity, 90% coverage (cov-mode 5)
mmseqs linclust allSeqs allSeqClust alltmp \
  --min-seq-id 0.5 -c 0.9 --cov-mode 5 \
  --threads "$(nproc)"
```

Notes:
- This notebook expects the outputs `allSeqClust.*` and `allSeqs.lookup` to be generated in the working directory.
- Increase or decrease `--min-seq-id`/`-c` as needed for your clustering granularity.
- If `mmseqs` is not found, ensure it is on your PATH (e.g., `conda activate <env>`).

``` 
mmseqs createdb *.fa allSeqs
mmseqs linclust allSeqs allSeqClust alltmp --min-seq-id 0.5 -c 0.9 --cov-mode 5

```

**Once the above are generated, you can run the following to map the sequences to clusters:**

## Map the sequences to clusters

In [None]:
import json
import sys
from glob import glob

import numpy as np
import polars as pl
from tqdm import tqdm


sys.path.append(
    "/workspace/codon-fm"
)  # NOTE: this assumes you've launched the notebook from /workspace and the codon-fm repo is mounted at /workspace/codon-fm
from pathlib import Path
from typing import Callable

import torch

from src.tokenizer import Tokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class CodonMmapDataset(torch.utils.data.Dataset):
    def __init__(self, data_path: str, cache_path: str, tokenizer: Callable, seed: int = 42):
        self.data_path = Path(data_path)
        self.metadata_path = self.data_path / "metadata.json"
        self.tokenizer = tokenizer
        self.cache_path = Path(cache_path)

        with open(self.metadata_path, "r") as f:
            metadata = json.load(f)

        self.chunks_metadata = metadata["chunks"]

        self.group_offset = {}
        curr = 0
        for item in metadata["file_metadata"]:
            gr = item["file_name"].split(".csv")[0]
            if gr not in self.group_offset:
                self.group_offset[gr] = curr
            curr += item["end"] - item["start"] + 1

        if self.cache_path.exists():
            print("Loading cached global indices...")
            self.global_indices = np.load(cache_path, allow_pickle=True)  # .tolist()
        else:
            self.indices_mmaps = []
            for chunk in self.chunks_metadata:
                idx_mmap_path = self.data_path / chunk["index"]["path"]

                idx_mmap = np.memmap(
                    idx_mmap_path, dtype=chunk["index"]["dtype"], mode="r", shape=tuple(chunk["index"]["shape"])
                )

                self.indices_mmaps.append(idx_mmap)
            print("Computing global indices for subsequences...")
            self.global_indices = []
            for chunk_id, idx_mmap in enumerate(self.indices_mmaps):
                for seq_idx in tqdm(range(len(idx_mmap))):
                    seq_start, seq_end, taxid = idx_mmap[seq_idx]

                    self.global_indices.append((chunk_id, seq_start, seq_end))

            np.save(self.cache_path, np.array(self.global_indices, dtype=np.uint32))
            print(f"Cached global indices saved at {cache_path}")

    def get_by_group(self, group, group_idx):
        idx = self.group_offset[group] + group_idx
        # return idx
        return self.__getitem__(idx)

    def idx_by_group(self, group, group_idx):
        idx = self.group_offset[group] + group_idx
        return idx
        # return self.__getitem__(idx)

    def __len__(self):
        return len(self.global_indices)

    def __getitem__(self, idx):
        chunk_id, start_token_idx, end_token_idx = self.global_indices[idx]
        chunk = self.chunks_metadata[chunk_id]
        seq_mmap_path = self.data_path / chunk["sequences"]["path"]
        seq_mmap = np.memmap(
            seq_mmap_path, dtype=chunk["sequences"]["dtype"], mode="r", shape=tuple(chunk["sequences"]["shape"])
        )
        sequence_tokens = seq_mmap[start_token_idx:end_token_idx]
        return sequence_tokens

In [None]:
tokenizer = Tokenizer(
    cls_token="<CLS>",
    bos_token="<CLS>",
    sep_token="<SEP>",
    unk_token="<UNK>",
    pad_token="<PAD>",
    mask_token="<MASK>",
    padding_side="right",
    truncation="right",
    seq_type="dna",
)
dataset = CodonMmapDataset(DATASET_DIR, cache_path=CACHE_PATH, tokenizer=tokenizer)

Loading cached global indices...


In [None]:
clusters = []
for fn in tqdm(sorted(glob(str(MMSEQS_WORK_DIR / "allSeqClust.[0-9]*")))):
    with open(fn, "rb") as f:
        temp = f.read().split(b"\x00")
        temp = [list(map(int, x.strip().split())) for x in temp]
        clusters += temp

100%|██████████| 96/96 [00:49<00:00,  1.94it/s]


In [None]:
lookup = pl.read_csv(str(MMSEQS_WORK_DIR / "allSeqs.lookup"), separator="\t", has_header=False)

In [7]:
#
lookup.columns = ["seq_idx", "seq_name", "index"]
lookup = lookup.select(["seq_idx", "seq_name"]).with_row_index()
lookup_seqs = lookup["seq_name"].to_list()

In [8]:
def get_g_i(gi):
    parts = gi.rsplit("_", 1)  # Split only once, from the right
    g = parts[0]
    i = int(parts[1])
    return g, i

In [21]:
seq_name_clusters = []
global_group_idx = [-1] * len(dataset)

for cluster_i, cluster in enumerate(tqdm(clusters)):
    if cluster:
        # curr_gi = [get_g_i(lookup_seqs[i]) for i in cluster]
        curr_gi = [lookup_seqs[i] for i in cluster]
        seq_name_clusters.append(curr_gi)

        for g, i in map(get_g_i, curr_gi):
            idx = dataset.idx_by_group(g, i)
            global_group_idx[idx] = cluster_i

100%|██████████| 41422336/41422336 [03:22<00:00, 204460.04it/s]


In [None]:
np.save(CLUSTERS_OUTPUT_DIR / "allSeqClusterIdx.npy", np.array(global_group_idx))

# NOTE: this file should be moved to the directory that contains the memmap files