In [16]:
import os
from src.data import RNADataset
from random import shuffle

from torch.utils.data import DataLoader, random_split


In [17]:
def create_data_set(base_path, data_list):
    sequences = []
    dbns = []
    for data_point in data_list:
        path = os.path.join(base_path, data_point)
        with open(path, 'r') as f:
            all_lines = f.readlines()
        sequences.append(all_lines[3].strip())
        dbns.append(all_lines[4].strip())
        
    dataset = RNADataset(sequences, dbns)
    return dataset

In [18]:
data_path = "dbnFiles"
all_data = os.listdir(data_path)
shuffle(all_data)
total_data_length = len(all_data)


In [41]:
# take 90% for training 5% for validation and 5% for testing
train_size = int(total_data_length * 0.98)
val_size   = int((total_data_length - train_size) // 1.1)

print(train_size, val_size)

100271 1860


In [42]:
train_data = all_data[:train_size]
val_data = all_data[train_size:train_size+val_size]
test_data = all_data[train_size+val_size:]

print(len(train_data), len(val_data), len(test_data))

100271 1860 187


In [43]:
train_ds = create_data_set(data_path, train_data)
val_ds   = create_data_set(data_path, val_data)
test_ds  = create_data_set(data_path, test_data)

print(len(train_ds), len(val_ds), len(test_ds))

100271 1860 187


In [44]:
test_ds.info()


=== Basic Dataset Statistics ===
+-------------------+--------+
|      Metric       | Value  |
+-------------------+--------+
|  Total Sequences  |  187   |
| Total Nucleotides | 57366  |
|  Average Length   | 306.77 |
|    Min Length     |   15   |
|    Max Length     |  1723  |
|  Std Dev Length   | 467.57 |
+-------------------+--------+

=== Nucleotide Composition ===
+------------+-------+-----------+
| Nucleotide | Count | Frequency |
+------------+-------+-----------+
|     A      | 13791 |  24.04%   |
|     U      | 12403 |  21.62%   |
|     G      | 17624 |  30.72%   |
|     C      | 13489 |  23.51%   |
| GC Content |   -   |  54.24%   |
+------------+-------+-----------+

=== Structure Composition ===
+---------+-------+-----------+
| Element | Count | Frequency |
+---------+-------+-----------+
|    .    | 24624 |  42.92%   |
|    (    | 16074 |  28.02%   |
|    )    | 16074 |  28.02%   |
+---------+-------+-----------+

=== Base Pair Statistics ===
+-----------------------

In [45]:
# train_ds.save("data/train_dataset.pkl", format="pickle")
# val_ds.save("data/val_dataset.pkl", format="pickle")
test_ds.save("data/test_dataset_small.pkl", format="pickle")

Dataset saved to data/test_dataset_small.pkl


In [9]:
res = val_ds[100]
res.keys()

dict_keys(['sequence', 'structure', 'attention_mask', 'length', 'raw_sequence', 'raw_structure'])

In [10]:
res['sequence'].shape

torch.Size([5000, 5])

In [11]:
res["structure"].shape

torch.Size([5000])

In [15]:
val_ds.struct_to_idx

{'.': 0,
 '(': 1,
 '>': 2,
 '{': 3,
 ']': 4,
 '[': 5,
 '}': 6,
 '<': 7,
 ')': 8,
 'A': 9,
 'a': 10,
 'B': 11,
 'b': 12,
 'C': 13,
 'c': 14,
 'D': 15,
 'd': 16,
 'E': 17,
 'e': 18,
 'F': 19,
 'f': 20,
 'G': 21,
 'g': 22,
 'H': 23,
 'h': 24,
 'I': 25,
 'i': 26,
 'J': 27,
 'j': 28,
 'K': 29,
 'k': 30,
 'L': 31,
 'l': 32,
 'M': 33,
 'm': 34,
 'N': 35,
 'n': 36,
 'O': 37,
 'o': 38,
 'P': 39,
 'p': 40,
 'Q': 41,
 'q': 42,
 'R': 43,
 'r': 44,
 'S': 45,
 's': 46,
 'T': 47,
 't': 48,
 'U': 49,
 'u': 50,
 'V': 51,
 'v': 52,
 'W': 53,
 'w': 54,
 'X': 55,
 'x': 56,
 'Y': 57,
 'y': 58,
 'Z': 59,
 'z': 60}