In [None]:
import torch

import re

from sklearn.decomposition import PCA
import numpy as np
import matplotlib.pyplot as plt

from mpl_toolkits import mplot3d

from Bio import SeqIO
from torch.utils import data
from data_generator import Dataset

In [None]:
small_file = "100k_rows.fasta"
small_label_file = "astral-scopedom-seqres-gd-sel-gs-bib-40-2.07.fasta"
big_label_file = "astral-scopedom-seqres-gd-sel-gs-bib-95-2.07.fasta"

max_seq_len = 300
acids = "ACDEFGHIKLMNOPQRSTUVWY-"
dataset = Dataset(small_file, max_seq_len, acids=acids)
base_generator = data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=16)



regex = re.compile('[a-z]\.')

def find_label(description):
    regex_str = regex.search(description)
    regex_str = regex_str.group()
    return regex_str[0]

#def find_labels(header):

    
# Loading the entire input file into memory
def load_data(file_name, size, has_labels=False):
    elem_list = []
    label_list = []
    check_list = []
    for i, elem in enumerate(SeqIO.parse(file_name, "fasta")):
        if (len(elem) < max_seq_len) and not(('X' in elem or 'x' in elem) or ('B' in elem or 'b' in elem) or ('Z' in elem or 'z' in elem) or ('J' in elem or 'j' in elem)):
            elem_list.append(np.array(str(elem.seq).ljust(max_seq_len)))
            check_list.append(elem.seq)
            if has_labels:
                label_list.append(find_label(elem.description))
        if len(elem_list) == size:
            elem_list = np.array(elem_list)
            break
    return elem_list, label_list, check_list

UL_big_file_elems, _, UL_check_list = load_data(small_file, 10000)
L_big_file_elems, L_big_file_labels,_ = load_data(big_label_file, 10000, has_labels=True)

print(L_big_file_labels)


In [None]:


def create_dicts(seqs, acids_str):
    count = 0
    char_to_int_dict = {}
    int_to_char_dict = {}
    cp_acids = np.copy(acids_str)
    
    for seq in seqs:
        for a in seq:
            if a in char_to_int_dict:
                continue
            else:
                char_to_int_dict[a] = count
                int_to_char_dict[count] = a
                count += 1
            if a in cp_acids:
                cp_acids = np.where(cp_acids != a)
        if cp_acids.size == 0:
            break
        
    return char_to_int_dict, int_to_char_dict



In [None]:
char_to_int_dict, int_to_char_dict = create_dicts(UL_big_file_elems, acids)

UL_check_elem_int_list = [[char_to_int_dict[a] for a in seq] for seq in UL_check_list]
UL_elem_int_list = np.array([[char_to_int_dict[a] for a in seq] for seq in UL_big_file_elems])


L_check_elem_int_list = [[char_to_int_dict[a.capitalize()] for a in seq] for seq in L_big_file_elems]


In [None]:
pca = PCA(n_components = 3)
low_dim_points = pca.fit_transform(UL_elem_int_list)
x = low_dim_points[:,0]
y = low_dim_points[:,1]
z = low_dim_points[:,2]
plt.scatter(x,y)
plt.show()

fig = plt.figure()
ax = plt.axes(projection='3d')
ax.scatter3D(x, y, z, cmap='Greens');

"""
for seq in check_elem_int_list:
    pca = PCA(n_components = 2)
    low_dim_points = pca.fit_transform(np.array([seq]))
    x = low_dim_points[:,0]
    y = low_dim_points[:,1]
    plt.scatter(x,y)
plt.show()
"""

In [None]:
label_grps = [0 if structure == 'a' else 1 for structure in L_big_file_labels]

colors = ['red', 'green']

pca = PCA(n_components = 3)
low_dim_points = pca.fit_transform(L_check_elem_int_list)

x_coordinates = low_dim_points[:,0]
y_coordinates = low_dim_points[:,1]
z_coordinates = low_dim_points[:,2]


fig = plt.figure()
ax = plt.axes(projection='3d')
for x,y,z,label in zip(x_coordinates,y_coordinates,z_coordinates, label_grps):
    ax.scatter3D(x, y, z, c=colors[label])
plt.show()



In [None]:
pca = PCA(n_components = 2)
low_dim_points = pca.fit_transform(L_check_elem_int_list)

x_coordinates = low_dim_points[:,0]
y_coordinates = low_dim_points[:,1]


for x,y,z,label in zip(x_coordinates,y_coordinates,z_coordinates, label_grps):
    plt.scatter(x, y, z, c=colors[label],marker='.')
plt.legend()
plt.show()