In [1]:
import pandas as pd 
from datasets import load_dataset, load_from_disk
import pyarrow.parquet as pq  
import time
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
parquet_path = "/gpfs/data/brandeslab/Data/uniref/uniref90_clusters.parquet"
table = pq.read_table(parquet_path)
clusters = table.to_pylist()
print(f"Loaded {len(clusters)} clusters into memory.")

In [2]:
import pyarrow.parquet as pq
import random

class CachedRowGroupClusterSampler:
    def __init__(self, parquet_path):
        self.pf = pq.ParquetFile(parquet_path)
        self.num_row_groups = self.pf.num_row_groups
        self.row_group_cache = []
        self.current_rg_index = None

    def sample_cluster(self):
        # If cache is empty, load a new row group
        if not self.row_group_cache:
            # Choose a random row group
            rg_idx = random.randint(0, self.num_row_groups - 1)
            table = self.pf.read_row_group(rg_idx)
            rows = table.to_pylist()  # 1840 rows typically
            random.shuffle(rows)  # Shuffle to maintain randomness
            self.row_group_cache.extend(rows)
            self.current_rg_index = rg_idx

        # Pop one sample from cache
        row = self.row_group_cache.pop()
        return {
            "cluster_id": row["cluster_id"],
            "representative_id": row["representative_id"],
            "members": row["members"],
        }


In [5]:
class RandomClusterSampler:
    def __init__(self, parquet_path):
        self.path = parquet_path
        
        # Read metadata once
        self.meta = pq.read_metadata(parquet_path)
        self.num_row_groups = self.meta.num_row_groups 
        
        # Open ParquetFile once 
        self.pf = pq.ParquetFile(parquet_path)

    def sample_clusters(self):
        # randomly select a row group (multiple clusters per row group)
        rg = np.random.randint(0, self.num_row_groups)

        # read only that row group 
        table = self.pf.read_row_group(rg)
        n = table.num_rows

        # Pick random row (cluster) from the row group
        cluster_idx = np.random.randint(0, n)
        row = table.slice(cluster_idx, 1)

        cluster_id = row["cluster_id"][0].as_py()
        rep_id      = row["representative_id"][0].as_py()
        members     = row["members"][0].as_py()

        return {
            "cluster_id": cluster_id,
            "representative_id": rep_id,
            "members": members
        }

In [None]:
class InMemoryClusterSampler:
    def __init__(self, parquet_path):
        print(f"Loading Parquet file into memory (Arrow table): {parquet_path}...")
        self.table = pq.read_table(parquet_path)  
        self.num_rows = self.table.num_rows
        print(f"Loaded Arrow table with {self.num_rows} clusters.")

    def sample_cluster(self):
        idx = random.randint(0, self.num_rows - 1)
        row = {k: self.table[k][idx].as_py() for k in self.table.column_names}
        return {
            "cluster_id": row["cluster_id"],
            "representative_id": row["representative_id"],
            "members": row["members"],
        }

In [9]:
parquet_path = "/gpfs/data/brandeslab/Data/uniref/uniref90_clusters.parquet"
sampeler1_start_time = time.time()
sampler1 = CachedRowGroupClusterSampler(parquet_path)
# Example usage
sampler1_samples = []
for i in range(129):
    sample = sampler1.sample_cluster()
    sampler1_samples.append(sample)
print(len(sampler1_samples))  # Print a sampled cluster
sampeler1_end_time = time.time()
print(f"Sampler1 Time taken: {sampeler1_end_time - sampeler1_start_time} seconds")

129
Sampler1 Time taken: 0.8135685920715332 seconds


In [10]:
parquet_path = "/gpfs/data/brandeslab/Data/uniref/uniref90_clusters.parquet"
sampeler2_start_time = time.time()
sampler2 = RandomClusterSampler(parquet_path)
# Example usage
sampler2_samples = []
for i in range(129):
    sample = sampler2.sample_clusters()
    sampler2_samples.append(sample)
print(len(sampler2_samples))  # Print a sampled cluster
sampeler2_end_time = time.time()
print(f"Sampler1 Time taken: {sampeler2_end_time - sampeler2_start_time} seconds")

129
Sampler1 Time taken: 5.8889000415802 seconds


In [4]:
meta = load_dataset("parquet", 
                    data_files="/gpfs/data/brandeslab/Data/uniref/uniref90_clusters.parquet")["train"]

Generating train split: 184146434 examples [00:20, 8911737.66 examples/s] 


In [5]:
ds = load_dataset(
    "parquet",
    data_files="/gpfs/data/brandeslab/Data/uniref/uniref90_clusters.parquet",
    split="train",
    streaming=True
)


In [6]:
stream = load_dataset(
    "parquet",
    data_files="/gpfs/data/brandeslab/Data/uniref/uniref90_clusters.parquet",
    streaming=True
)

In [7]:
from datasets import Dataset

ds = Dataset.from_parquet("/gpfs/data/brandeslab/Data/uniref/uniref90_clusters.parquet")

Generating train split: 184146434 examples [00:18, 9966633.86 examples/s] 


In [8]:
from datasets import load_dataset

ds = load_dataset(
    "parquet",
    data_files="/gpfs/data/brandeslab/Data/uniref/uniref90_clusters.parquet",
    split="train",      # required if not streaming
    streaming=False     # or just omit this arg
)

print(ds)
print(ds.features)


Dataset({
    features: ['cluster_id', 'representative_id', 'member_count_xml_file', 'member_count', 'members'],
    num_rows: 184146434
})
{'cluster_id': Value(dtype='string', id=None), 'representative_id': Value(dtype='string', id=None), 'member_count_xml_file': Value(dtype='int32', id=None), 'member_count': Value(dtype='int32', id=None), 'members': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)}


In [9]:
import pyarrow.parquet as pq
from datasets import Dataset

table = pq.read_table("/gpfs/data/brandeslab/Data/uniref/uniref90_clusters.parquet")
ds = Dataset(table)
print(ds)

Dataset({
    features: ['cluster_id', 'representative_id', 'member_count_xml_file', 'member_count', 'members'],
    num_rows: 184146434
})


In [10]:
import pyarrow.parquet as pq

table = pq.read_table("/gpfs/data/brandeslab/Data/uniref/uniref90_clusters.parquet")
table

pyarrow.Table
cluster_id: string
representative_id: string
member_count_xml_file: int32
member_count: int32
members: list<element: string>
  child 0, element: string
----
cluster_id: [["UniRef90_UPI002E2621C6","UniRef90_UPI00358F51CD","UniRef90_UPI00398E31D8","UniRef90_A0A5A9P0L4","UniRef90_A0ABD1JBH0",...,"UniRef90_UPI00225B580A","UniRef90_UPI00254460F6","UniRef90_UPI002796F535","UniRef90_UPI002896634E","UniRef90_UPI003D7BF59D"],["UniRef90_UPI00222E980F","UniRef90_UPI0024785306","UniRef90_UPI0035A280ED","UniRef90_UPI001B351495","UniRef90_UPI00222EC2CE",...,"UniRef90_UPI00248289A7","UniRef90_UPI0012FEA33D","UniRef90_A0A8C5NYK4","UniRef90_UPI0025ADAEBD","UniRef90_UPI003BF9D085"],...,["UniRef90_A0A3M7PZG0","UniRef90_Q9BM67","UniRef90_Q9BM66","UniRef90_E9BBM2","UniRef90_Q9BLZ4",...,"UniRef90_A0A0T5Z569","UniRef90_D4J9Q7","UniRef90_A0A6A7YBE0","UniRef90_A0A0T5Z260","UniRef90_A0A5C6IVY6"],["UniRef90_Q725V0","UniRef90_A0A4R4SIP9","UniRef90_A0ABW1B1Z7","UniRef90_A0A943LZD8","UniRef90_A0ABV6HH

In [11]:
pq.read_metadata("/gpfs/data/brandeslab/Data/uniref/uniref90_clusters.parquet").num_rows

184146434

In [12]:
table.schema

cluster_id: string
representative_id: string
member_count_xml_file: int32
member_count: int32
members: list<element: string>
  child 0, element: string

In [13]:
import pandas as pd
df = pd.read_csv("/gpfs/data/brandeslab/Data/uniref/uniref90_cluster_members.tsv", sep="\t", nrows=100)
df

Unnamed: 0,ClusterID,RepresentativeID,member_count_xml_file,member_count,MemberIDs
0,UniRef90_UPI002E2621C6,UPI002E2621C6,1,1,
1,UniRef90_UPI00358F51CD,UPI00358F51CD,1,1,
2,UniRef90_UPI00398E31D8,UPI00398E31D8,1,1,
3,UniRef90_A0A5A9P0L4,A0A5A9P0L4_9TELE,1,1,
4,UniRef90_A0ABD1JBH0,A0ABD1JBH0_9TELE,1,1,
...,...,...,...,...,...
95,UniRef90_A0A6P8RG45,A0A6P8RG45_GEOSA,1,1,
96,UniRef90_UPI003D69693D,UPI003D69693D,1,1,
97,UniRef90_UPI003EBDD6D2,UPI003EBDD6D2,3,3,"UPI003EBC14C6,UPI003EB9A168"
98,UniRef90_A0A1V4K6M4,A0A1V4K6M4_PATFA,562,562,"A0A7K4RU14_COLPI,A0A7L4G2H3_9COLU,A0A094KAD8_A..."


In [14]:
df.iloc[97]

ClusterID                     UniRef90_UPI003EBDD6D2
RepresentativeID                       UPI003EBDD6D2
member_count_xml_file                              3
member_count                                       3
MemberIDs                UPI003EBC14C6,UPI003EB9A168
Name: 97, dtype: object

In [15]:
!pip install parasail

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [16]:
import parasail

In [17]:
import sys
!{sys.executable} -m pip install parasail

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [18]:
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Split
from transformers import PreTrainedTokenizerFast

# Define vocabulary: ONLY standard 20 amino acids
amino_acids = list("ACDEFGHIKLMNPQRSTVWY")  # Standard 20 only
special_tokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "-"]

# Build vocabulary
vocab = {token: i for i, token in enumerate(special_tokens + amino_acids)}

print(f"Vocab: {vocab}")
print(f"'-' ID: {vocab['-']}")

# Create tokenizer
tokenizer = Tokenizer(WordLevel(vocab=vocab, unk_token="[UNK]"))
tokenizer.pre_tokenizer = Split(pattern="", behavior="isolated")

# Wrap it
hf_tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=tokenizer,
    unk_token="[UNK]",
    pad_token="[PAD]",
    cls_token="[CLS]",
    sep_token="[SEP]",
    mask_token="[MASK]"
)

# Test
print("\nTests:")
print("'-' converts to:", hf_tokenizer.convert_tokens_to_ids("-"))
print("Tokenization of 'ACD-GH':", hf_tokenizer("ACD-GH"))
print("Decoding:", hf_tokenizer.decode([6, 7, 8, 5, 11, 12]))
print("Vocab size:", len(hf_tokenizer))

# hf_tokenizer.save_pretrained("phylo_char_tokenizer")

Vocab: {'[PAD]': 0, '[UNK]': 1, '[CLS]': 2, '[SEP]': 3, '[MASK]': 4, '-': 5, 'A': 6, 'C': 7, 'D': 8, 'E': 9, 'F': 10, 'G': 11, 'H': 12, 'I': 13, 'K': 14, 'L': 15, 'M': 16, 'N': 17, 'P': 18, 'Q': 19, 'R': 20, 'S': 21, 'T': 22, 'V': 23, 'W': 24, 'Y': 25}
'-' ID: 5

Tests:
'-' converts to: 5
Tokenization of 'ACD-GH': {'input_ids': [6, 7, 8, 5, 11, 12], 'token_type_ids': [0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1]}
Decoding: A C D - G H
Vocab size: 26


In [3]:
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Split
from transformers import PreTrainedTokenizerFast

# Define vocabulary: ONLY standard 20 amino acids
amino_acids = list("ACDEFGHIKLMNPQRSTVWY")  # Standard 20 only
special_tokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "[GAP]", "-"]


# Build vocabulary
vocab = {token: i for i, token in enumerate(special_tokens + amino_acids)}

print(f"Vocab: {vocab}")
print(f"'-' ID: {vocab['-']}")
print(f"[GAP] ID: {vocab['[GAP]']}")
print(f"[MASK] ID: {vocab['[MASK]']}")

# Create tokenizer
tokenizer = Tokenizer(WordLevel(vocab=vocab, unk_token="[UNK]"))
tokenizer.pre_tokenizer = Split(pattern="", behavior="isolated")

# Wrap it
hf_tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=tokenizer,
    unk_token="[UNK]",
    pad_token="[PAD]",
    cls_token="[CLS]",
    sep_token="[SEP]",
    mask_token="[MASK]",
    additional_special_tokens=["[GAP]"]  # ðŸ‘ˆ Register [GAP] as special
)

# Test
print("\nTests:")
print("'-' converts to:", hf_tokenizer.convert_tokens_to_ids("-"))
print("[GAP] converts to:", hf_tokenizer.convert_tokens_to_ids("[GAP]"))
print("[MASK] converts to:", hf_tokenizer.convert_tokens_to_ids("[MASK]"))
print("\n")
print("Tokenization of 'ACD-GH':", hf_tokenizer("ACD-GH", add_special_tokens=False))
print("Tokenization of  'ACD[GAP]GH':", hf_tokenizer("ACD[GAP]GH", add_special_tokens=False))
# print("Decoding:", hf_tokenizer.decode([6, 7, 8, 5, 11, 12]))
print("Vocab size:", len(hf_tokenizer))

hf_tokenizer.save_pretrained("phylo_char_tokenizer_updated")

Vocab: {'[PAD]': 0, '[UNK]': 1, '[CLS]': 2, '[SEP]': 3, '[MASK]': 4, '[GAP]': 5, '-': 6, 'A': 7, 'C': 8, 'D': 9, 'E': 10, 'F': 11, 'G': 12, 'H': 13, 'I': 14, 'K': 15, 'L': 16, 'M': 17, 'N': 18, 'P': 19, 'Q': 20, 'R': 21, 'S': 22, 'T': 23, 'V': 24, 'W': 25, 'Y': 26}
'-' ID: 6
[GAP] ID: 5
[MASK] ID: 4

Tests:
'-' converts to: 6
[GAP] converts to: 5
[MASK] converts to: 4


Tokenization of 'ACD-GH': {'input_ids': [7, 8, 9, 6, 12, 13], 'token_type_ids': [0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1]}
Tokenization of  'ACD[GAP]GH': {'input_ids': [7, 8, 9, 5, 12, 13], 'token_type_ids': [0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1]}
Vocab size: 27


('phylo_char_tokenizer_updated/tokenizer_config.json',
 'phylo_char_tokenizer_updated/special_tokens_map.json',
 'phylo_char_tokenizer_updated/tokenizer.json')

In [19]:
import parasail 

In [20]:
import sys
print(sys.executable)

/gpfs/home/as12267/.conda/envs/huggingface_bert/bin/python


In [21]:
import sys
[p for p in sys.path if "conda" in p]

['/gpfs/home/as12267/.conda/envs/huggingface_bert/lib/python310.zip',
 '/gpfs/home/as12267/.conda/envs/huggingface_bert/lib/python3.10',
 '/gpfs/home/as12267/.conda/envs/huggingface_bert/lib/python3.10/lib-dynload',
 '/gpfs/home/as12267/.conda/envs/huggingface_bert/lib/python3.10/site-packages']

In [22]:
import parasail
result = parasail.nw_trace_scan_16("AAAA", "AAAA", 5, 1, parasail.blosum62)

print(dir(result))

['__class__', '__del__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_as_parameter_', '_cigar', '_traceback', '_traceback_args', 'cigar', 'end_query', 'end_ref', 'get_cigar', 'get_traceback', 'len_query', 'len_ref', 'length', 'length_col', 'length_row', 'length_table', 'matches', 'matches_col', 'matches_row', 'matches_table', 'matrix', 'pointer', 'query', 'ref', 'saturated', 'score', 'score_col', 'score_row', 'score_table', 'similar', 'similar_col', 'similar_row', 'similar_table', 'traceback']


In [23]:
import parasail

matrix = parasail.blosum62

def align_pair(seq1, seq2):
    # global Needlemanâ€“Wunsch alignment with traceback
    result = parasail.nw_trace_scan_16(seq1, seq2, 10, 1, matrix)

    # Extract the traceback object
    tb = result.traceback

    # Get aligned sequences
    # a1 = tb.query.replace("-", "<GAP>")
    # a2 = tb.ref.replace("-", "<GAP>")
    a1 = tb.query
    a2 = tb.ref

    return a1, a2

# Test sequences
seq1 = "MKTAYIAKQRQISFVKSHFSRQLEERLGLIE"
seq2 = "MKTVYIAKRRQISFLKSHFSRQLDDRLGLIE"

a1, a2 = align_pair(seq1, seq2)

print("Seq1 Aligned:", a1)
print("Seq2 Aligned:", a2)


Seq1 Aligned: MKTAYIAKQRQISFVKSHFSRQLEERLGLIE
Seq2 Aligned: MKTVYIAKRRQISFLKSHFSRQLDDRLGLIE


In [5]:
from gLM.sequences import align_pair
seq1 = "MKTAYIAKQRQISFVKSHFS"
seq2 = "MKTA---QRQISFVKSHFSSS"

a1, a2 = align_pair(seq1, seq2)

print("Seq1 Aligned:", a1)
print("Seq2 Aligned:", a2)

print(f"Seq1 Aligned with [GAP]:", a1.replace("-", "[GAP]"))
print(f"Seq2 Aligned with [GAP]:", a2.replace("-", "[GAP]"))

Seq1 Aligned: MKTAYIAKQRQISFVKSHF--S
Seq2 Aligned: MKTA----QRQISFVKSHFSSS
Seq1 Aligned with [GAP]: MKTAYIAKQRQISFVKSHF[GAP][GAP]S
Seq2 Aligned with [GAP]: MKTA[GAP][GAP][GAP][GAP]QRQISFVKSHFSSS


In [25]:
from transformers import (
    PreTrainedTokenizerFast,
)

class PhyloTokenizerLoader:
    def __init__(self, tokenizer_path):
        self.tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_path)
    
    def encode(self, aligned_seq1, aligned_seq2):
        inputs = self.tokenizer(
            aligned_seq2,
            add_special_tokens=False
        )
        labels = self.tokenizer(
            aligned_seq1,
            add_special_tokens=False
        )
        return{
            "input_ids": inputs["input_ids"],
            "labels": labels["input_ids"], 
            "attention_mask": inputs["attention_mask"]
        }
    
    def __call__(self, *args, **kwargs):
        return self.tokenizer(*args, **kwargs)

# Set breakpoint inside encode() method
# Then run this:
phylo_tokenizer = PhyloTokenizerLoader("phylo_char_tokenizer")

# When you call encode, debugger will pause at your breakpoint
result = phylo_tokenizer.encode("ATCG", "ATCG")

In [26]:
result

{'input_ids': [6, 22, 7, 11],
 'labels': [6, 22, 7, 11],
 'attention_mask': [1, 1, 1, 1]}

In [27]:
from gLM.dataset import UniRefClusterIterableDataset
from gLM.tokenizers import PhyloTokenizerLoader


In [28]:
tokenizer = PhyloTokenizerLoader("./phylo_char_tokenizer").load()


In [29]:
ds = UniRefClusterIterableDataset(
    parquet_path="/gpfs/data/brandeslab/Data/uniref/uniref90_clusters.parquet",
    fasta_path="/gpfs/data/brandeslab/Data/uniref/uniref100.fasta",
    tokenizer=tokenizer,
    max_seq_len=8192,
)



TypeError: UniRefClusterIterableDataset.__init__() missing 1 required positional argument: 'index_db_path'

In [None]:
for i, x in enumerate(ds):
    print("Got item", i)
    if i == 2:
        break


In [2]:

import time
from gLM.dataset import UniRefClusterIterableDataset
from gLM.tokenizers import PhyloTokenizerLoader

parquet = "/gpfs/data/brandeslab/Data/uniref/uniref90_clusters.parquet"
fasta   = "/gpfs/data/brandeslab/Data/uniref/uniref100.fasta"
index_db_path = "/gpfs/data/brandeslab/User/as12267/uniref100.idx"

print("Loading tokenizer...")
tok = PhyloTokenizerLoader("./phylo_char_tokenizer").load()

print("Instantiating dataset...")
start = time.time()
ds = UniRefClusterIterableDataset(
    parquet_path=parquet,
    index_db_path=index_db_path,
    fasta_path=fasta,
    tokenizer=tok,
    max_seq_len=8192,
    training_type="MLM"
)
print("Dataset constructed in", time.time() - start, "sec")

print("Fetching first item...")
start = time.time()
for i, item in enumerate(ds):
    print("Got first item!")
    break
print("First item fetched in", time.time() - start, "sec")



Loading tokenizer...
Instantiating dataset...
Dataset constructed in 0.04993486404418945 sec
Fetching first item...
In the Uniref Iterable - MLM branch
Got first item!
First item fetched in 0.4544064998626709 sec


In [None]:
from gLM.dataset import UniRefClusterIterableDataset
from gLM.tokenizers import PhyloTokenizerLoader
import time

parquet = "/gpfs/data/brandeslab/Data/uniref/uniref90_clusters.parquet"
fasta   = "/gpfs/data/brandeslab/Data/uniref/uniref100.fasta"
index_db = "/gpfs/data/brandeslab/User/as12267/uniref100.idx"

print("Loading tokenizer...")
tok = PhyloTokenizerLoader("./phylo_char_tokenizer")

ds = UniRefClusterIterableDataset(
    parquet_path=parquet,
    index_db_path=index_db,
    fasta_path=fasta,
    tokenizer=tok,
    max_seq_len=8192,
)

print("Fetching first item...")
for i, x in enumerate(ds):
    print("Success!")
    print(x.keys())
    break



  from .autonotebook import tqdm as notebook_tqdm


Loading tokenizer...
Fetching first item...
REP: A0ABS9P3Y9_9GAMM MEMBERS: [] <class 'list'>
REP: A0A835XQ69_9CHLO MEMBERS: ['A0A835XMW2_9CHLO'] <class 'list'>
PAIR: A0A835XQ69_9CHLO A0A835XMW2_9CHLO
REP: UPI000C86E4D3 MEMBERS: [] <class 'list'>
REP: UPI001260CD7F MEMBERS: [] <class 'list'>
REP: A0A2I1HB31_9GLOM MEMBERS: [] <class 'list'>
REP: A0AAF0EAU2_9BASI MEMBERS: [] <class 'list'>
REP: UPI002ECFBE8A MEMBERS: [] <class 'list'>
REP: UPI0002491D06 MEMBERS: [] <class 'list'>
REP: UPI002625D505 MEMBERS: [] <class 'list'>
REP: A0A917ZIC9_9GAMM MEMBERS: [] <class 'list'>
REP: UPI00188B8639 MEMBERS: [] <class 'list'>
REP: UPI003216E2CD MEMBERS: [] <class 'list'>
REP: A0A1R1K1Z6_ALCXX MEMBERS: ['UPI0029A3963B', 'UPI0005F94960'] <class 'list'>
PAIR: UPI0005F94960 A0A1R1K1Z6_ALCXX
REP: A0ABU7LKA1_9NOCA MEMBERS: ['UPI000B257CD9', 'UPI00364A542B'] <class 'list'>
PAIR: UPI000B257CD9 A0ABU7LKA1_9NOCA
REP: UPI00112C5A48 MEMBERS: [] <class 'list'>
REP: A0A1I2CAG8_9RHOB MEMBERS: [] <class 'list'>


In [1]:
from gLM.tokenizers import PhyloTokenizerLoader

tok = PhyloTokenizerLoader("./phylo_char_tokenizer")

print(tok)                    # should show PreTrainedTokenizerFast
print(tok.mask_token)         # should NOT raise error
print(tok.pad_token_id)
print(tok.convert_tokens_to_ids("-"))
print(tok.get_vocab())
print(tok("ACD-EFG"))


  from .autonotebook import tqdm as notebook_tqdm


<gLM.tokenizers.phylo_tokenizer.PhyloTokenizerLoader object at 0x15543fe5d510>
[MASK]
0
5
{'[UNK]': 1, 'C': 7, '[MASK]': 4, 'T': 22, '[SEP]': 3, '[CLS]': 2, 'H': 12, 'E': 9, 'L': 15, 'F': 10, 'G': 11, 'Y': 25, '[PAD]': 0, 'Q': 19, 'K': 14, 'A': 6, '-': 5, 'P': 18, 'I': 13, 'W': 24, 'R': 20, 'D': 8, 'S': 21, 'M': 16, 'N': 17, 'V': 23}
{'input_ids': [6, 7, 8, 5, 9, 10, 11], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}


In [6]:
tokenizer = PhyloTokenizerLoader("./phylo_char_tokenizer_updated")
pad_id = tokenizer.pad_token_id
mask_id = tokenizer.mask_token_id
non_gap_id = tokenizer.convert_tokens_to_ids("-") 
gap_id = tokenizer.convert_tokens_to_ids("[GAP]")
print("Phylo Tokenizer loaded")
print("Mask ID:", mask_id)
print("Non-GAP ID:", non_gap_id)
print("GAP ID:", gap_id)
print("Tokenizer vocab size:", tokenizer.vocab_size)

Phylo Tokenizer loaded
Mask ID: 4
Non-GAP ID: 6
GAP ID: 5
Tokenizer vocab size: 27


In [1]:
from gLM.dataset import UniRefClusterIterableDataset
from gLM.tokenizers import PhyloTokenizerLoader
from gLM.collator import SequencePairCollator
import time
from torch.utils.data import DataLoader

parquet = "/gpfs/data/brandeslab/Data/uniref/uniref90_clusters.parquet"
fasta   = "/gpfs/data/brandeslab/Data/uniref/uniref100.fasta"
index_db = "/gpfs/data/brandeslab/User/as12267/uniref100.idx"

print("Loading tokenizer...")
# tok = PhyloTokenizerLoader("./phylo_char_tokenizer")
tok = PhyloTokenizerLoader("./phylo_char_tokenizer_updated")

ds = UniRefClusterIterableDataset(
    parquet_path=parquet,
    index_db_path=index_db,
    fasta_path=fasta,
    tokenizer=tok,
    max_seq_len=8192,
    training_type="phylo"
)


collator = SequencePairCollator(pad_id=tok.tokenizer.pad_token_id)

loader = DataLoader(
    ds, 
    batch_size=4, 
    collate_fn=collator,
    num_workers=0

)
batch = next(iter(loader))

print("input_ids shape:", batch["input_ids"].shape)
print("labels shape:", batch["labels"].shape)
print("attention_mask shape:", batch["attention_mask"].shape)
print("percent_identity:", batch["percent_identity"])

  from .autonotebook import tqdm as notebook_tqdm


Loading tokenizer...
In the Uniref Iterable - Phylo branch
In the Uniref Iterable - Phylo branch
In the Uniref Iterable - Phylo branch
Untokenize seq A1 MDTEGFPRQFARTRRFSLGVPREFTVSPDGDRVLFLRSESGVVPRVHLWMYESGGERILTDPA[GAP][GAP][GAP][GAP][GAP][GAP][GAP][GAP][GAP][GAP][GAP][GAP][GAP][GAP][GAP][GAP][GAP][GAP][GAP][GAP][GAP][GAP][GAP]AGVGTYATDRHVRVVAYTVDGSLWTVRTDGGLPRRIQTVGPVRDPRPSPDGTLIAYVTGGALRVVGTDGAGDRPLAEPETAETTYGLADYSAVASIGRSRGYWWSPDSGALLVARVDTSVVERRYLSDPSDPGQSPRSVRYPAAGTANAITSLHLVTVAGGHTPVRLPRQAPAKDAPADSWGLAFEYVVGADWQSGGPVISLQTRDQRTMWVLRVDPVHGSVEQLSRQTDQGWVEFPPGAPLHTASGVLVLPQVRGDERTIRIGGVFAPAGLQVRAVLGSVGEKVWFAASEEPTEVHVWSYEAGHGFERLTQTPGVHTATAGGDTLVLDSRTLGGHAVTVLRDGKQVGCIAVLAEQPLVTPRPVHLTLGKRQLRSRLHLPSWYEPGMAKLPVLLSPYAGPGMQVVTKAHGWYTAVCQWYAEQGFAVLATDGRGTPGRGVRWQRAILGDRLTPVLDDQIDGLHAAARRCDALDLERVGIRGWSFSGYLAAGAVLHRPDVFHAAVAGATPTDRRLYDTYWEERFLGHPDLQPHNYERSSLLPLAEKLTRPLMLVHGLADDNVAPAHTLRLSAELLAAGRPHRVLLLPGVGHLVTGEGVADTLLQLELDFLKSSLGA, length A1: 792
Untokenize length A1: 792
Untokenize s

In [3]:
import time
from torch.utils.data import DataLoader
from gLM.dataset import UniRefClusterIterableDataset
from gLM.tokenizers import PhyloTokenizerLoader
from gLM.collator import SequencePairCollator

# === File paths ===
parquet = "/gpfs/data/brandeslab/Data/uniref/uniref90_clusters.parquet"
fasta   = "/gpfs/data/brandeslab/Data/uniref/uniref100.fasta"
index_db = "/gpfs/data/brandeslab/User/as12267/uniref100.idx"

# === Load tokenizer ===
print("Loading tokenizer...")
tok = PhyloTokenizerLoader("./phylo_char_tokenizer_updated")

# === Dataset ===
ds = UniRefClusterIterableDataset(
    parquet_path=parquet,
    index_db_path=index_db,
    fasta_path=fasta,
    tokenizer=tok,
    max_seq_len=8192,
    training_type="phylo",  # or "MLM"
    batch_size=128  # this is internal batch collection, not DataLoader batch_size
)

# === Collator ===
collator = SequencePairCollator(tok)

# === Dataloader ===
loader = DataLoader(
    ds,
    batch_size=128,  # this is the final batch size the model sees
    collate_fn=collator,
    num_workers=16,  # Use >0 to test multiprocessing
    pin_memory=True
)

# === Test: Get one batch ===
start = time.time()
batch = next(iter(loader))
end = time.time()

# === Inspect batch ===
print("\n--- Batch Info ---")
print("input_ids shape:     ", batch["input_ids"].shape)
print("labels shape:        ", batch["labels"].shape)
print("attention_mask shape:", batch["attention_mask"].shape)
print("percent_identity:    ", batch["percent_identity"])
print(f"Loaded 1 batch in {end - start:.3f} seconds")


Loading tokenizer...



--- Batch Info ---
input_ids shape:      torch.Size([128, 2198])
labels shape:         torch.Size([128, 2198])
attention_mask shape: torch.Size([128, 2198])
percent_identity:     tensor([80.3828, 98.7730, 98.1013, 98.5714, 96.6057, 95.7672, 99.6068, 92.6471,
        99.9545, 99.0654, 96.6728, 98.1366, 99.7462, 98.2507, 92.7602, 95.2381,
        99.0826, 99.3644, 90.8257, 92.8158, 96.7914, 94.7917, 99.2924, 93.0000,
        98.7730, 75.1196, 93.1452, 96.5517, 90.3509, 99.9168, 99.4429, 96.6346,
        91.4286, 99.0937, 97.2414, 91.8841, 89.7674, 92.0489, 94.0887, 98.4375,
        91.1565, 99.3243, 88.2129, 99.0991, 92.6471, 97.5510, 97.5845, 99.5968,
        93.1559, 98.7730, 80.6452, 98.8166, 98.1183, 91.3979, 88.0769, 91.3043,
        97.8261, 99.4012, 74.9231, 99.4220, 93.5484, 99.6689, 99.3072, 90.2017,
        90.1408, 94.3114, 92.0038, 99.3088, 90.7990, 89.1374, 94.8718, 95.0000,
        99.7630, 98.9305, 97.0588, 95.5823, 99.6805, 96.9595, 91.1504, 90.5028,
        93.4138, 92.

In [None]:
tokenizer = PhyloTokenizerLoader("./phylo_char_tokenizer").load()
pad_id = tokenizer.pad_token_id
gap_id = tokenizer.convert_tokens_to_ids("-")
print("Phylo Tokenizer loaded. GAP ID:", gap_id)
print(f"vocab size: {tokenizer.vocab_size}")

In [None]:
from gLM.data_utils.uniref_cluster_sampler import RandomClusterSampler

sampler = RandomClusterSampler(parquet)
for _ in range(5):
    cluster = sampler.sample_clusters()
    print(cluster)


{'cluster_id': 'UniRef90_A0AAD2JM23', 'representative_id': 'A0AAD2JM23_9STRA', 'members': ['A0A9K3LU73_9STRA']}
{'cluster_id': 'UniRef90_A0A2T5GHF6', 'representative_id': 'A0A2T5GHF6_9SPHN', 'members': ['UPI0011B1CF18', 'UPI00177CA10F', 'UPI00335862B2', 'UPI001783B8B9', 'UPI00178503EB', 'UPI002A6B860E']}
{'cluster_id': 'UniRef90_A0ABD3RNB6', 'representative_id': 'A0ABD3RNB6_9LAMI', 'members': []}
{'cluster_id': 'UniRef90_A0ABN7WPN7', 'representative_id': 'A0ABN7WPN7_GIGMA', 'members': ['A0ABN7X348_GIGMA']}
{'cluster_id': 'UniRef90_UPI0031F96DD0', 'representative_id': 'UPI0031F96DD0', 'members': []}


In [None]:
import parasail
print(parasail.__version__)


1.3.4


In [1]:
from gLM.tokenizers import PhyloTokenizerLoader

tokenizer = PhyloTokenizerLoader("./phylo_char_tokenizer_updated")
pad_id = tokenizer.pad_token_id
gap_id = tokenizer.convert_tokens_to_ids("-")
print("Phylo Tokenizer loaded. pad ID:", pad_id)
print("Phylo Tokenizer loaded. GAP ID:", gap_id)
print(f"vocab size: {tokenizer.vocab_size}")

  from .autonotebook import tqdm as notebook_tqdm


Phylo Tokenizer loaded. pad ID: 0
Phylo Tokenizer loaded. GAP ID: 6
vocab size: 27


In [2]:
train_ds = "/gpfs/data/brandeslab/Data/uniref/hf_pairs_uniref90_final/train.jsonl"

import torch
import json
from torch.utils.data import Dataset 


class JsonInMemoryDataset(Dataset):
    def __init__(self, path:str):
        with open(path, 'r') as f:
            self.data = [json.loads(line) for line in f if "seq1" in line and "seq2" in line]
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        ex = self.data[idx]
        return ex["seq1"], ex["seq2"]

train_data = JsonInMemoryDataset(train_ds)


In [7]:
from gLM.sequences.pairwise_align import align_pair, percent_identity
class PhyloCollator:
    def __init__(self, tokenizer, training_type: str, max_seq_len: int):
        self.tokenizer = tokenizer
        self.training_type = training_type
        self.max_seq_len = max_seq_len

    def __call__(self, batch):
        # batch is a list of (seq1 and seq2)

        if self.training_type == "MLM":
            sequences = [s1[:self.max_seq_len] for s1, _ in batch]
            return self.tokenizer(
                sequences, 
                padding="longest", 
                truncation=True,
                eturn_tensors="pt")

        elif self.training_type == "phylo_encoder_only":
            # P(Seq1 | Seq2)
            aligned_pairs = []
            pids = []
            for s1, s2 in batch:
                # per pair alignment
                a1, a2 = align_pair(s1, s2)

                if len(a1) == len(a2):
                    aligned_pairs.append((a1, a2))
                    pid = percent_identity(a1, a2)
                    pids.append(pid)

            # Tokenize aligned pairs (batch)    
            batch_out = self.tokenizer.batch_encode_aligned(
                aligned_pairs, max_length=self.max_seq_len)
            

            print("input_ids shape:", batch_out["input_ids"].shape)
            print("attention_mask shape:", batch_out["attention_mask"].shape)
            print("attention_mask sum:", batch_out["attention_mask"].sum(dim=1))
            print("labels shape:", batch_out["labels"].shape)
            
            return batch_out
        
        elif self.training_type == "phylo_encoder_decoder":
            inputs = [s1[:self.max_seq_len] for s1, _ in batch]
            targets = [s2[:self.max_seq_len] for _, s2 in batch]
            enc = self.tokenizer(
                inputs,
                padding="longest",
                truncation=True,
                max_length=self.max_seq_len,
                return_tensors="pt",
            )
            # print("Encoder seq:", enc)

            dec = self.tokenizer(
                targets,
                padding="longest",
                truncation=True,
                max_length=self.max_seq_len,
                return_tensors="pt",
            )

            # Conver {PAD} tokens in labels to -100
            labels = dec["input_ids"].clone()
            labels[labels == self.tokenizer.pad_token_id] = -100

            print("Encoder input_ids shape:", enc["input_ids"].shape)
            print("Encoder attention_mask sum:", enc["attention_mask"].sum(dim=1))

            print("Decoder labels shape:", dec["input_ids"].shape)
            print("Decoder attention_mask sum:", dec["attention_mask"].sum(dim=1))
            print(
                "Decoder PAD count:",
                (dec["input_ids"] == self.tokenizer.pad_token_id).sum().item()
            )

            # dec["input_ids"] is the raw target sequence with [PAD] = 0
            # labels is the same tensor with [PAD] converted to -100
        
            assert(labels == -100).sum() == (dec["input_ids"] == tokenizer.pad_token_id).sum()
            return {
                "input_ids": enc["input_ids"],
                "attention_mask": enc["attention_mask"],
                "labels": labels,
            }

        else:
            raise ValueError(f"Unsupported training_type: {self.training_type}")


# collator = PhyloCollator(
#     tokenizer=tokenizer,
#     training_type="phylo_encoder_decoder",
#     max_seq_len=512
# )

collator = PhyloCollator(
    tokenizer=tokenizer,
    training_type="phylo_encoder_only",
    max_seq_len=200
)

In [4]:
import torch
from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_data,
    batch_size=2,
    shuffle=True,
    collate_fn=collator,
    pin_memory=True
)

for batch in train_loader:
    # print("input_ids:", batch["input_ids"].shape)
    # print("attention_mask:", batch["attention_mask"].shape)
    # print("labels:", batch["labels"].shape)
    break

example of tokenized input: MRASGAVTRSTLRQQIADALRDEVLAGRLQRGREFTVKQIAEQYGVSATPVREALFDLSAQGLLESDQHRGFRVREFTVADYRSMVEARTLVIDGIIRDVFHGFGPGLSAARAAVYQDGLVSVRRRAQEAARAAQGGDLDILIGYDLRFWRELGALVDNAYINDFLHRLRVQAWVFAVPYLRHDASARDWLWHGHPELVAAITLCDHDAVRAVMDDYNAHALNWADRLAAGMLALPTASAPAGPTNPANPAASGPANTAETSIPTDPAAPGPAISEDPGSGRSGS MRASGAVTRSTLRQQIADALRDEVLAGRLQRGREFTVKQIAEQYGVSATPVREALFDLSAQGLLESDQHRGFRVREFTVADYRSMVEARTLVIDGVIRDVFHGFGPGLSAARAAVYQDGLVSVRRRAQEAARAAQGGDLDILIGYDLRFWRELGALVDNAYINDFLHRLRVQAWVFAVPYLRHDASARDWLWHGHPELVAAITLCDHDAVRAVMDDYNAHALNWADRLAAGMLALPTASAPAGPTNPANPAAPGPANTAETSIPTDPAAPGPAISEEPGSGRSGS
input_ids shape: torch.Size([2, 285])
attention_mask shape: torch.Size([2, 285])
attention_mask sum: tensor([285, 157])
labels shape: torch.Size([2, 285])


In [8]:
max_length = 200  # or your actual self.max_seq_len

import torch
from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_data,
    batch_size=2,
    shuffle=True,
    collate_fn=collator,
    pin_memory=True
)

for batch in train_loader:
    input_lengths = batch["input_ids"].ne(0).sum(dim=1)  # count non-padding tokens
    if (input_lengths > max_length).any():
        print("Found batch with long sequence:")
        print("input_ids shape:", batch["input_ids"].shape)
        print("Longest input length:", input_lengths.max().item())
        break


input_ids shape: torch.Size([2, 200])
attention_mask shape: torch.Size([2, 200])
attention_mask sum: tensor([200, 137])
labels shape: torch.Size([2, 200])
input_ids shape: torch.Size([2, 200])
attention_mask shape: torch.Size([2, 200])
attention_mask sum: tensor([200, 200])
labels shape: torch.Size([2, 200])
input_ids shape: torch.Size([2, 200])
attention_mask shape: torch.Size([2, 200])
attention_mask sum: tensor([129, 200])
labels shape: torch.Size([2, 200])
input_ids shape: torch.Size([2, 183])
attention_mask shape: torch.Size([2, 183])
attention_mask sum: tensor([180, 183])
labels shape: torch.Size([2, 183])
input_ids shape: torch.Size([2, 200])
attention_mask shape: torch.Size([2, 200])
attention_mask sum: tensor([200, 200])
labels shape: torch.Size([2, 200])
input_ids shape: torch.Size([2, 200])
attention_mask shape: torch.Size([2, 200])
attention_mask sum: tensor([200, 200])
labels shape: torch.Size([2, 200])
input_ids shape: torch.Size([2, 200])
attention_mask shape: torch.Size

KeyboardInterrupt: 

In [5]:
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]

print("input_ids shape:", input_ids.shape)
print("attention_mask shape:", attention_mask.shape)
print("labels shape:", labels.shape)

pad_id = tokenizer.pad_token_id
gap_id = tokenizer.convert_tokens_to_ids("[GAP]")
print("Pad ID:", pad_id)
print("GAP ID:", gap_id)
print("PAD count in input_ids", (input_ids == pad_id).sum().item()) # comes from seq 1
print("GAP count in input_ids", (input_ids == gap_id).sum().item()) # comes from seq 1
print("PAD count in labels", (labels == pad_id).sum().item()) # should be 0 
print("GAP count in labels", (labels == gap_id).sum().item()) # comes from seq 2
print("-100 count in labels", (labels == -100).sum().item()) # comes from seq 2


input_ids shape: torch.Size([2, 285])
attention_mask shape: torch.Size([2, 285])
labels shape: torch.Size([2, 285])
Pad ID: 0
GAP ID: 5
PAD count in input_ids 128
GAP count in input_ids 0
PAD count in labels 0
GAP count in labels 0
-100 count in labels 128


In [8]:
print(len(batch["input_ids"][1]))
print(batch["input_ids"][1])

285
tensor([17, 21, 16, 20, 18, 26, 15, 23,  7, 12, 13, 19, 14, 14, 26, 14,  9, 10,
        22, 12, 11, 22, 13,  9, 17, 19, 21, 21, 26, 12, 26, 22, 19, 15, 12, 20,
        21,  8, 26, 12, 20, 13, 18, 25, 20,  7, 21, 12, 21, 23, 18, 14, 17, 12,
         7, 16, 16, 10, 15,  7, 16, 16, 23, 24,  8,  7, 11, 23, 22, 18, 14, 18,
        22,  9, 14, 11, 13,  7, 25, 14, 23, 20,  9, 16, 16, 19, 15, 24, 19, 16,
        18, 22, 24, 17, 24, 17,  9, 18,  7, 22, 11, 13, 15, 21, 15,  9, 14, 20,
         9,  7, 14, 15, 18,  7, 12, 11, 14, 16, 10, 26, 16, 19, 24, 26, 22, 19,
         9, 16, 18, 19, 14, 10, 15, 15, 25,  7, 13,  7, 15,  7, 21, 21, 21, 15,
        10, 21,  8,  9, 24,  9, 13, 16, 11, 22, 11, 11, 18,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0, 

In [9]:
print(len(batch["labels"][1]))
print(batch["labels"][1])

285
tensor([  17,   21,   16,   20,   18,   26,   15,   23,    7,   12,   13,   19,
          14,   14,   26,   14,    9,   10,   22,   12,   11,   22,   13,    9,
          17,   19,   21,   13,   26,   12,   26,   22,   19,   15,   12,   20,
          21,    8,   26,   12,   20,   13,   18,   25,   20,    7,   21,   12,
          21,   23,   18,   14,   14,   12,    7,   16,   16,   10,   15,   10,
          16,   16,   23,   24,    8,    7,   11,   23,   22,   18,   14,   18,
          22,    9,   14,   11,   13,    7,   25,   14,   23,   20,    9,   16,
          16,   19,   15,   24,   19,   16,   18,   22,   24,   17,   24,   17,
           9,   18,    7,   22,   11,   13,   15,   21,   15,    9,   14,   20,
           9,    7,   14,   15,    9,    7,   12,   11,   14,   16,   10,   26,
          16,   19,   24,   26,   22,   19,    9,   16,   18,   19,   14,   10,
          15,   15,   25,    7,   13,    7,   15,    7,   21,   21,   21,   15,
          10,   21,    8,    9,   24