[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mitiau/DNABERT-Z/blob/main/ZDNA-prediction.ipynb)

# Install dependecies and define helper functions

In [1]:
!pip install transformers
!pip install biopython

Collecting biopython
  Downloading biopython-1.83-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: biopython
Successfully installed biopython-1.83


In [2]:
import torch
from torch import nn
import transformers
from transformers import BertTokenizer, BertForTokenClassification
import numpy as np
from Bio import SeqIO
from io import StringIO, BytesIO
from google.colab import drive, files
from tqdm import tqdm
import pickle
import scipy
from scipy import ndimage

In [3]:
def seq2kmer(seq, k):
    kmer = [seq[x:x+k] for x in range(len(seq)+1-k)]
    return kmer

def split_seq(seq, length = 512, pad = 16):
    res = []
    for st in range(0, len(seq), length - pad):
        end = min(st+512, len(seq))
        res.append(seq[st:end])
    return res

def stitch_np_seq(np_seqs, pad = 16):
    res = np.array([])
    for seq in np_seqs:
        res = res[:-pad]
        res = np.concatenate([res,seq])
    return res

# Select model and parameters

In [4]:
model = 'HG kouzine' #@param ["HG chipseq", "HG kouzine", "MM chipseq", "MM kouzine"]
model_confidence_threshold = 0.5 #@param {type:"number"}
minimum_sequence_length = 10 #@param {type:"integer"}

In [5]:
if model == 'HG chipseq':
    model_id = '1VAsp8I904y_J0PUhAQqpSlCn1IqfG0FB'
elif model == 'HG kouzine':
    model_id = '1dAeAt5Gu2cadwDhbc7OnenUgDLHlUvkx'
elif model == 'MM curax':
    model_id = '1W6GEgHNoitlB-xXJbLJ_jDW4BF35W1Sd'
elif model == 'MM kouzine':
    model_id = '1dXpQFmheClKXIEoqcZ7kgCwx6hzVCv3H'


In [6]:
!gdown $model_id
!gdown 10sF8Ywktd96HqAL0CwvlZZUUGj05CGk5
!gdown 16bT7HDv71aRwyh3gBUbKwign1mtyLD2d
!gdown 1EE9goZ2JRSD8UTx501q71lGCk-CK3kqG
!gdown 1gZZdtAoDnDiLQqjQfGyuwt268Pe5sXW0


!mkdir 6-new-12w-0
!mv pytorch_model.bin 6-new-12w-0/
!mv config.json 6-new-12w-0/
!mv special_tokens_map.json 6-new-12w-0/
!mv tokenizer_config.json 6-new-12w-0/
!mv vocab.txt 6-new-12w-0/

Downloading...
From (original): https://drive.google.com/uc?id=1dAeAt5Gu2cadwDhbc7OnenUgDLHlUvkx
From (redirected): https://drive.google.com/uc?id=1dAeAt5Gu2cadwDhbc7OnenUgDLHlUvkx&confirm=t&uuid=31c672f1-c6b2-4a96-9fcb-2e9d659b8aa4
To: /content/pytorch_model.bin
100% 354M/354M [00:03<00:00, 110MB/s] 
Downloading...
From: https://drive.google.com/uc?id=10sF8Ywktd96HqAL0CwvlZZUUGj05CGk5
To: /content/config.json
100% 634/634 [00:00<00:00, 2.49MB/s]
Downloading...
From: https://drive.google.com/uc?id=16bT7HDv71aRwyh3gBUbKwign1mtyLD2d
To: /content/special_tokens_map.json
100% 112/112 [00:00<00:00, 358kB/s]
Downloading...
From: https://drive.google.com/uc?id=1EE9goZ2JRSD8UTx501q71lGCk-CK3kqG
To: /content/tokenizer_config.json
100% 40.0/40.0 [00:00<00:00, 158kB/s]
Downloading...
From: https://drive.google.com/uc?id=1gZZdtAoDnDiLQqjQfGyuwt268Pe5sXW0
To: /content/vocab.txt
100% 28.7k/28.7k [00:00<00:00, 19.3MB/s]


In [7]:
tokenizer = BertTokenizer.from_pretrained('6-new-12w-0/')
model = BertForTokenClassification.from_pretrained('6-new-12w-0/')
model.cuda()

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(4101, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

# Upload fasta files for prediction

In [8]:
uploaded = files.upload()
for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))

Saving cutted_genome (2).fna to cutted_genome (2).fna
User uploaded file "cutted_genome (2).fna" with length 6966709 bytes


# Predict and save raw outputs

In [9]:
out = []
for key in uploaded.keys():
    print(key)
    out.append(key)
    result_dict = {}
    for seq_record in SeqIO.parse(StringIO(BytesIO(uploaded[key]).read().decode('UTF-8')), 'fasta'):
        kmer_seq = seq2kmer(str(seq_record.seq).upper(), 6)
        seq_pieces = split_seq(kmer_seq)
        print(seq_record.name)
        out.append(seq_record.name)
        with torch.no_grad():
            preds = []
            for seq_piece in tqdm(seq_pieces):
                input_ids = torch.LongTensor(tokenizer.encode(' '.join(seq_piece), add_special_tokens=False))
                outputs = torch.softmax(model(input_ids.cuda().unsqueeze(0))[-1],axis = -1)[0,:,1]
                preds.append(outputs.cpu().numpy())
        result_dict[seq_record.name] = stitch_np_seq(preds)



        labeled, max_label = scipy.ndimage.label(result_dict[seq_record.name]>model_confidence_threshold)
        print('  start     end')
        out.append('  start     end')
        for label in range(1, max_label+1):
            candidate = np.where(labeled == label)[0]
            candidate_length = candidate.shape[0]
            if candidate_length>minimum_sequence_length:
                print('{:8}'.format(candidate[0]), '{:8}'.format(candidate[-1]))
                out.append('{:8}'.format(candidate[0]) + '{:8}'.format(candidate[-1]))

    with open(key + '.preds.pkl',"wb") as fh:
      pickle.dump(result_dict, fh)
    print()

with open('text_predictions.txt',"w") as fh:
    for item in out:
        fh.write("%s\n" % item)


cutted_genome (2).fna
CAJNDS010000402.1


100%|██████████| 1935/1935 [01:18<00:00, 24.59it/s]


  start     end
     836      857
     878      907
    1870     1888
    2317     2335
    2964     2975
    4087     4102
    4576     4596
    5239     5257
    7100     7112
    8599     8614
   10475    10490
   11156    11170
   13895    13906
   14327    14342
   15734    15748
   16401    16416
   17114    17129
   17402    17413
   31511    31524
   32632    32643
   33965    33977
   34170    34184
   35355    35368
   36811    36829
   39327    39342
   39867    39882
   41817    41833
   42493    42507
   43126    43138
   43249    43261
   46358    46376
   48340    48357
   50711    50723
   51301    51320
   53164    53181
   53744    53759
   54932    54948
   56508    56523
   59413    59434
   60301    60313
   61120    61137
   61569    61581
   62546    62561
   62565    62602
   63144    63159
   63723    63734
   66148    66162
   66635    66651
   68911    68926
   71224    71236
   71308    71324
   71579    71592
   73943    73957
   74260    74273
   74579    

100%|██████████| 280/280 [00:11<00:00, 24.14it/s]


  start     end
    4998     5012
    6418     6435
    9897     9919
   10170    10191
   10868    10880
   13091    13105
   18430    18449
   20149    20164
   20501    20514
   20671    20696
   20795    20808
   24306    24342
   24388    24399
   24551    24565
   25011    25026
   25131    25148
   25629    25644
   25919    25933
   26683    26701
   27408    27422
   28370    28383
   35272    35284
   37298    37312
   38673    38687
   40441    40454
   41743    41766
   43610    43624
   43801    43818
   44865    44879
   49215    49225
   53484    53498
   53716    53728
   55715    55729
   58244    58255
   60421    60438
   62976    62991
   72026    72042
   72842    72860
   77115    77131
   78848    78862
   82687    82699
   83452    83467
   83804    83818
   83946    83959
   87218    87252
   88452    88471
   88865    88884
   88927    88943
   94324    94340
   95549    95562
   95695    95711
   96891    96906
   96935    96953
   97583    97598
   98970    

100%|██████████| 1533/1533 [01:01<00:00, 24.97it/s]


  start     end
    2148     2168
    2813     2829
    7392     7410
    8728     8741
   10810    10832
   10854    10870
   14312    14323
   14694    14711
   16899    16918
   19212    19224
   23862    23876
   27395    27410
   28746    28763
   30455    30468
   41331    41344
   43990    44006
   44755    44774
   45430    45451
   45984    46002
   46802    46814
   47616    47628
   49084    49094
   52633    52648
   53183    53195
   53214    53228
   57049    57064
   60607    60626
   60685    60700
   61859    61869
   62966    62981
   63333    63346
   64045    64059
   65665    65678
   66637    66657
   67159    67173
   69280    69295
   70943    70963
   71076    71113
   71426    71451
   75756    75772
   76357    76375
   77257    77274
   77347    77364
   77486    77499
   77542    77564
   77957    77975
   78624    78640
   79358    79372
   79882    79894
   80290    80305
   80413    80429
   80848    80861
   80883    80899
   81829    81839
   81896    

100%|██████████| 1326/1326 [00:53<00:00, 24.63it/s]


  start     end
    1181     1192
    2881     2893
    3443     3456
    6620     6637
    7712     7729
    9556     9574
   11117    11138
   11586    11603
   18271    18289
   18543    18558
   22656    22676
   29072    29094
   30050    30060
   33385    33402
   37548    37567
   38480    38496
   41679    41708
   41765    41781
   45559    45576
   46398    46410
   46983    46994
   54393    54408
   57380    57397
   57850    57860
   61986    61998
   62054    62078
   63470    63487
   64610    64624
   65278    65290
   65298    65313
   72599    72615
   73291    73308
   80824    80840
   81775    81808
   81813    81839
   83257    83268
   84090    84105
   84245    84260
   84937    84947
   98245    98260
  102968   102980
  104466   104480
  109100   109114
  110819   110832
  111116   111131
  122209   122225
  123910   123925
  124218   124228
  124238   124270
  128048   128062
  128733   128748
  129243   129259
  131508   131521
  135775   135789
  136897   1

100%|██████████| 1086/1086 [00:43<00:00, 24.90it/s]


  start     end
    3584     3599
    4674     4691
    6627     6642
    8614     8625
    9574     9590
    9983     9999
   10181    10198
   10881    10899
   11502    11518
   11640    11655
   12613    12630
   13119    13136
   13219    13264
   14065    14083
   15472    15486
   15615    15635
   15712    15731
   15860    15871
   15929    15950
   16413    16427
   16822    16841
   16985    17002
   17276    17289
   17294    17311
   17327    17343
   17463    17480
   17572    17588
   17923    17942
   17949    17967
   18062    18087
   19613    19627
   19963    19976
   20004    20020
   21236    21252
   21505    21520
   21906    21918
   23347    23363
   24147    24162
   25769    25814
   25831    25846
   26538    26550
   26894    26905
   27300    27316
   27755    27767
   27778    27794
   28436    28454
   29353    29367
   31287    31304
   32957    32972
   33473    33493
   33631    33649
   34077    34091
   34861    34873
   36037    36052
   37205    

100%|██████████| 4575/4575 [03:03<00:00, 24.98it/s]


  start     end
    1640     1652
    2106     2120
    3422     3433
    4344     4361
    4922     4938
    5190     5207
    6151     6165
   10322    10342
   10367    10397
   10405    10415
   11863    11881
   11905    11919
   14057    14072
   15637    15647
   15935    15950
   17518    17534
   18483    18500
   19047    19061
   19445    19459
   20776    20793
   21435    21454
   23334    23352
   23430    23442
   23459    23474
   24184    24198
   27645    27676
   28339    28352
   28517    28533
   31834    31854
   31951    31965
   32208    32225
   32663    32678
   32831    32845
   33338    33355
   34034    34045
   34436    34452
   34958    34971
   35420    35436
   37677    37688
   38809    38828
   40357    40370
   40661    40671
   42216    42227
   43353    43375
   45534    45545
   45881    45893
   47551    47567
   47736    47751
   50328    50343
   51055    51070
   53060    53071
   54046    54060
   54821    54835
   55860    55873
   59609    

100%|██████████| 1000/1000 [00:40<00:00, 24.57it/s]


  start     end
      20       34
     910      924
    2180     2192
    2285     2300
    3355     3369
    3499     3517
    8509     8526
    9641     9657
    9867     9881
    9897     9919
   11849    11864
   16525    16538
   18020    18046
   18535    18546
   18693    18707
   20684    20699
   21701    21718
   22276    22289
   23625    23637
   24336    24349
   27572    27590
   28199    28214
   29787    29802
   30204    30218
   31530    31552
   31598    31612
   31839    31851
   32078    32093
   37453    37466
   37697    37713
   37738    37750
   39614    39628
   43210    43230
   44436    44455
   45742    45756
   47265    47280
   48113    48125
   52605    52619
   53300    53335
   54540    54551
   56922    56932
   57517    57529
   59004    59017
   59699    59712
   61315    61333
   62007    62020
   62166    62179
   65721    65733
   65819    65831
   67337    67354
   68275    68289
   69508    69530
   70885    70902
   71805    71821
   75952    

100%|██████████| 940/940 [00:37<00:00, 25.06it/s]


  start     end
     162      178
     806      824
    2536     2550
    4129     4149
    4904     4919
    5013     5023
    6235     6249
    6907     6918
    9824     9845
    9862     9876
   10047    10067
   10221    10234
   10283    10293
   10522    10534
   12533    12550
   12962    12972
   13718    13733
   18628    18644
   19401    19417
   19474    19486
   20897    20910
   22028    22040
   22979    22996
   23973    23988
   26477    26492
   28983    28998
   31153    31164
   32790    32808
   35089    35101
   35457    35469
   36163    36175
   37271    37292
   38442    38463
   38981    38996
   40745    40758
   42037    42055
   42141    42159
   43231    43248
   43347    43359
   44093    44110
   45069    45085
   46285    46298
   47400    47412
   48843    48867
   50902    50917
   50965    50983
   53507    53528
   53874    53887
   54373    54389
   56675    56696
   57042    57054
   57558    57573
   57615    57632
   59664    59680
   60600    

100%|██████████| 653/653 [00:26<00:00, 24.99it/s]


  start     end
    3165     3183
    3924     3941
    6911     6933
    9760     9771
    9953     9968
   10789    10803
   12505    12515
   18447    18461
   20808    20819
   23516    23530
   23987    24000
   24075    24092
   25561    25575
   25692    25705
   25836    25850
   26219    26233
   27228    27240
   28166    28180
   32266    32280
   32669    32681
   33796    33810
   35905    35917
   36479    36495
   38262    38281
   39339    39352
   40297    40310
   42168    42185
   42829    42844
   43458    43480
   43971    43988
   45838    45856
   46395    46409
   47392    47406
   51165    51178
   51937    51952
   52116    52127
   52274    52290
   52595    52609
   53008    53027
   54710    54725
   54771    54784
   57469    57483
   60182    60213
   62669    62683
   64663    64675
   65592    65606
   69760    69774
   73567    73585
   77440    77455
   77627    77641
   78009    78026
   80707    80719
   81455    81473
   85673    85693
   85863    

100%|██████████| 548/548 [00:22<00:00, 24.44it/s]


  start     end
    3125     3139
    4010     4026
    7212     7227
    9943     9960
   10499    10511
   10927    10947
   11557    11577
   14140    14158
   14485    14499
   15125    15142
   17825    17840
   19009    19024
   22395    22409
   23530    23546
   25176    25193
   25808    25828
   26891    26907
   27829    27845
   29277    29287
   30768    30779
   35602    35615
   36660    36671
   39232    39246
   39924    39942
   40181    40195
   40749    40762
   42006    42021
   42780    42802
   44167    44198
   45207    45219
   46641    46656
   47582    47599
   49310    49323
   51003    51020
   52338    52352
   55414    55438
   56370    56389
   61438    61454
   62687    62700
   67283    67298
   68513    68527
   69091    69108
   69139    69152
   69792    69810
   69858    69874
   70306    70324
   70773    70785
   70805    70818
   72326    72344
   72701    72719
   75446    75464
   76761    76775
   80818    80829
   82656    82670
   83621    

# Download text file with predictions

In [10]:
files.download('text_predictions.txt')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

# Download raw prediction files in numpy format

In [11]:

# for key in uploaded.keys():
#     files.download(key + '.preds.pkl')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [12]:
# !grep -P '(G{3,5}[ATGC]{1,7}){3,}G{3,5}' cutted_genome.fna

grep: cutted_genome.fna: No such file or directory
