In [1]:
%%capture
!pip uninstall -y transformers
!pip install git+https://github.com/huggingface/transformers
!pip install biopython 

# plasmids

In [1]:
import pandas as pd
from Bio import SeqIO

In [2]:
plsdb = pd.read_csv('/global/cfs/cdirs/jgirnd/projects/LLMs/mge/plsdb_metadata_Arvind_ArgonneLab.csv.gz', usecols=['ACC_NUCCORE', 'NCBI_ids'])

In [3]:
# remove those without host
plsdb = plsdb.dropna()
plsdb = plsdb.drop_duplicates()

In [4]:
plsdb.shape

(1410, 2)

In [5]:
plsdb.head()

Unnamed: 0,ACC_NUCCORE,NCBI_ids
1704,NZ_AP014570.1,GCF_000828915.1
1708,NZ_AP014682.1,GCF_000829395.1
1709,NZ_AP014681.1,GCF_000829395.1
1712,NZ_AP018444.1,GCF_000534935.2
1713,NZ_AP018444.1,GCF_000534935.1


In [6]:
len(set(plsdb['ACC_NUCCORE']))

1346

In [7]:
len(set(plsdb['NCBI_ids']))

534

# Split Train / Validation / Test

In validation: we evaluate on unknown relationships between known mge/host

In test: we evaluation on unknown relationships between unknown mge/host

In [8]:
import random

all_hosts = list(set(plsdb['NCBI_ids']))
random.shuffle(all_hosts)
hosts_test = all_hosts[:60]
hosts_train_val = all_hosts[60:]


In [9]:
pairs_test = plsdb[plsdb['NCBI_ids'].isin(hosts_test)]
pairs_train_val = plsdb[~plsdb['NCBI_ids'].isin(hosts_test)]

data_test = []
data_val = []
data_train = []

seen_mge = set(pairs_train_val['ACC_NUCCORE'])
for index, row in pairs_test.iterrows():
    if row['ACC_NUCCORE'] in seen_mge:
        data_val.append((row['ACC_NUCCORE'], row['NCBI_ids']))
    else:
        data_test.append((row['ACC_NUCCORE'], row['NCBI_ids']))


In [10]:
all_train_val = []
for index, row in pairs_train_val.iterrows():
    all_train_val.append((row['ACC_NUCCORE'], row['NCBI_ids']))
    
random.shuffle(all_train_val)
num_diff = len(data_test) - len(data_val)
data_val.extend(all_train_val[:num_diff])
data_train.extend(all_train_val[num_diff:])

print(f"Get {len(data_train)} for training, {len(data_val)} for validation, and {len(data_test)} for test")
test_mge = set([p[0] for p in data_test])
train_val_mge = set([p[0] for p in data_val] + [p[0] for p in data_train])
print(f"The overlap between test and training is {len(test_mge.intersection(train_val_mge))}")

Get 1124 for training, 143 for validation, and 143 for test
The overlap between test and training is 0


# Get MGE sequences

In [11]:
import gzip

plasmids = []
with gzip.open('/global/cfs/cdirs/jgirnd/projects/LLMs/mge/plsdb.fna.gz', 'rt') as FASTA:
    for p in SeqIO.parse(FASTA, format='fasta'):
        if p.id in plsdb['ACC_NUCCORE'].values:
            plasmids.append(p)
            
len(plasmids)      

1346

In [12]:
mgeid2seq = {item.id: str(item.seq) for item in plasmids}

# Get Host sequences

In [13]:
import os
MAX_LEN = 10000
path = '/global/cfs/cdirs/jgirnd/projects/LLMs/mge/bacteria_hosts'
genome_list = os.listdir(path)

In [17]:
def generate_subsequences(sequence, sub_length=10000, num_subsequences=800):
    total_subsequences = len(sequence) // sub_length
    
    # Ensure we have enough subsequences to select from
    if total_subsequences < num_subsequences:
        raise ValueError(f"Can't generate {num_subsequences} non-overlapping subsequences of length {sub_length} from the provided sequence of length {len(sequence)}.")
    
    # Randomly select starting indices for the subsequences
    start_indices = random.sample(range(100, total_subsequences * sub_length, sub_length), num_subsequences)
    
    # Extract the subsequences
    subsequences = [sequence[i:i+sub_length] for i in start_indices]
    
    return subsequences

hostid2seqs = {}
# Path to your .fna.gz file
for genome in os.listdir(path):
    hostid = "_".join(genome.split("_")[:2])
    file_path = os.path.join(path, genome)

    # Open the gzip file in text mode
    with gzip.open(file_path, 'rt') as file:
        for record in SeqIO.parse(file, "fasta"):
            original_sequence = str(record.seq)
            break
    
    num_subsequences = min(100, len(original_sequence) // MAX_LEN)
    sequences = generate_subsequences(original_sequence, sub_length=MAX_LEN, num_subsequences=num_subsequences)
    hostid2seqs[hostid] = sequences
    

# Generate Paired Data

In [23]:
MAX_LEN = 10000
def generate_paired_sequences(pairs, mgeid2seq, hostid2seqs):
    all_sequences = []
    for mgeid, hostid in pairs:
        for host_seq in hostid2seqs[hostid]:
            all_sequences.append([mgeid2seq[mgeid], host_seq])
    
    return all_sequences
            
sequences_train = generate_paired_sequences(data_train,  mgeid2seq, hostid2seqs)
sequences_val = generate_paired_sequences(data_val,  mgeid2seq, hostid2seqs)
sequences_test = generate_paired_sequences(data_test,  mgeid2seq, hostid2seqs)

print(f"Get {len(sequences_train)} for training, {len(sequences_val)} for validation, and {len(sequences_test)} for test")


Get 107909 for training, 13555 for validation, and 14223 for test


In [26]:
import csv

output_path = "/pscratch/sd/z/zhihanz/data/mge_host/len10k_host100"

with open(os.path.join(output_path, "train.csv"), "w") as f:
    f_w = csv.writer(f)
    f_w.writerows(sequences_train)
    
with open(os.path.join(output_path, "val.csv"), "w") as f:
    f_w = csv.writer(f)
    f_w.writerows(sequences_val)
    
with open(os.path.join(output_path, "test.csv"), "w") as f:
    f_w = csv.writer(f)
    f_w.writerows(sequences_test)