# Format Data from Train/Test Split

## Retrieve Full Sequences

ID files are in the format `{train/test}_ids.txt`. Load the ID files and strip anything after and including the "/" on each line.

In [4]:
def format_data(data_dir, file):
    seq_list = []
    with open(data_dir + "/" + file, "r") as f:
        file_stem = file.split(".")[0]
        for line in f:
            seq_list.append(line)
    # eliminate duplicates from seq_list
    print(len(seq_list))
    seq_list = list(set(seq_list))
    print(len(seq_list))
    # write to file
    with open(data_dir + f"/{file_stem}_unique.txt", "w") as output_file:
        for seq in seq_list:
            output_file.write(seq)


In [4]:
data_dir = "../data"
file = "train_ids.txt"

format_data(data_dir, file)

557150
521296


~36k sequences in train had more than one domain and were thus represented more than once

In [5]:
data_dir = "../data"
file = "test_ids.txt"

format_data(data_dir, file)

187258
180997


Since originally split train/test on domains, there may be some proteins that have domains in both train and test. These sequences should be removed from test.

In [6]:
with open('../data/train_ids_unique.txt', 'r') as train_file, open('../data/test_ids_unique.txt', 'r') as test_file:
    train_sequences = set(line.strip() for line in train_file)
    test_sequences = set(line.strip() for line in test_file)

# Find sequence names that are in both files
common_sequences = train_sequences & test_sequences

# remove common sequences from test_ids_unique.txt 
with open('../data/test_ids_unique.txt', 'r') as test_file, open('../data/test_ids_unique_no_common.txt', 'w') as test_file_no_common:
    for line in test_file:
        if line.strip() not in common_sequences:
            test_file_no_common.write(line)

Use `esl-sfetch` to retrieve the full sequences from UniProt using the `{train/test}_ids.txt` seqfile. We use the UniProt release from April, 2022

In [9]:
!sbatch ../bash_scripts/fetch_seqs.sh

Submitted batch job 20744256


Retrieved sequences are in the format `{train/test}_ids_full.fasta`

## Split Data

In [10]:
import sys
sys.path.insert(0, '../library')
import hmmscan_utils as utils

data_dir = "../data"
fasta_file = "test_ids_full.fasta"
num_jobs = 50

utils.split_fasta_file(fasta_file, data_dir, num_jobs)

Split files will be named `split_{i}_train_ids_full.fasta`

## HMM Scan the split FASTA files

In [11]:
data_dir="../data"
fasta_file="test_ids_full.fasta"
!sbatch --array=1-$num_jobs ../bash_scripts/hmmscan.sh $data_dir $fasta_file

Submitted batch job 20747024


In [1]:
import pickle
from pathlib import Path

In [2]:
with open(Path('../data_esm_decoder') / 'maps.pkl', 'rb') as f:
    maps = pickle.load(f)

In [35]:
maps.keys()

dict_keys(['fam_idx', 'clan_idx', 'fam_clan', 'clan_count', 'clan_fam', 'idx_clan', 'idx_fam'])

In [5]:
maps["clan_idx"].keys()

dict_keys(['CL0001', 'CL0003', 'CL0004', 'CL0005', 'CL0006', 'CL0007', 'CL0009', 'CL0010', 'CL0011', 'CL0012', 'CL0013', 'CL0014', 'CL0015', 'CL0016', 'CL0018', 'CL0020', 'CL0021', 'CL0022', 'CL0023', 'CL0025', 'CL0026', 'CL0027', 'CL0028', 'CL0029', 'CL0030', 'CL0031', 'CL0032', 'CL0033', 'CL0034', 'CL0035', 'CL0036', 'CL0037', 'CL0039', 'CL0040', 'CL0041', 'CL0042', 'CL0043', 'CL0044', 'CL0045', 'CL0046', 'CL0047', 'CL0048', 'CL0049', 'CL0050', 'CL0051', 'CL0052', 'CL0053', 'CL0054', 'CL0055', 'CL0056', 'CL0057', 'CL0058', 'CL0059', 'CL0060', 'CL0061', 'CL0062', 'CL0063', 'CL0064', 'CL0065', 'CL0066', 'CL0067', 'CL0068', 'CL0069', 'CL0070', 'CL0071', 'CL0072', 'CL0073', 'CL0074', 'CL0075', 'CL0076', 'CL0077', 'CL0078', 'CL0079', 'CL0080', 'CL0081', 'CL0082', 'CL0083', 'CL0084', 'CL0085', 'CL0086', 'CL0087', 'CL0088', 'CL0089', 'CL0090', 'CL0091', 'CL0092', 'CL0093', 'CL0094', 'CL0095', 'CL0096', 'CL0097', 'CL0098', 'CL0099', 'CL0100', 'CL0101', 'CL0103', 'CL0104', 'CL0105', 'CL0106',

In [7]:
maps["fam_clan"].keys()

dict_keys(['PF00008.30', 'PF00053.27', 'PF00084.23', 'PF00594.23', 'PF01414.22', 'PF04863.16', 'PF06247.14', 'PF07645.18', 'PF07699.16', 'PF07974.16', 'PF09014.13', 'PF09064.13', 'PF09289.13', 'PF09443.13', 'PF12661.10', 'PF12662.10', 'PF12946.10', 'PF12947.10', 'PF14670.9', 'PF18193.4', 'PF18372.4', 'PF18720.4', 'PF00536.33', 'PF01698.19', 'PF02198.19', 'PF04904.16', 'PF07647.20', 'PF09235.13', 'PF09597.13', 'PF13543.9', 'PF18016.4', 'PF18017.4', 'PF18255.4', 'PF18609.4', 'PF00054.26', 'PF00139.22', 'PF00262.21', 'PF00337.25', 'PF00354.20', 'PF00426.21', 'PF00457.20', 'PF00622.31', 'PF00629.26', 'PF00722.24', 'PF00840.23', 'PF01670.19', 'PF01828.20', 'PF02210.27', 'PF02973.19', 'PF03388.16', 'PF03935.18', 'PF05592.14', 'PF05735.15', 'PF06439.14', 'PF06955.15', 'PF07081.14', 'PF07177.15', 'PF07622.14', 'PF07675.14', 'PF07953.15', 'PF08244.15', 'PF08787.14', 'PF08978.13', 'PF09101.13', 'PF09206.14', 'PF09224.14', 'PF09264.13', 'PF10287.12', 'PF11476.11', 'PF11958.11', 'PF12248.11', 'PF1

In [9]:
# Initialize an empty dictionary for the reverse map
reverse_map = {}
# Create the reverse map
for family, clan in maps["fam_clan"].items():
    if clan not in reverse_map:
        reverse_map[clan] = [family]
    else:
        reverse_map[clan].append(family)

print("Reverse map:", reverse_map)

Reverse map: {'CL0001': ['PF00008.30', 'PF00053.27', 'PF00084.23', 'PF00594.23', 'PF01414.22', 'PF04863.16', 'PF06247.14', 'PF07645.18', 'PF07699.16', 'PF07974.16', 'PF09014.13', 'PF09064.13', 'PF09289.13', 'PF09443.13', 'PF12661.10', 'PF12662.10', 'PF12946.10', 'PF12947.10', 'PF14670.9', 'PF18193.4', 'PF18372.4', 'PF18720.4'], 'CL0003': ['PF00536.33', 'PF01698.19', 'PF02198.19', 'PF04904.16', 'PF07647.20', 'PF09235.13', 'PF09597.13', 'PF13543.9', 'PF18016.4', 'PF18017.4', 'PF18255.4', 'PF18609.4'], 'CL0004': ['PF00054.26', 'PF00139.22', 'PF00262.21', 'PF00337.25', 'PF00354.20', 'PF00426.21', 'PF00457.20', 'PF00622.31', 'PF00629.26', 'PF00722.24', 'PF00840.23', 'PF01670.19', 'PF01828.20', 'PF02210.27', 'PF02973.19', 'PF03388.16', 'PF03935.18', 'PF05592.14', 'PF05735.15', 'PF06439.14', 'PF06955.15', 'PF07081.14', 'PF07177.15', 'PF07622.14', 'PF07675.14', 'PF07953.15', 'PF08244.15', 'PF08787.14', 'PF08978.13', 'PF09101.13', 'PF09206.14', 'PF09224.14', 'PF09264.13', 'PF10287.12', 'PF11476

In [11]:
maps["clan_fam"] = reverse_map

In [26]:
idx_clan = {idx: clan for clan, idx in maps["clan_idx"].items()}

In [28]:
idx_fam = {idx: fam for fam, idx in maps["fam_idx"].items()}

In [32]:
maps["idx_clan"] = idx_clan
maps["idx_fam"] = idx_fam

In [33]:
with open(Path('../data_esm_decoder') / 'maps.pkl', 'wb') as f:
    pickle.dump(maps, f)

In [14]:
f_list = [0] * len(maps["clan_idx"])
for clan, idx in maps["clan_idx"].items():
            f_list[idx] = len(maps["clan_fam"][clan]) 

In [15]:
f_list

[22,
 12,
 51,
 3,
 4,
 14,
 3,
 47,
 34,
 17,
 6,
 19,
 26,
 40,
 3,
 252,
 113,
 18,
 245,
 11,
 15,
 12,
 75,
 69,
 8,
 16,
 30,
 6,
 17,
 17,
 61,
 21,
 32,
 11,
 7,
 7,
 7,
 27,
 5,
 14,
 3,
 20,
 33,
 15,
 46,
 17,
 29,
 15,
 42,
 13,
 43,
 59,
 30,
 13,
 16,
 21,
 209,
 15,
 12,
 23,
 8,
 2,
 2,
 12,
 2,
 79,
 10,
 6,
 15,
 5,
 2,
 5,
 14,
 7,
 10,
 6,
 36,
 23,
 10,
 3,
 4,
 10,
 12,
 8,
 4,
 2,
 9,
 8,
 3,
 2,
 6,
 14,
 4,
 7,
 7,
 32,
 14,
 26,
 14,
 7,
 34,
 33,
 53,
 24,
 6,
 51,
 8,
 6,
 44,
 10,
 5,
 13,
 4,
 381,
 26,
 78,
 74,
 10,
 19,
 14,
 4,
 11,
 4,
 3,
 7,
 16,
 25,
 18,
 3,
 5,
 7,
 10,
 11,
 12,
 2,
 3,
 5,
 9,
 11,
 7,
 18,
 2,
 2,
 2,
 11,
 257,
 3,
 18,
 2,
 9,
 2,
 23,
 2,
 89,
 6,
 15,
 7,
 3,
 65,
 6,
 42,
 13,
 4,
 27,
 19,
 26,
 17,
 22,
 17,
 22,
 112,
 7,
 9,
 6,
 6,
 53,
 117,
 5,
 2,
 9,
 6,
 23,
 5,
 3,
 6,
 71,
 5,
 17,
 2,
 4,
 4,
 14,
 25,
 7,
 3,
 2,
 21,
 2,
 70,
 31,
 33,
 8,
 8,
 4,
 4,
 5,
 4,
 5,
 44,
 5,
 6,
 2,
 4,
 2,
 3,
 149,
 13,
 3,


In [18]:
count = sum(1 for x in f_list if x > 300)
count


2

In [5]:
intermediate = {maps['clan_idx'][x]: y for x,y in maps["clan_fam"].items()}

In [8]:
fam_idx = {}
fam_list = []
for i in range(len(intermediate)):
    fam_list += intermediate[i]

for i in range(len(fam_list)):
    fam_idx[fam_list[i]] = i

In [7]:
len(intermediate)

657

In [12]:
fam_idx['IDR']

19632

In [13]:
maps["fam_idx"] == fam_idx

True

In [19]:
import torch

In [20]:
# Get the number of clans and families
num_clans = len(maps['clan_idx'])
num_families = len(maps['fam_idx'])

# Create a tensor of shape (c, f) filled with zeros
clan_fam_matrix = torch.zeros(num_clans, num_families)

# Iterate over the clan_fam dictionary
for clan_id, fam_ids in maps['clan_fam'].items():
    # Get the index of the current clan
    clan_idx = maps['clan_idx'][clan_id]
    # Iterate over the family ids in the current clan
    fam_idxs = [maps['fam_idx'][fam_id] for fam_id in fam_ids]
    clan_fam_matrix[clan_idx, fam_idxs] = 1


In [26]:
maps['clan_family_matrix'] = clan_fam_matrix

In [28]:
with open(Path('../data_esm_decoder') / 'maps.pkl', 'wb') as f:
    pickle.dump(maps, f)

In [27]:
maps.keys()

dict_keys(['fam_idx', 'clan_idx', 'fam_clan', 'clan_count', 'clan_fam', 'idx_clan', 'idx_fam', 'clan_family_matrix'])

In [24]:
sum_matrix = torch.sum(clan_fam_matrix, dim=1)


In [25]:
sum_matrix

tensor([2.2000e+01, 1.2000e+01, 5.1000e+01, 3.0000e+00, 4.0000e+00, 1.4000e+01,
        3.0000e+00, 4.7000e+01, 3.4000e+01, 1.7000e+01, 6.0000e+00, 1.9000e+01,
        2.6000e+01, 4.0000e+01, 3.0000e+00, 2.5200e+02, 1.1300e+02, 1.8000e+01,
        2.4500e+02, 1.1000e+01, 1.5000e+01, 1.2000e+01, 7.5000e+01, 6.9000e+01,
        8.0000e+00, 1.6000e+01, 3.0000e+01, 6.0000e+00, 1.7000e+01, 1.7000e+01,
        6.1000e+01, 2.1000e+01, 3.2000e+01, 1.1000e+01, 7.0000e+00, 7.0000e+00,
        7.0000e+00, 2.7000e+01, 5.0000e+00, 1.4000e+01, 3.0000e+00, 2.0000e+01,
        3.3000e+01, 1.5000e+01, 4.6000e+01, 1.7000e+01, 2.9000e+01, 1.5000e+01,
        4.2000e+01, 1.3000e+01, 4.3000e+01, 5.9000e+01, 3.0000e+01, 1.3000e+01,
        1.6000e+01, 2.1000e+01, 2.0900e+02, 1.5000e+01, 1.2000e+01, 2.3000e+01,
        8.0000e+00, 2.0000e+00, 2.0000e+00, 1.2000e+01, 2.0000e+00, 7.9000e+01,
        1.0000e+01, 6.0000e+00, 1.5000e+01, 5.0000e+00, 2.0000e+00, 5.0000e+00,
        1.4000e+01, 7.0000e+00, 1.0000e+

In [1]:
import torch

In [None]:
preds = torch.zeros((19632))
preds[:11863] = torch.randn(150)
target = torch.tensor(0)
loss = torch.nn.functional.cross_entropy(preds,target)

In [None]:
loss

tensor(10.9370)

In [23]:
preds = torch.zeros((19632))
preds[:11863] = 100
target = torch.tensor(0)
loss = torch.nn.functional.cross_entropy(preds,target)

In [24]:
loss

tensor(9.3812)