In [None]:
from transformers import BertForTokenClassification, BertModel
from data_dir import pretrained_3kmer_dir

bertForTokenClassification = BertForTokenClassification.from_pretrained(pretrained_3kmer_dir)

In [4]:
from data_dir import chr24_index_csv, chr24_fasta, labseq_dir, labseq_names
chr_indices = [chr24_index_csv]
chr_fastas = [chr24_fasta]
chr_labseq_path = ["{}/{}".format(labseq_dir, fname) for fname in [labseq_names[-1]]]
print(chr_indices)
print(chr_fastas)
print(chr_labseq_path)

['./data/genome/grch38/exon/NC_000024.10.csv']
['./data/chr/NC_000024.10.fasta']
['./data/genome/labseq/chr24.csv']


In [5]:
from data_dir import chr24_index_csv, chr24_fasta, labseq_dir, labseq_names
from data_preparation import generate_sequence_labelling
chr_indices = [chr24_index_csv]
chr_fastas = [chr24_fasta]
chr_labseq_path = ["{}/{}".format(labseq_dir, fname) for fname in [labseq_names[-1]]]
for src, fasta, target in zip(chr_indices, chr_fastas, chr_labseq_path):
    print("Generating sequential labelling for index {}, from fasta {}, to {}: {}".format(src, fasta, target, generate_sequence_labelling(src, fasta, target, do_expand=True, expand_size=512)))

Processing index ./data/genome/grch38/exon/NC_000024.10.csv, with fasta ./data/chr/NC_000024.10.fasta, to seq. labelling ./data/genome/labseq/chr24.csv, expanding [5431760/57226904]

In [46]:
from transformers import BertTokenizer
from data_dir import pretrained_3kmer_dir
from sequential_labelling import preprocessing, initialize_seq2seq

"""
Initialize model and tokenizer.
"""
tokenizer = BertTokenizer.from_pretrained(pretrained_3kmer_dir)
in_out_dimensions = [768, 512, 512, 10]
model = initialize_seq2seq(pretrained_3kmer_dir, in_out_dimensions)
#print(model)

"""
Create sample data sequential labelling.
"""
from random import randint
from data_preparation import kmer
from sequential_labelling import process_sequence_and_label, create_dataloader
sequences = ['ATGC' * 128, 'TGAC' * 128, 'GATC' * 128, "AGCC" * 128]
labels = [['E' if randint(0, 255) % 2 == 0 else '.' for i in range(len(s))] for s in sequences]

kmer_seq = [' '.join(kmer(sequence, 3)) for sequence in sequences]
kmer_label = [' '.join(kmer(''.join(label), 3)) for label in labels]

arr_input_ids = []
arr_attn_mask = []
arr_label_repr = []
for seq, label in zip(kmer_seq, kmer_label):
    input_ids, attn_mask, label_repr = process_sequence_and_label(seq, label, tokenizer)
    arr_input_ids.append(input_ids)
    arr_attn_mask.append(attn_mask)
    arr_label_repr.append(label_repr)

dataloader = create_dataloader(arr_input_ids, arr_attn_mask, arr_label_repr, batch_size=2)

In [79]:
"""
Play with result.
"""
from tqdm import tqdm
from torch import nn
loss_fn = nn.NLLLoss()
activation_fn = nn.Softmax(dim=2)
linear = nn.Linear(in_features=768, out_features=10)
for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
    input_ids, attn_mask, label_repr = tuple(t for t in batch)
    output = model(input_ids, attn_mask)
    bert_output = model.bert(input_ids, attn_mask)
    bert_output = bert_output[0] # tensor.Size([batch_size, seq_length, dim])
    linear_output = linear(bert_output) # tensor.Size([batch_size, seq_length, dim])
    activation = activation_fn(linear_output)
    print("activation", activation.shape, activation)
    print("label repr", label_repr.shape, label_repr)
    loss = loss_fn(activation, label_repr)



  0%|          | 0/2 [00:03<?, ?it/s]

activation torch.Size([2, 512, 10]) tensor([[[0.0335, 0.1366, 0.2869,  ..., 0.0582, 0.0439, 0.0373],
         [0.0534, 0.0582, 0.1790,  ..., 0.0410, 0.0642, 0.1359],
         [0.1073, 0.1038, 0.1294,  ..., 0.0689, 0.2670, 0.0596],
         ...,
         [0.0502, 0.0610, 0.1907,  ..., 0.0427, 0.0626, 0.1263],
         [0.1091, 0.1058, 0.1277,  ..., 0.0690, 0.2639, 0.0587],
         [0.0335, 0.1366, 0.2869,  ..., 0.0582, 0.0439, 0.0373]],

        [[0.0247, 0.0701, 0.1186,  ..., 0.0540, 0.0300, 0.0743],
         [0.0973, 0.0563, 0.1604,  ..., 0.1201, 0.1110, 0.0659],
         [0.0378, 0.0617, 0.0935,  ..., 0.1416, 0.0607, 0.1822],
         ...,
         [0.0957, 0.0594, 0.1599,  ..., 0.1230, 0.1072, 0.0659],
         [0.0368, 0.0624, 0.0940,  ..., 0.1403, 0.0591, 0.1861],
         [0.0247, 0.0701, 0.1186,  ..., 0.0540, 0.0300, 0.0743]]],
       grad_fn=<SoftmaxBackward0>)
label repr torch.Size([2, 512]) tensor([[0, 1, 2,  ..., 7, 5, 9],
        [0, 7, 5,  ..., 7, 5, 9]])





RuntimeError: Expected target size [2, 10], got [2, 512]

In [None]:
    print('linear output {}'.format(linear_output.shape))
    linear_output_permute = linear_output.permute(0, 2, 1) # tensor.Size([batch_size, dim, seq_length])
    print('linear output permute {}'.format(linear_output_permute.shape))
    print('label repr {}'.format(label_repr.shape))
    activation = activation_fn(linear_output_permute)
    print('activation {}'.format(activation.shape))
    loss = loss_fn(linear_output_permute, label_repr)
    print('loss {}'.format(linear_output_permute))


In [85]:
import torch
loss = nn.CrossEntropyLoss()
activate = nn.Softmax(dim=2)
input = torch.randn(3, 5, 10, requires_grad=True, dtype=torch.float)
input = activate(input)
target = torch.empty(3, 5, 10, dtype=torch.float).random_(5)
output = loss(input, target)
print(input.shape, input)
print(target.shape, target)

torch.Size([3, 5, 10]) tensor([[[0.0761, 0.0309, 0.1782, 0.3150, 0.0041, 0.0667, 0.0346, 0.0324,
          0.1923, 0.0697],
         [0.0491, 0.0709, 0.1599, 0.1308, 0.0236, 0.0327, 0.2852, 0.0152,
          0.0579, 0.1746],
         [0.1077, 0.0813, 0.0841, 0.0973, 0.0452, 0.0425, 0.0752, 0.0313,
          0.1997, 0.2358],
         [0.1600, 0.1957, 0.1299, 0.1179, 0.0709, 0.0377, 0.1201, 0.0344,
          0.0467, 0.0868],
         [0.1917, 0.0216, 0.0058, 0.1390, 0.0640, 0.0325, 0.2813, 0.0407,
          0.0276, 0.1961]],

        [[0.0189, 0.2757, 0.0623, 0.0221, 0.0229, 0.1040, 0.0658, 0.0579,
          0.0653, 0.3052],
         [0.0463, 0.0238, 0.0840, 0.4257, 0.0258, 0.0374, 0.1712, 0.0945,
          0.0349, 0.0564],
         [0.0105, 0.0182, 0.2384, 0.0794, 0.0994, 0.2283, 0.1128, 0.1145,
          0.0499, 0.0485],
         [0.0164, 0.0236, 0.1261, 0.0135, 0.1104, 0.3142, 0.1054, 0.0131,
          0.1640, 0.1134],
         [0.0067, 0.0832, 0.1345, 0.1329, 0.1143, 0.0975, 0.0780, 