In [36]:
import math
import torch
import pandas as pd
from gensim.models import Word2Vec
import os
from typing import List
from torch.utils.data import random_split

In [90]:
def get_dataframe(path: str) -> pd.DataFrame:
    return pd.read_csv(path)

def get_sentence(sequence: str, window: int) -> List[str]:
    return [sequence[i:i+window] for i in range(0,len(sequence),window)]

def get_sentences(sequences: List[str], window: int) -> List[List[str]]:
    sentences = []
    for sequence in sequences:
        sentences.append(get_sentence(sequence, window))
    return sentences
        
# train ratio consistent with ratio in data_configs
def get_training_sentences(sentences: List[List[str]], split: int, train_ratio: float = 0.7):
    train_len = math.floor(train_ratio * len(sentences))
    train_sentences, _ = random_split(sentences, [train_len, len(sentences) - train_len], generator=torch.Generator().manual_seed(split))
    result = []
    for train_sentence in torch.utils.data.DataLoader(train_sentences):
        result.append([t[0] for t in train_sentence])
    return result

def get_trained_model(sentences: List[List[str]], features_number: int = 16, window: int = 3):
    return Word2Vec(sentences=sentences, vector_size=features_number, window=window, min_count=1, workers=4)

In [128]:
data_paths = ["../../../../../data/DeepMirTar.csv", "../../../../../data/miRAW.csv", "../../../../../data/MirTarRAW.csv"]
data_names = ["deepmirtar", "miraw", "mirtarraw"]
splits = [418,627,960,426,16,523,708,541,747,897,714,515,127,657,662,284,595,852,734,136,394,321,200,502,786,817,411,264,929,407]
mirna_key = "miRNA Seq (3'-->5')"
target_key = "Target Site"
word_size = 3

In [129]:
for path, data_name in zip(data_paths, data_names):
    print("path: ", path)
    df = get_dataframe(path)
    sentences_mirna = get_sentences(df[mirna_key], word_size)
    sentences_target = get_sentences(df[target_key], word_size)
    for i, split in enumerate(splits):
        print("split no: ", i)
        training_sentences_mirna = get_training_sentences(sentences_mirna, split)
        training_sentences_target = get_training_sentences(sentences_target, split)
        model_mirna = get_trained_model(training_sentences_mirna)
        model_mirna.save("word2vec_{}_{}_mirna.model".format(data_name, split))
        model_target = get_trained_model(training_sentences_target)
        model_target.save("word2vec_{}_{}_target.model".format(data_name, split))

path:  ../../../../../data/DeepMirTar.csv
split no:  0
split no:  1
split no:  2
split no:  3
split no:  4
split no:  5
split no:  6
split no:  7
split no:  8
split no:  9
split no:  10
split no:  11
split no:  12
split no:  13
split no:  14
split no:  15
split no:  16
split no:  17
split no:  18
split no:  19
split no:  20
split no:  21
split no:  22
split no:  23
split no:  24
split no:  25
split no:  26
split no:  27
split no:  28
split no:  29
path:  ../../../../../data/miRAW.csv
split no:  0
split no:  1
split no:  2
split no:  3
split no:  4
split no:  5
split no:  6
split no:  7
split no:  8
split no:  9
split no:  10
split no:  11
split no:  12
split no:  13
split no:  14
split no:  15
split no:  16
split no:  17
split no:  18
split no:  19
split no:  20
split no:  21
split no:  22
split no:  23
split no:  24
split no:  25
split no:  26
split no:  27
split no:  28
split no:  29
path:  ../../../../../data/MirTarRAW.csv
split no:  0
split no:  1
split no:  2
split no:  3
split no

In [126]:
model_loaded = Word2Vec.load("./word2vec_deepmirtar_515_mirna.model")

In [127]:
model_loaded.wv.get_vector("UCG", norm=True)

array([-0.28140032, -0.10750585,  0.38771498,  0.11904792,  0.00914918,
       -0.01234087,  0.5486606 , -0.01678914,  0.00997987,  0.29765213,
       -0.11197576, -0.00759577, -0.13498308, -0.46904308, -0.19051005,
       -0.25962675], dtype=float32)

In [89]:
model_mirna.wv.get_vector('UCG', norm=True).tolist()

[-0.09394899010658264,
 -0.15069660544395447,
 -0.10315322130918503,
 -0.10049491375684738,
 0.23426876962184906,
 0.6124224066734314,
 0.2541028559207916,
 -0.2321488857269287,
 -0.11454567313194275,
 -0.060300201177597046,
 0.12170539051294327,
 0.1472024768590927,
 -0.45102930068969727,
 -0.31932035088539124,
 0.10256762057542801,
 0.17369303107261658]

In [21]:
df.head()

Unnamed: 0.1,Unnamed: 0,miRNA ID,miRNA Seq (3'-->5'),mRNA Accession Number,Target Site,label
0,0,hsa-miR-96-5p,UCGUUUUUACACGAUCACGGUUU,ENSG00000095139,AGCAGAGCACTCACACATAAATGGCTGTGTGTGGAATTGC,1
1,1,hsa-miR-25-3p,AGUCUGGCUCUGUUCACGUUAC,ENSG00000009413,GACATTTTTGTTACAAACCTGTGGGCCTGTTGCAATACTT,1
2,2,hsa-miR-181b-5p,UGGGUGGCUGUCGUUACUUACAA,ENSG00000171862,GTGAAGGTCTGAATGAGGGTTTTGATTTTGAATGTTTCAA,1
3,3,hsa-let-7b-5p,UUGGUGUGUUGGAUGAUGGAGU,ENSG00000211460,AGAACATTTTGGTACAGTAAAAACACATCTAACATCTTTG,1
4,4,hsa-miR-29a-3p,AUUGGCUAAAGUCUACCACGAU,ENSG00000100483,CGGATGGAATTCTGGTATTTATAGGCATTGGTGCTAGATG,1


In [26]:
get_sentences(df[mirna_key], 3)

UCGUUUUUACACGAUCACGGUUU
['UCG', 'UUU', 'UUA', 'CAC', 'GAU', 'CAC', 'GGU', 'UU']
AGUCUGGCUCUGUUCACGUUAC
['AGU', 'CUG', 'GCU', 'CUG', 'UUC', 'ACG', 'UUA', 'C']
UGGGUGGCUGUCGUUACUUACAA
['UGG', 'GUG', 'GCU', 'GUC', 'GUU', 'ACU', 'UAC', 'AA']
UUGGUGUGUUGGAUGAUGGAGU
['UUG', 'GUG', 'UGU', 'UGG', 'AUG', 'AUG', 'GAG', 'U']
AUUGGCUAAAGUCUACCACGAU
['AUU', 'GGC', 'UAA', 'AGU', 'CUA', 'CCA', 'CGA', 'U']
AAGUUUUGUACUUAACGACGAC
['AAG', 'UUU', 'UGU', 'ACU', 'UAA', 'CGA', 'CGA', 'C']
AGUCCUUGACGGAAAGAGAGGU
['AGU', 'CCU', 'UGA', 'CGG', 'AAA', 'GAG', 'AGG', 'U']
UUGGUGUGUUGGAUGAUGGAGU
['UUG', 'GUG', 'UGU', 'UGG', 'AUG', 'AUG', 'GAG', 'U']
AAGUUUUGUACUUAACGACGAC
['AAG', 'UUU', 'UGU', 'ACU', 'UAA', 'CGA', 'CGA', 'C']
UUGGUAUGUUGGAUGAUGGAGU
['UUG', 'GUA', 'UGU', 'UGG', 'AUG', 'AUG', 'GAG', 'U']
AGUCUGGCUCUGUUCACGUUAC
['AGU', 'CUG', 'GCU', 'CUG', 'UUC', 'ACG', 'UUA', 'C']
UGCCAAAAUGGUCUGUCAUAAU
['UGC', 'CAA', 'AAU', 'GGU', 'CUG', 'UCA', 'UAA', 'U']
UCGAUACGGUCGUAGAACGGA
['UCG', 'AUA', 'CGG', 'UCG', 'UAG', 'A