In [16]:
import ijson
from tqdm import tqdm
from collections import defaultdict

In [2]:
TARGET_FRAGMENT_TYPES = [
    'terminal_N',
    'terminal_C',
    'terminal_both',
    'internal_gap',
    'mixed'
]

In [8]:
# load current valid ids
TRAIN_SET = set()
with open("../data/train.txt", "r") as f:
    for line in f.readlines():
        TRAIN_SET.add(line[:-1])

VAL_SET = set()
with open("../data/val.txt", "r") as f:
    for line in f.readlines():
        VAL_SET.add(line[:-1])

TEST_SET = set()
with open("../data/test.txt", "r") as f:
    for line in f.readlines():
        TEST_SET.add(line[:-1])

COMPLETE = TRAIN_SET.union(VAL_SET).union(TEST_SET)

print(len(TRAIN_SET))
print(len(VAL_SET))
print(len(TEST_SET))
print(len(COMPLETE))
print(len(COMPLETE) == len(TRAIN_SET) + len(VAL_SET) + len(TEST_SET))

243444
60861
76077
380382
True


In [13]:
def extract_domains(features):
    domains = []
    for feature in features:
        if feature.get('type') == 'Domain':
            location = feature.get('location', {})
            start_info = location.get('start', {})
            end_info = location.get('end', {})

            # Get start and end positions
            start = start_info.get('value')
            end = end_info.get('value')

            if start is not None and end is not None:
                domains.append({
                    "start": int(start),
                    "end": int(end),
                    "description": feature.get("description", "")
                })

    return sorted(domains, key=lambda x: x["start"])


def categorize_protein(domains, seq_length):
    if not domains:
        return []

    fragment_types = []

    has_n_term = False
    has_c_term = False
    has_intern = False

    for domain in domains:
        if domain["start"] == 1:
            has_n_term = True
        elif domain["end"] == seq_length:
            has_c_term = True
        else:
            if domain["start"] > 1 and domain["end"] < seq_length:
                has_intern = True

    if has_n_term:
        fragment_types.append("terminal_N")

    if has_c_term:
        fragment_types.append("terminal_C")

    if has_n_term and has_c_term:
        fragment_types.append("terminal_both")

    if has_intern:
        fragment_types.append("internal_gap")

    if (has_n_term or has_c_term) and has_intern:
        fragment_types.append("mixed")

    return fragment_types


def is_valid_protein(record):
    protein_desc = record.get('proteinDescription', {})
    if protein_desc.get('flag') == 'Fragment':
        return False

    sequence = record.get('sequence', {}).get('value')
    if not sequence:
        return False

    if len(sequence) < 10:
        return False

    return True

In [17]:
train_protein_pools = defaultdict(list)
train_protein_data = []

val_protein_pools = defaultdict(list)
val_protein_data = []

test_protein_pools = defaultdict(list)
test_protein_data = []

stats = {
    'total_processed': 0,
    'valid_proteins': 0,
    'fragments_skipped': 0,
    'no_sequence': 0,
    'too_short': 0,
    'with_domains': 0,
    'without_domains': 0
}

print("Loading SwissProt database...")

with open("../data/uniprotkb_reviewed_true_2025_11_07.json", 'r') as f:
    # Use tqdm for progress bar
    for record in tqdm(ijson.items(f, 'results.item'), desc="Processing proteins"):
        stats['total_processed'] += 1

        if not record["primaryAccession"] in COMPLETE:
            continue

        # Check if protein is valid
        if not is_valid_protein(record):
            protein_desc = record.get('proteinDescription', {})
            if protein_desc.get('flag') == 'Fragment':
                stats['fragments_skipped'] += 1
            elif not record.get('sequence', {}).get('value'):
                stats['no_sequence'] += 1
            elif len(record.get("sequence", {}).get("value")) < 10:
                stats['too_short'] += 1
            continue

        stats['valid_proteins'] += 1

        # Extract protein info
        acc_id = record['primaryAccession']
        sequence = record['sequence']['value']
        seq_length = len(sequence)
        features = record.get('features', [])

        # Extract domain annotations
        domains = extract_domains(features)

        if domains:
            stats['with_domains'] += 1
        else:
            stats['without_domains'] += 1

        # Store protein data
        protein_info = {
            'acc_id': acc_id,
            'sequence': sequence,
            'length': seq_length,
            'domains': domains,
            'n_domains': len(domains)
        }

        # Categorize protein by fragment generation capability
        fragment_types = categorize_protein(domains, seq_length)
        protein_info['can_generate'] = fragment_types

        if acc_id in TRAIN_SET:
            train_protein_data.append(protein_info)

            # Add to appropriate pools
            for ftype in fragment_types:
                train_protein_pools[ftype].append(protein_info)

        elif acc_id in TEST_SET:
            test_protein_data.append(protein_info)

            # Add to appropriate pools
            for ftype in fragment_types:
                test_protein_pools[ftype].append(protein_info)

        elif acc_id in VAL_SET:
            val_protein_data.append(protein_info)

            # Add to appropriate pools
            for ftype in fragment_types:
                val_protein_pools[ftype].append(protein_info)

print("\n✓ Loading complete!")

Loading SwissProt database...


Processing proteins: 573661it [01:12, 7922.96it/s] 


✓ Loading complete!





In [18]:
# stats
print("SwissProt Database Statistics:")
print("="*60)
print(f"Total records processed:      {stats['total_processed']:>10,}")
print(f"Valid complete proteins:      {stats['valid_proteins']:>10,}")
print(f"  - With domain annotations:  {stats['with_domains']:>10,}")
print(f"  - Without domain annotations:{stats['without_domains']:>10,}")
print(f"\nSkipped records:")
print(f"  - Already fragments:        {stats['fragments_skipped']:>10,}")
print(f"  - No sequence:              {stats['no_sequence']:>10,}")
print(f"  - Too short (<10 aa):       {stats['too_short']:>10,}")
print("="*60)

SwissProt Database Statistics:
Total records processed:         573,661
Valid complete proteins:         373,953
  - With domain annotations:      92,823
  - Without domain annotations:   281,130

Skipped records:
  - Already fragments:             5,980
  - No sequence:                       0
  - Too short (<50 aa):              449


In [23]:
print("\nProtein Pool Sizes by Fragment Type (TRAIN):")
print("="*60)
for ftype in TARGET_FRAGMENT_TYPES:
    count = len(train_protein_pools[ftype])
    pct = (count / len(TRAIN_SET) * 100) if stats['valid_proteins'] > 0 else 0
    print(f"{ftype:20s}: {count:>8,} proteins ({pct:>5.2f}%)")
print("="*60)
print(f"{'Total protein entries':20s}: {len(train_protein_data):>8,}")
print("\nNote: A single protein can belong to multiple pools")


Protein Pool Sizes by Fragment Type (TRAIN):
terminal_N          :    4,188 proteins ( 1.72%)
terminal_C          :    7,526 proteins ( 3.09%)
terminal_both       :      276 proteins ( 0.11%)
internal_gap        :   51,520 proteins (21.16%)
mixed               :    3,619 proteins ( 1.49%)
Total protein entries:  239,300

Note: A single protein can belong to multiple pools


In [24]:
print("\nProtein Pool Sizes by Fragment Type (VAL):")
print("="*60)
for ftype in TARGET_FRAGMENT_TYPES:
    count = len(val_protein_pools[ftype])
    pct = (count / len(VAL_SET) * 100) if stats['valid_proteins'] > 0 else 0
    print(f"{ftype:20s}: {count:>8,} proteins ({pct:>5.2f}%)")
print("="*60)
print(f"{'Total protein entries':20s}: {len(val_protein_data):>8,}")
print("\nNote: A single protein can belong to multiple pools")


Protein Pool Sizes by Fragment Type (VAL):
terminal_N          :    1,108 proteins ( 1.82%)
terminal_C          :    1,822 proteins ( 2.99%)
terminal_both       :       71 proteins ( 0.12%)
internal_gap        :   12,958 proteins (21.29%)
mixed               :      938 proteins ( 1.54%)
Total protein entries:   59,857

Note: A single protein can belong to multiple pools


In [25]:
print("\nProtein Pool Sizes by Fragment Type (TEST):")
print("="*60)
for ftype in TARGET_FRAGMENT_TYPES:
    count = len(test_protein_pools[ftype])
    pct = (count / len(TEST_SET) * 100) if stats['valid_proteins'] > 0 else 0
    print(f"{ftype:20s}: {count:>8,} proteins ({pct:>5.2f}%)")
print("="*60)
print(f"{'Total protein entries':20s}: {len(test_protein_data):>8,}")
print("\nNote: A single protein can belong to multiple pools")


Protein Pool Sizes by Fragment Type (TEST):
terminal_N          :    1,300 proteins ( 1.71%)
terminal_C          :    2,267 proteins ( 2.98%)
terminal_both       :       85 proteins ( 0.11%)
internal_gap        :   16,255 proteins (21.37%)
mixed               :    1,132 proteins ( 1.49%)
Total protein entries:   74,796

Note: A single protein can belong to multiple pools


In [27]:
import pandas as pd
import json

In [28]:
# save data

df_train_proteins = pd.DataFrame([{
    "acc_id": p["acc_id"],
    "length": p["length"],
    "n_domains": p["n_domains"],
    "can_generate": ",".join(p["can_generate"] if p["can_generate"] else ""),
    "sequence": p["sequence"],
    "domains": json.dumps(p["domains"])
} for p in train_protein_data])

df_val_proteins = pd.DataFrame([{
    "acc_id": p["acc_id"],
    "length": p["length"],
    "n_domains": p["n_domains"],
    "can_generate": ",".join(p["can_generate"] if p["can_generate"] else ""),
    "sequence": p["sequence"],
    "domains": json.dumps(p["domains"])
} for p in val_protein_data])

df_test_proteins = pd.DataFrame([{
    "acc_id": p["acc_id"],
    "length": p["length"],
    "n_domains": p["n_domains"],
    "can_generate": ",".join(p["can_generate"] if p["can_generate"] else ""),
    "sequence": p["sequence"],
    "domains": json.dumps(p["domains"])
} for p in test_protein_data])

df_train_proteins.to_csv("../data/swissprot_proteins_processed_train.csv", index=False)
df_val_proteins.to_csv("../data/swissprot_proteins_processed_val.csv", index=False)
df_test_proteins.to_csv("../data/swissprot_proteins_processed_test.csv", index=False)

train_pool_assignments = {}
val_pool_assignments = {}
test_pool_assignments = {}
for ftype in TARGET_FRAGMENT_TYPES:
    train_pool_assignments[ftype] = [p["acc_id"] for p in train_protein_pools[ftype]]
    val_pool_assignments[ftype] = [p["acc_id"] for p in val_protein_pools[ftype]]
    test_pool_assignments[ftype] = [p["acc_id"] for p in test_protein_pools[ftype]]

with open('../data/protein_pool_assignments_train.json', 'w') as f:
    json.dump(train_pool_assignments, f, indent=2)

with open('../data/protein_pool_assignments_val.json', 'w') as f:
    json.dump(val_pool_assignments, f, indent=2)

with open('../data/protein_pool_assignments_test.json', 'w') as f:
    json.dump(test_pool_assignments, f, indent=2)