In [1]:
import pandas as pd
import numpy as np
import json
from tqdm import tqdm

In [2]:
df_train_proteins = pd.read_csv("../data/swissprot_proteins_processed_train.csv")
df_val_proteins = pd.read_csv("../data/swissprot_proteins_processed_val.csv")
df_test_proteins = pd.read_csv("../data/swissprot_proteins_processed_test.csv")
dfs_proteins = [df_train_proteins, df_val_proteins, df_test_proteins]

In [3]:
dfs_with_domains = []
for df in dfs_proteins:
    df["domains"] = df["domains"].apply(json.loads)
    dfs_with_domains.append(df[df["n_domains"] > 0].copy())

In [4]:
dfs_with_domains[0].head()

Unnamed: 0,acc_id,length,n_domains,can_generate,sequence,domains
0,A0A009IHW8,269,1,internal_gap,MSLEQKKGADIISKILQIQNSIGKTTSPSTLKTKLSEISRKEQENA...,"[{'start': 133, 'end': 266, 'description': 'TI..."
1,A0A023I7E1,796,1,internal_gap,MRFQVIVAAATITMITSYIPGVASQSTSDGDDLFVPVSNFDPKSIF...,"[{'start': 31, 'end': 759, 'description': 'GH8..."
2,A0A024B7W1,3423,5,internal_gap,MKNPKKKSGGFRIVNMLKRGVARVSPFGGLKRLPAGLLLGHGPIRM...,"[{'start': 1503, 'end': 1680, 'description': '..."
4,A0A044RE18,693,2,internal_gap,MYWQLVRILVLFDCLQKILAIEHDSICIADVDDACPEPSHTVMRLR...,"[{'start': 167, 'end': 482, 'description': 'Pe..."
6,A0A059WI14,161,1,internal_gap,MKYAHVGLNVTNLEKSIEFYSKLFGAEPVKVKPDYAKFLLESPGLN...,"[{'start': 2, 'end': 119, 'description': 'VOC'}]"


In [None]:
# synthetic functions
def generate_synthetic_n_terminal_cut(seq_length, existing_domains):
    """
    Generate a synthetic N-terminal cut point.
    Returns the end position of a synthetic N-terminal domain.
    """
    # Place cut at 10-30% of sequence length
    min_cut = max(10, int(seq_length * 0.10))
    max_cut = min(seq_length // 2, int(seq_length * 0.30))

    # Avoid overlap with existing domains
    cut_point = np.random.randint(min_cut, max_cut + 1)

    # Check for conflicts with existing domains
    for domain in existing_domains:
        if domain['start'] <= cut_point <= domain['end']:
            # Adjust to be before this domain
            cut_point = min(cut_point, domain['start'] - 1)

    return max(10, cut_point)  # Ensure minimum fragment size


def generate_synthetic_c_terminal_cut(seq_length, existing_domains):
    """
    Generate a synthetic C-terminal cut point.
    Returns the start position of a synthetic C-terminal domain.
    """
    # Place cut at 70-90% of sequence length
    min_cut = max(seq_length // 2, int(seq_length * 0.70))
    max_cut = min(seq_length - 10, int(seq_length * 0.90))

    cut_point = np.random.randint(min_cut, max_cut + 1)

    # Check for conflicts with existing domains
    for domain in existing_domains:
        if domain['start'] <= cut_point <= domain['end']:
            # Adjust to be after this domain
            cut_point = max(cut_point, domain['end'] + 1)

    return min(seq_length - 10, cut_point)  # Ensure minimum fragment size


def augment_domains(domains, seq_length):
    """
    Augment domain list with synthetic domains.
    Returns (augmented_domains, has_synthetic_n, has_synthetic_c)
    """
    augmented = domains.copy()

    # Check what's missing
    has_n_terminal = any(d['start'] == 1 for d in domains)
    has_c_terminal = any(d['end'] == seq_length for d in domains)

    synthetic_n = False
    synthetic_c = False

    # Add synthetic N-terminal domain if missing
    if not has_n_terminal:
        cut_end = generate_synthetic_n_terminal_cut(seq_length, domains)
        augmented.append({
            'start': 1,
            'end': cut_end,
            'description': 'SYNTHETIC_N_TERMINAL',
            'synthetic': True
        })
        synthetic_n = True

    # Add synthetic C-terminal domain if missing
    if not has_c_terminal:
        cut_start = generate_synthetic_c_terminal_cut(seq_length, domains)
        augmented.append({
            'start': cut_start,
            'end': seq_length,
            'description': 'SYNTHETIC_C_TERMINAL',
            'synthetic': True
        })
        synthetic_c = True

    # Sort by start position
    augmented = sorted(augmented, key=lambda x: x['start'])

    return augmented, synthetic_n, synthetic_c

In [None]:
# normal fragment generation functions
def generate_terminal_n_fragments(acc_id, sequence, domains):
    fragments = []
    seq_length = len(sequence)
    n_terminal_domains = [d for d in domains if d['start'] == 1]

    for domain in n_terminal_domains:
        fragment_seq = sequence[domain['end']:]
        if len(fragment_seq) >= 10:
            is_synthetic = domain.get('synthetic', False)
            fragments.append({
                'source_accession': acc_id,
                'fragment_type': 'terminal_N',
                'sequence': fragment_seq,
                'is_fragment': 1,
                'removed_region': f"1-{domain['end']}",
                'is_synthetic': is_synthetic
            })
    return fragments

def generate_terminal_c_fragments(acc_id, sequence, domains):
    fragments = []
    seq_length = len(sequence)
    c_terminal_domains = [d for d in domains if d['end'] == seq_length]

    for domain in c_terminal_domains:
        fragment_seq = sequence[:domain['start']-1]
        if len(fragment_seq) >= 10:
            is_synthetic = domain.get('synthetic', False)
            fragments.append({
                'source_accession': acc_id,
                'fragment_type': 'terminal_C',
                'sequence': fragment_seq,
                'is_fragment': 1,
                'removed_region': f"{domain['start']}-{seq_length}",
                'is_synthetic': is_synthetic
            })
    return fragments

def generate_terminal_both_fragments(acc_id, sequence, domains):
    fragments = []
    seq_length = len(sequence)
    n_terminal_domains = [d for d in domains if d['start'] == 1]
    c_terminal_domains = [d for d in domains if d['end'] == seq_length]

    for n_dom in n_terminal_domains:
        for c_dom in c_terminal_domains:
            if n_dom['end'] < c_dom['start']:
                fragment_seq = sequence[n_dom['end']:c_dom['start']-1]
                if len(fragment_seq) >= 10:
                    is_synthetic = n_dom.get('synthetic', False) or c_dom.get('synthetic', False)
                    fragments.append({
                        'source_accession': acc_id,
                        'fragment_type': 'terminal_both',
                        'sequence': fragment_seq,
                        'is_fragment': 1,
                        'removed_region': f"1-{n_dom['end']},{c_dom['start']}-{seq_length}",
                        'is_synthetic': is_synthetic
                    })
    return fragments

def generate_internal_gap_fragments(acc_id, sequence, domains):
    fragments = []
    seq_length = len(sequence)
    internal_domains = [d for d in domains if d['start'] > 1 and d['end'] < seq_length]

    for domain in internal_domains:
        before = sequence[:domain['start']-1]
        after = sequence[domain['end']:]
        fragment_seq = before + after
        if len(fragment_seq) >= 10:
            fragments.append({
                'source_accession': acc_id,
                'fragment_type': 'internal_gap',
                'sequence': fragment_seq,
                'is_fragment': 1,
                'removed_region': f"{domain['start']}-{domain['end']}",
                'is_synthetic': False  # Internal domains are always real
            })
    return fragments

def generate_mixed_fragments(acc_id, sequence, domains):
    fragments = []
    seq_length = len(sequence)

    n_terminal_domains = [d for d in domains if d['start'] == 1]
    c_terminal_domains = [d for d in domains if d['end'] == seq_length]
    internal_domains = [d for d in domains if d['start'] > 1 and d['end'] < seq_length]

    # N-terminal + internal
    for n_dom in n_terminal_domains:
        for int_dom in internal_domains:
            if n_dom['end'] < int_dom['start']:
                middle = sequence[n_dom['end']:int_dom['start']-1]
                after = sequence[int_dom['end']:]
                fragment_seq = middle + after
                if len(fragment_seq) >= 10:
                    is_synthetic = n_dom.get('synthetic', False)
                    fragments.append({
                        'source_accession': acc_id,
                        'fragment_type': 'mixed',
                        'sequence': fragment_seq,
                        'is_fragment': 1,
                        'removed_region': f"1-{n_dom['end']},{int_dom['start']}-{int_dom['end']}",
                        'is_synthetic': is_synthetic
                    })

    # C-terminal + internal
    for c_dom in c_terminal_domains:
        for int_dom in internal_domains:
            if int_dom['end'] < c_dom['start']:
                before = sequence[:int_dom['start']-1]
                middle = sequence[int_dom['end']:c_dom['start']-1]
                fragment_seq = before + middle
                if len(fragment_seq) >= 10:
                    is_synthetic = c_dom.get('synthetic', False)
                    fragments.append({
                        'source_accession': acc_id,
                        'fragment_type': 'mixed',
                        'sequence': fragment_seq,
                        'is_fragment': 1,
                        'removed_region': f"{int_dom['start']}-{int_dom['end']},{c_dom['start']}-{seq_length}",
                        'is_synthetic': is_synthetic
                    })

    # All three
    for n_dom in n_terminal_domains:
        for c_dom in c_terminal_domains:
            for int_dom in internal_domains:
                if n_dom['end'] < int_dom['start'] < int_dom['end'] < c_dom['start']:
                    part1 = sequence[n_dom['end']:int_dom['start']-1]
                    part2 = sequence[int_dom['end']:c_dom['start']-1]
                    fragment_seq = part1 + part2
                    if len(fragment_seq) >= 10:
                        is_synthetic = n_dom.get('synthetic', False) or c_dom.get('synthetic', False)
                        fragments.append({
                            'source_accession': acc_id,
                            'fragment_type': 'mixed',
                            'sequence': fragment_seq,
                            'is_fragment': 1,
                            'removed_region': f"1-{n_dom['end']},{int_dom['start']}-{int_dom['end']},{c_dom['start']}-{seq_length}",
                            'is_synthetic': is_synthetic
                        })
    return fragments

def generate_all_fragments_augmented(acc_id, sequence, domains):
    fragments = []
    fragments.extend(generate_terminal_n_fragments(acc_id, sequence, domains))
    fragments.extend(generate_terminal_c_fragments(acc_id, sequence, domains))
    fragments.extend(generate_terminal_both_fragments(acc_id, sequence, domains))
    fragments.extend(generate_internal_gap_fragments(acc_id, sequence, domains))
    fragments.extend(generate_mixed_fragments(acc_id, sequence, domains))
    return fragments