In [2]:
import sys
sys.path.append("../../Chapter_3")

import pandas as pd
from transformers import AutoTokenizer, AutoModel

import torch
from torch.utils.data import WeightedRandomSampler

from tqdm.notebook import tqdm

from sklearn.preprocessing import KBinsDiscretizer
from sklearn.model_selection import train_test_split

import fastai
from fastai.basics import Learner, DataLoaders
from fastai.vision.all import *
from fastai.callback.all import *

import seaborn as sbn
import matplotlib.pyplot as plt
import numpy as np

from Utils.negative_sampling import IterableProteinEmbedding, magnitude
from Utils.misc import filter_sequences_by_len_from_fasta
from Utils.encoders import SkipGramEmbedder

import warnings
warnings.filterwarnings('ignore')

# RNAInter HIV-1 data

In [3]:
df = pd.read_csv("../../Chapter_3/Data/TrainingData/PositiveSamples/RNAInter_small_interactions_current.csv")

hiv = 'Human immunodeficiency virus 1'

hiv_prot_interaction = df['Species'] == hiv
hiv_rna_interaction = df['Also_species'] == hiv

hiv_frame = df[hiv_prot_interaction | hiv_rna_interaction]
hiv_frame

print(f"""RNAInter HIV report
---
Total interactions: {len(hiv_frame)}
Unique proteins: {len(set(hiv_frame['proteins'].unique()))}
Unique RNAs: {len(set(hiv_frame['rnas'].unique()))}
HIV-1 proteins: {set(hiv_frame[hiv_frame['Also_species'] == hiv]['Protein_name'])}
HIV-1 RNAs: {set(hiv_frame[hiv_frame['Species'] == hiv]['RNA_Name'])}""")

RNAInter HIV repport
---
Total interactions: 14
Unique proteins: 1
Unique RNAs: 14
HIV-1 proteins: {'tat'}
HIV-1 RNAs: {'hiv1-miR-TAR-3p', 'hiv1-miR-TAR-5p'}


In [6]:
hiv_frame

Unnamed: 0.1,Unnamed: 0,Interaction_ID,RNA_Name,HGNC/Entrex/Ensembl_ID,RNA_type,Species,Protein_name,NCBI_GeneID,Protein_Type,Also_species,Confidence_score,proteins,UniprotID,rnas
12145,8923212,RP33979063,rno-let-7b-5p,MIMAT0000775,miRNA,Rattus norvegicus,tat,155871.0,protein,Human immunodeficiency virus 1,0.9768,MEPVDPRLEPWKHPGSQPKTACTNCYCKKCCFHCQVCFITKALGISYGRKKRRQRRRAHQNSQTHQASLSKQPTSQPRGDPTGPKE,P04608,UGAGGUAGUAGGUUGUGUGGUU
12146,8923213,RP33979064,rno-let-7c-5p,MIMAT0000776,miRNA,Rattus norvegicus,tat,155871.0,protein,Human immunodeficiency virus 1,0.9768,MEPVDPRLEPWKHPGSQPKTACTNCYCKKCCFHCQVCFITKALGISYGRKKRRQRRRAHQNSQTHQASLSKQPTSQPRGDPTGPKE,P04608,UGAGGUAGUAGGUUGUAUGGUU
12147,8923214,RP33979065,rno-let-7e-5p,MIMAT0000777,miRNA,Rattus norvegicus,tat,155871.0,protein,Human immunodeficiency virus 1,0.9768,MEPVDPRLEPWKHPGSQPKTACTNCYCKKCCFHCQVCFITKALGISYGRKKRRQRRRAHQNSQTHQASLSKQPTSQPRGDPTGPKE,P04608,UGAGGUAGGAGGUUGUAUAGUU
12148,8923215,RP33979066,rno-let-7f-5p,MIMAT0000778,miRNA,Rattus norvegicus,tat,155871.0,protein,Human immunodeficiency virus 1,0.9768,MEPVDPRLEPWKHPGSQPKTACTNCYCKKCCFHCQVCFITKALGISYGRKKRRQRRRAHQNSQTHQASLSKQPTSQPRGDPTGPKE,P04608,UGAGGUAGUAGAUUGUAUAGUU
12149,8923216,RP33979067,rno-miR-25-3p,MIMAT0000795,miRNA,Rattus norvegicus,tat,155871.0,protein,Human immunodeficiency virus 1,0.9768,MEPVDPRLEPWKHPGSQPKTACTNCYCKKCCFHCQVCFITKALGISYGRKKRRQRRRAHQNSQTHQASLSKQPTSQPRGDPTGPKE,P04608,CAUUGCACUUGUCUCGGUCUGA
12150,8923217,RP33979068,rno-miR-99a-5p,MIMAT0000820,miRNA,Rattus norvegicus,tat,155871.0,protein,Human immunodeficiency virus 1,0.9768,MEPVDPRLEPWKHPGSQPKTACTNCYCKKCCFHCQVCFITKALGISYGRKKRRQRRRAHQNSQTHQASLSKQPTSQPRGDPTGPKE,P04608,AACCCGUAGAUCCGAUCUUGUG
12151,8923218,RP33979069,rno-miR-100-5p,MIMAT0000822,miRNA,Rattus norvegicus,tat,155871.0,protein,Human immunodeficiency virus 1,0.9768,MEPVDPRLEPWKHPGSQPKTACTNCYCKKCCFHCQVCFITKALGISYGRKKRRQRRRAHQNSQTHQASLSKQPTSQPRGDPTGPKE,P04608,AACCCGUAGAUCCGAACUUGUG
12152,8923219,RP33979070,rno-miR-128-3p,MIMAT0000834,miRNA,Rattus norvegicus,tat,155871.0,protein,Human immunodeficiency virus 1,0.9768,MEPVDPRLEPWKHPGSQPKTACTNCYCKKCCFHCQVCFITKALGISYGRKKRRQRRRAHQNSQTHQASLSKQPTSQPRGDPTGPKE,P04608,UCACAGUGAACCGGUCUCUUU
12153,8923220,RP33979072,rno-miR-214-3p,MIMAT0000885,miRNA,Rattus norvegicus,tat,155871.0,protein,Human immunodeficiency virus 1,0.9768,MEPVDPRLEPWKHPGSQPKTACTNCYCKKCCFHCQVCFITKALGISYGRKKRRQRRRAHQNSQTHQASLSKQPTSQPRGDPTGPKE,P04608,ACAGCAGGCACAGACAGGCAG
12154,8923221,RP33979073,rno-miR-298-5p,MIMAT0000900,miRNA,Rattus norvegicus,tat,155871.0,protein,Human immunodeficiency virus 1,0.9768,MEPVDPRLEPWKHPGSQPKTACTNCYCKKCCFHCQVCFITKALGISYGRKKRRQRRRAHQNSQTHQASLSKQPTSQPRGDPTGPKE,P04608,GGCAGAGGAGGGCUGUUCUUCCC


In [3]:
hiv_frame_seq_interactions = list(zip(hiv_frame['proteins'], hiv_frame['rnas'],[1]*len(hiv_frame)))
hiv_frame_seq_interactions

[('MEPVDPRLEPWKHPGSQPKTACTNCYCKKCCFHCQVCFITKALGISYGRKKRRQRRRAHQNSQTHQASLSKQPTSQPRGDPTGPKE',
  'UGAGGUAGUAGGUUGUGUGGUU',
  1),
 ('MEPVDPRLEPWKHPGSQPKTACTNCYCKKCCFHCQVCFITKALGISYGRKKRRQRRRAHQNSQTHQASLSKQPTSQPRGDPTGPKE',
  'UGAGGUAGUAGGUUGUAUGGUU',
  1),
 ('MEPVDPRLEPWKHPGSQPKTACTNCYCKKCCFHCQVCFITKALGISYGRKKRRQRRRAHQNSQTHQASLSKQPTSQPRGDPTGPKE',
  'UGAGGUAGGAGGUUGUAUAGUU',
  1),
 ('MEPVDPRLEPWKHPGSQPKTACTNCYCKKCCFHCQVCFITKALGISYGRKKRRQRRRAHQNSQTHQASLSKQPTSQPRGDPTGPKE',
  'UGAGGUAGUAGAUUGUAUAGUU',
  1),
 ('MEPVDPRLEPWKHPGSQPKTACTNCYCKKCCFHCQVCFITKALGISYGRKKRRQRRRAHQNSQTHQASLSKQPTSQPRGDPTGPKE',
  'CAUUGCACUUGUCUCGGUCUGA',
  1),
 ('MEPVDPRLEPWKHPGSQPKTACTNCYCKKCCFHCQVCFITKALGISYGRKKRRQRRRAHQNSQTHQASLSKQPTSQPRGDPTGPKE',
  'AACCCGUAGAUCCGAUCUUGUG',
  1),
 ('MEPVDPRLEPWKHPGSQPKTACTNCYCKKCCFHCQVCFITKALGISYGRKKRRQRRRAHQNSQTHQASLSKQPTSQPRGDPTGPKE',
  'AACCCGUAGAUCCGAACUUGUG',
  1),
 ('MEPVDPRLEPWKHPGSQPKTACTNCYCKKCCFHCQVCFITKALGISYGRKKRRQRRRAHQNSQTHQASLSKQPTSQPRGDPTGPKE',
  'UCACAGUGAACCGGUCUCUUU',

# Klase Tat-miRNA data

In [7]:
from Bio import SeqIO

with open("../Data/Klase_Tat_interaction_miRNA_names.yaml") as handle:
    interactions = yaml.load(handle, Loader = yaml.FullLoader)
    
with open("/data/Chapter_3/Data/RawData/mirbase.fa") as handle:
    name2seq = {record.name : str(record.seq) for record in SeqIO.parse(handle, 'fasta') if 'hsa' in record.name}
    
with open("../Data/RawData/Klase_translated_tat_fragment.fasta") as handle:
    # Has spacer and flag tag, removing them for refinement data
    klase_tat = str(next(SeqIO.parse(handle, 'fasta')).seq)[:101]

In [224]:
klase_tat

'M E P V D P R L E P W K H P G S Q P K T A C T N C Y C K K C C F H C Q V C F I T K A L G I S Y G R K K R R Q R R R P P Q G S Q T H Q V S L S K Q P T S Q S R G D P T G P K E S K K K V E R E T E T D P F D'

In [9]:
import re

def process_id(mirna_id : str, name2seq : dict) -> str:
    """
    Jank conversions from listed miRNAs to their updated names.
    
    """

    # Just an odd ID typo
    if mirna_id == "hsa-miR-103-a":
        return "hsa-mir-103a-1"
    
    # http://www.mirbase.org/cgi-bin/mirna_entry.pl?acc=MI0001727
    # hsa-miR-453 has been merged with  hsa-mir-323b and should be used instead
    if mirna_id == 'hsa-miR-453':
        return "hsa-mir-323b"
    
    # See whether ID just works as is
    replaced_id = mirna_id.replace("miR","mir")
    
    try:
        name2seq[mirna_id]
        return mirna_id
    except:
        pass
    
    try:
        name2seq[replaced_id]
        return replaced_id
    except:
        pass
    
    # See whether adding "a" as end fixes ID
    try:
        formatted_id = f"{mirna_id}a"
        name2seq[formatted_id]
        return formatted_id
    except:
        pass
    
    try:
        formatted_id = f"{replaced_id}a"
        name2seq[formatted_id]
        return formatted_id
    except:
        pass
    
    # See whether adding "-1" at end fixes ID
    try:
        formatted_id = f"{mirna_id}-1"
        name2seq[formatted_id]
        return formatted_id
    except:
        pass
    
    try:
        formatted_id = f"{replaced_id}-1"
        name2seq[formatted_id]
        return formatted_id
    except:
        pass
    
    # See whether adding "-1" at end fixes ID
    try:
        formatted_id = f"{mirna_id}-1"
        name2seq[formatted_id]
        return formatted_id
    except:
        pass
    
    try:
        formatted_id = f"{replaced_id}-1"
        name2seq[formatted_id]
        return formatted_id
    except:
        pass
    
    # See whether adding "a-1" fixes ID
    try:
        formatted_id = f"{mirna_id}a-1"
        name2seq[formatted_id]
        return formatted_id
    except:
        pass
    
    try:
        formatted_id = f"{replaced_id}a-1"
        name2seq[formatted_id]
        return formatted_id
    except:
        pass

    # See whether adding "a" to number fixes ID
    try:
        formatted_id = re.sub(r"-(\d+)-", r"-\1a-", mirna_id)
        name2seq[formatted_id]
        return formatted_id
    except:
        pass
    
    try:
        formatted_id = re.sub(r"-(\d+)-", r"-\1a-", replaced_id)
        name2seq[formatted_id]
        return formatted_id
    except:
        pass
    
    # See whether adding "-1" to number fixes ID
    try:
        formatted_id = re.sub(r"-(\d+)-", r"-\1-1-", mirna_id)
        name2seq[formatted_id]
        return formatted_id
    except:
        pass
    
    try:
        formatted_id = re.sub(r"-(\d+)-", r"-\1-1-", replaced_id)
        name2seq[formatted_id]
        return formatted_id
    except:
        pass
    
#         try:
#             formatted_id = f"hsa-{mirna_id}"
#             name2seq[formatted_id]
#             return f"hsa-{mirna_id}"
#         except:
#             split_id = mirna_id.split('-')
#             split_id[1] = f"{split_id[1]}a"
#             formatted_id = f"hsa-{'-'.join(split_id)}"
#             return formatted_id

NOTE: Some listed miRNAs do not exist. These include:
* hsa-miR-220{a,b,c} (http://www.mirbase.org/cgi-bin/mirna_entry.pl?acc=MI0000297)

In [12]:
# Keep true if experimenting with ID conversions
# If False, removes all None name : seq values. This is to remove miRNAs that no longer exist in miRBase. 
# If True, keeps all None values and displays which IDs need to be further modified to return sequence
testing = False

tat_interactions = {name : name2seq.get(process_id(name, name2seq), None) for name in interactions['Interaction']}

print(f"Original Interacting Tat-miRNA count: {len(tat_interactions)}")


if not testing:
    # Filter out long miRNAs and non-existant miRNAs
    tat_interactions = {key : value for key,value in tat_interactions.items() if value}
    tat_interactions = {key : value for key,value in tat_interactions.items() if len(value) <= 100}

print(f"Filtered for existing and <= 100 Interacting Tat-miRNA count: {len(tat_interactions)}")

    
if testing:
    for key,value in tat_interactions.items():
        if not value:
            print(f"{key}: {value}")

Original Interacting Tat-miRNA count: 68
Filtered for existing and <= 100 Interacting Tat-miRNA count: 57


In [13]:
# Keep true if experimenting with ID conversions
# If False, removes all None name : seq values. This is to remove miRNAs that no longer exist in miRBase. 
# If True, keeps all None values and displays which IDs need to be further modified to return sequence
testing = False

tat_noninteractions = {name : name2seq.get(process_id(name, name2seq), None) for name in interactions['NonInteraction']}

print(f"Original Non-interacting miRNA count: {len(tat_noninteractions)}")

if not testing: 
    # Filter out long miRNAs and non-existant miRNAs
    tat_noninteractions = {key : value for key,value in tat_noninteractions.items() if value}
    tat_noninteractions = {key : value for key,value in tat_noninteractions.items() if len(value) <= 100}

print(f"Filtered for existing and <= 100 Non-interacting miRNA count: {len(tat_noninteractions)}")
    
if testing:
    for key,value in tat_noninteractions.items():
        if not value:
            print(f"{key}: {value}")

Original Non-interacting miRNA count: 314
Filtered for existing and <= 100 Non-interacting miRNA count: 284


In [14]:
# Klase Tat is 111, but need to reduce truncate final residue to bring to 110
# This is the only time that i need to do this and it's only a single residue
# truncation. 
klase_seq_interactions = list(zip([klase_tat[:-1]]*len(tat_interactions), tat_interactions.values(), [1]*len(tat_interactions)))
klase_seq_noninteractions = list(zip([klase_tat[:-1]]*len(tat_noninteractions), tat_noninteractions.values(), [0]*len(tat_noninteractions)))

# Generate HIV-1 reduced negative samples

All tats are wild type, so only need to incorperate miRNA data when sampling

In [24]:
hiv_interactions = hiv_frame_seq_interactions + klase_seq_interactions
positive_interaction_count = len(hiv_interactions)

print(f"There are: {positive_interaction_count} positive interactions")

There are: 71 positive interactions


In [105]:
_, rnas, _ = zip(*klase_seq_noninteractions)
unique_mirnas = list(set(rnas))
len(unique_mirnas)

280

In [106]:
rna_embedder = SkipGramEmbedder("/data/Chapter_3/SeqEmbedders/GensimWord2Vec/RNA2Vec_1024_hidden.model", reduce = True)
encoded_mirnas = rna_embedder(unique_mirnas)
encoded_mirnas

tensor([[-0.0819, -0.4581,  0.0760,  ...,  0.2146, -0.0318,  0.1334],
        [-0.2269, -0.4215, -0.3855,  ...,  0.0557,  0.2574, -0.2805],
        [-0.0234, -0.2672,  0.0008,  ..., -0.0528, -0.3317,  0.1246],
        ...,
        [-0.1364, -0.1354, -0.0173,  ...,  0.1199, -0.0333, -0.1199],
        [ 0.0610, -0.1866,  0.0520,  ..., -0.2541, -0.0556, -0.2120],
        [-0.0716, -0.3849, -0.0318,  ...,  0.1413,  0.0740,  0.3057]])

In [225]:
len(encoded_mirnas)

71

In [111]:
from hdbscan import HDBSCAN

# Playing around until I get roughly 71 clusters
# 5 gets the best trade-off between cluster count and unique RNAs
clusterer = hdbscan.HDBSCAN(min_cluster_size=2)
clusters = clusterer.fit_predict(encoded_mirnas)

#np.save(SWISS_CLUSTERS_PATH, clusters)

In [113]:
from math import ceil

TOTAL_REDUCED_SAMPLES = positive_interaction_count

cluster_count      = len(set(clusters))
sample_per_cluster = ceil(TOTAL_REDUCED_SAMPLES / cluster_count)
anticiapted_total_samples = cluster_count * sample_per_cluster

print(f"TOTAL SAMPLES: {TOTAL_REDUCED_SAMPLES}")
print(f"Total clusters: {cluster_count}")
print(f"Samples per cluster: {sample_per_cluster}")
print(f"Anticiapted smaple count: {anticiapted_total_samples}")

TOTAL SAMPLES: 71
Total clusters: 54
Samples per cluster: 2
Anticiapted smaple count: 108


In [115]:
for i in pd.Series(clusters).value_counts().iteritems():
    print(i)

(-1, 135)
(33, 9)
(36, 8)
(39, 7)
(11, 5)
(24, 5)
(51, 4)
(40, 4)
(49, 3)
(42, 3)
(18, 3)
(41, 3)
(44, 3)
(1, 3)
(5, 3)
(34, 3)
(2, 3)
(30, 3)
(7, 3)
(28, 2)
(27, 2)
(50, 2)
(19, 2)
(4, 2)
(43, 2)
(38, 2)
(48, 2)
(25, 2)
(32, 2)
(12, 2)
(37, 2)
(29, 2)
(45, 2)
(22, 2)
(47, 2)
(16, 2)
(21, 2)
(52, 2)
(46, 2)
(9, 2)
(0, 2)
(8, 2)
(35, 2)
(14, 2)
(13, 2)
(26, 2)
(3, 2)
(17, 2)
(20, 2)
(23, 2)
(6, 2)
(31, 2)
(10, 2)
(15, 2)


In [116]:
import numpy as np
def isolate_data_by_cluster(data : np.array, clusters : np.array, cluster_id : int) -> np.array:
    return data[clusters == cluster_id,:]

In [117]:
def isolate_data_by_clusters(data : np.array, clusters : np.array) -> list:
    return [isolate_data_by_cluster(data, clusters, cluster) for cluster in set(clusters)]

In [129]:
encoded_clusters = isolate_data_by_clusters(encoded_mirnas, clusters)
encoded_data = [cluster[:sample_per_cluster] for cluster in encoded_clusters]
reduced_encoded_mirnas = torch.cat(encoded_data)[:TOTAL_REDUCED_SAMPLES]
reduced_encoded_mirnas.shape

torch.Size([71, 1024])

In [132]:
klase_tat = " ".join(klase_tat)

In [133]:
model     = AutoModel.from_pretrained("Rostlab/prot_bert_bfd")
tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case = False)

tat_embedder = IterableProteinEmbedding([klase_tat], tokenizer, model, chunksize = 3, max_len = 110, cuda = True)

with torch.no_grad():
    encoded_klase_tat = torch.cat(list(tat_embedder))

  0%|          | 0/1 [00:00<?, ?it/s]

In [142]:
encoded_tats = torch.tile(encoded_klase_tat,(TOTAL_REDUCED_SAMPLES,1))

In [226]:
negative_tat_interactions = [encoded_tats.cpu(), encoded_mirnas, torch.tensor([0]*TOTAL_REDUCED_SAMPLES)]
negative_tat_interactions

[tensor([[-0.0354,  0.0086, -0.0037,  ..., -0.0222,  0.0086, -0.0599],
         [-0.0354,  0.0086, -0.0037,  ..., -0.0222,  0.0086, -0.0599],
         [-0.0354,  0.0086, -0.0037,  ..., -0.0222,  0.0086, -0.0599],
         ...,
         [-0.0354,  0.0086, -0.0037,  ..., -0.0222,  0.0086, -0.0599],
         [-0.0354,  0.0086, -0.0037,  ..., -0.0222,  0.0086, -0.0599],
         [-0.0354,  0.0086, -0.0037,  ..., -0.0222,  0.0086, -0.0599]]),
 tensor([[-0.0798, -0.0949, -0.0973,  ..., -0.1659,  0.0539, -0.3381],
         [ 0.6910,  0.2231,  1.8914,  ..., -1.2562,  0.1066, -0.2320],
         [-0.1549, -0.1900, -0.3548,  ...,  0.1658,  0.1127, -0.0458],
         ...,
         [-0.2056, -0.6489,  0.3222,  ...,  0.7906,  0.4120,  0.8115],
         [ 0.2071, -0.2519,  0.5972,  ..., -0.2751, -0.0802, -0.1256],
         [-0.0286, -0.3779,  0.4139,  ..., -0.0795,  0.3838,  0.0880]]),
 tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0,

In [227]:
[len(i) for i in negative_tat_interactions]

[71, 71, 71]

# Generare HIV-1 positive data

In [188]:
######################################################################
# NOTE: 86 exclusively interacts with samples, please deal with later
######################################################################

interacting_proteins, interacting_mirnas, _ = zip(*klase_seq_interactions)

interacting_proteins += tuple(hiv_frame['proteins'])
interacting_mirnas += tuple(hiv_frame['rnas'])

In [189]:
proteins = [" ".join(list(seq)) for seq in interacting_proteins]

In [190]:
model     = AutoModel.from_pretrained("Rostlab/prot_bert_bfd")
tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case = False)

tat_embedder = IterableProteinEmbedding(proteins, tokenizer, model, chunksize = 3, max_len = 110, cuda = True)

with torch.no_grad():
    encoded_tat = torch.cat(list(tat_embedder))

  0%|          | 0/24 [00:00<?, ?it/s]

In [194]:
rna_embedder = SkipGramEmbedder("/data/Chapter_3/SeqEmbedders/GensimWord2Vec/RNA2Vec_1024_hidden.model", reduce = True)
encoded_mirnas = rna_embedder(interacting_mirnas)
encoded_mirnas

tensor([[-0.0798, -0.0949, -0.0973,  ..., -0.1659,  0.0539, -0.3381],
        [ 0.6910,  0.2231,  1.8914,  ..., -1.2562,  0.1066, -0.2320],
        [-0.1549, -0.1900, -0.3548,  ...,  0.1658,  0.1127, -0.0458],
        ...,
        [-0.2056, -0.6489,  0.3222,  ...,  0.7906,  0.4120,  0.8115],
        [ 0.2071, -0.2519,  0.5972,  ..., -0.2751, -0.0802, -0.1256],
        [-0.0286, -0.3779,  0.4139,  ..., -0.0795,  0.3838,  0.0880]])

In [197]:
encoded_mirnas

tensor([[-0.0798, -0.0949, -0.0973,  ..., -0.1659,  0.0539, -0.3381],
        [ 0.6910,  0.2231,  1.8914,  ..., -1.2562,  0.1066, -0.2320],
        [-0.1549, -0.1900, -0.3548,  ...,  0.1658,  0.1127, -0.0458],
        ...,
        [-0.2056, -0.6489,  0.3222,  ...,  0.7906,  0.4120,  0.8115],
        [ 0.2071, -0.2519,  0.5972,  ..., -0.2751, -0.0802, -0.1256],
        [-0.0286, -0.3779,  0.4139,  ..., -0.0795,  0.3838,  0.0880]])

In [198]:
interacts = torch.tensor([1]*len(encoded_mirnas))
interacts

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [199]:
encoded_interaction_data = [encoded_tat.cpu(), encoded_mirnas, interacts]
encoded_interaction_data

[tensor([[-0.0355,  0.0101, -0.0067,  ..., -0.0212,  0.0080, -0.0615],
         [-0.0355,  0.0101, -0.0067,  ..., -0.0212,  0.0080, -0.0615],
         [-0.0355,  0.0101, -0.0067,  ..., -0.0212,  0.0080, -0.0615],
         ...,
         [-0.0318,  0.0129, -0.0052,  ..., -0.0309,  0.0040, -0.0528],
         [-0.0318,  0.0129, -0.0052,  ..., -0.0309,  0.0040, -0.0528],
         [-0.0318,  0.0129, -0.0052,  ..., -0.0309,  0.0040, -0.0528]]),
 tensor([[-0.0798, -0.0949, -0.0973,  ..., -0.1659,  0.0539, -0.3381],
         [ 0.6910,  0.2231,  1.8914,  ..., -1.2562,  0.1066, -0.2320],
         [-0.1549, -0.1900, -0.3548,  ...,  0.1658,  0.1127, -0.0458],
         ...,
         [-0.2056, -0.6489,  0.3222,  ...,  0.7906,  0.4120,  0.8115],
         [ 0.2071, -0.2519,  0.5972,  ..., -0.2751, -0.0802, -0.1256],
         [-0.0286, -0.3779,  0.4139,  ..., -0.0795,  0.3838,  0.0880]]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1,

In [228]:
[len(i) for i in encoded_interaction_data]

[71, 71, 71]

# Generate HIV-1 training data 

In [232]:
# negative_tat_interactions,negative_tat_interactions]
encoded_hiv_training_data = [torch.cat([encoded_interaction_data[i], negative_tat_interactions[i]]) for i in range(3)]

print(f"""Total HIV-1 training data: {len(encoded_hiv_training_data[0])}
Total positive samples: {sum(encoded_hiv_training_data[-1])}
Total negative samples: {len(encoded_hiv_training_data[-1]) - sum(encoded_hiv_training_data[-1])}""")

Total HIV-1 training data: 142
Total positive samples: 71
Total negative samples: 71


In [233]:
#encoded_hiv_training_data
[len(i) for i in encoded_hiv_training_data]

[142, 142, 142]

In [234]:
# No 2D version for HIV yet
with open("../Data/TrainingData/FullEmbeddedHIVTrainingData.pickle",'wb') as infile:
    pickle.dump(encoded_hiv_training_data, infile)