In [2]:
import torch
from transformers import AutoTokenizer, AutoModel, BertConfig
import pandas as pd

In [None]:
config = BertConfig.from_pretrained("zhihan1996/DNABERT-2-117M")
tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
model = AutoModel.from_config(config)

In [33]:
# Read in the new sequence data and run head to see the overall data structure
data_in = pd.read_csv("data/sequence-wide.tsv", sep='\t') 
data_in.head

<bound method NDFrame.head of                       genus       species  \
0               Alitibacter   langaaensis   
1               Alitibacter   langaaensis   
2               Roseovarius     maritimus   
3               Roseovarius        roseus   
4           Planosporangium      spinosum   
...                     ...           ...   
27727     Thermoclostridium  stercorarium   
27728           Clostridium      isatidis   
27729         Couchioplanes     caeruleus   
27730             Halomonas     koreensis   
27731  Pseudoflavonifractor   phocaeensis   

                                                sequence   identifier  \
0      ATTGAAGAGTTTGATCATGGCTCAGATTGAACGCTGGCGGCAGGCT...  NR_118751.1   
1      ATTGAACGCTGGCGGCAGGCTTAACACATGCAAGTCGAACGGTAAC...  NR_042885.1   
2      CAACTTGAGAGTTTGATCCTGGCTCAGAACGAACGCTGGCGGCAGG...  NR_200035.1   
3      CAACTTGAGAGTTTGATCCTGGCTCAGAACGAACGCTGGCGGCAGG...  NR_200034.1   
4      TTGTTGGAGAGTTTGATCCTGGCTCAGGACGAACGCTGGCGGCGTG...  NR_200

In [34]:
# Check the columns
print("Column 1:")
print(data_in.iloc[:, 0])
print("Column 2:")
print(data_in.iloc[:, 1])
print("Column 3:")
print(data_in.iloc[:, 2])

Column 1:
0                 Alitibacter
1                 Alitibacter
2                 Roseovarius
3                 Roseovarius
4             Planosporangium
                 ...         
27727       Thermoclostridium
27728             Clostridium
27729           Couchioplanes
27730               Halomonas
27731    Pseudoflavonifractor
Name: genus, Length: 27732, dtype: object
Column 2:
0         langaaensis
1         langaaensis
2           maritimus
3              roseus
4            spinosum
             ...     
27727    stercorarium
27728        isatidis
27729       caeruleus
27730       koreensis
27731     phocaeensis
Name: species, Length: 27732, dtype: object
Column 3:
0        ATTGAAGAGTTTGATCATGGCTCAGATTGAACGCTGGCGGCAGGCT...
1        ATTGAACGCTGGCGGCAGGCTTAACACATGCAAGTCGAACGGTAAC...
2        CAACTTGAGAGTTTGATCCTGGCTCAGAACGAACGCTGGCGGCAGG...
3        CAACTTGAGAGTTTGATCCTGGCTCAGAACGAACGCTGGCGGCAGG...
4        TTGTTGGAGAGTTTGATCCTGGCTCAGGACGAACGCTGGCGGCGTG...
                 

In [35]:
data_in.sequence

0        ATTGAAGAGTTTGATCATGGCTCAGATTGAACGCTGGCGGCAGGCT...
1        ATTGAACGCTGGCGGCAGGCTTAACACATGCAAGTCGAACGGTAAC...
2        CAACTTGAGAGTTTGATCCTGGCTCAGAACGAACGCTGGCGGCAGG...
3        CAACTTGAGAGTTTGATCCTGGCTCAGAACGAACGCTGGCGGCAGG...
4        TTGTTGGAGAGTTTGATCCTGGCTCAGGACGAACGCTGGCGGCGTG...
                               ...                        
27727    TGATCCTGGCTCAGGACGAACGCTGGCGGCGTGCCTAACACATGCA...
27728    GGCGTGCNTAACACATGCAAGTCGAGCGAGGTGATTTCNTTCGGGA...
27729    CGCTGGCGGCGTGCTTAACACATGCAAGTCGAGCGGAAAGGCCCCT...
27730    ACGATGGGAGCTTGCTCCCAGGCGTCGAGCGGCGGACGGGTGAGTA...
27731    AGAGTTTGATCCTGGCTCAGGATGAACGCTGGCGGCGTRCTTAACA...
Name: sequence, Length: 27732, dtype: object

In [32]:
# Convert the new sequences to input tokens, view the first element
inputs = tokenizer(data_in.sequence.to_list())["input_ids"]
print(inputs[0])

[1, 2061, 25, 222, 23, 224, 143, 3411, 403, 247, 53, 150, 527, 2759, 2834, 2734, 724, 873, 81, 118, 2470, 30, 708, 72, 61, 679, 29, 200, 88, 2894, 71, 117, 639, 72, 1478, 315, 137, 2787, 825, 1826, 966, 189, 1235, 45, 229, 4079, 314, 1340, 835, 427, 138, 316, 99, 120, 2139, 76, 36, 987, 75, 315, 8, 0, 41, 199, 0, 0, 5, 778, 460, 632, 59, 100, 72, 26, 0, 1212, 1527, 71, 148, 281, 0, 9, 3558, 238, 92, 635, 59, 111, 556, 2787, 135, 52, 259, 64, 72, 31, 120, 469, 2816, 50, 2638, 166, 29, 135, 1262, 31, 141, 17, 495, 1170, 317, 32, 443, 79, 78, 30, 619, 36, 247, 137, 1517, 2810, 19, 153, 1826, 20, 277, 1080, 332, 159, 15, 583, 458, 61, 783, 18, 486, 17, 540, 29, 200, 14, 183, 22, 236, 168, 37, 282, 3453, 71, 7, 0, 3386, 34, 123, 315, 103, 265, 194, 50, 42, 534, 171, 259, 166, 112, 2394, 200, 106, 59, 118, 1952, 409, 577, 117, 124, 1832, 0, 113, 205, 35, 553, 403, 38, 499, 16, 605, 788, 212, 8, 0, 0, 9, 3382, 169, 194, 233, 368, 38, 2785, 1149, 282, 1435, 66, 101, 39, 386, 8, 0, 9, 10, 846, 

In [None]:
hidden_states = model(inputs)[0] # [1, sequence_length, 768]

# embedding with mean pooling
embedding_mean = torch.mean(hidden_states[0], dim=0)
print(embedding_mean.shape) # expect to be 768

# embedding with max pooling
embedding_max = torch.max(hidden_states[0], dim=0)[0]
print(embedding_max.shape) # expect to be 768