In [1]:
import torch
from transformers import BertModel, BertTokenizer
from Bio import SeqIO

In [4]:
def encode_sequences(sequences, tokenizer, model, pooling_strategy='mean'):
    encoded_sequences = []
    for seq in sequences:
        spaced_sequence = ' '.join(seq)
        inputs = tokenizer(spaced_sequence, padding=True, truncation=True, max_length=50, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs)
        
        if pooling_strategy == 'mean':
            sequence_embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
        elif pooling_strategy == 'max':
            sequence_embedding = torch.max(outputs.last_hidden_state, dim=1)[0].squeeze().numpy()
        elif pooling_strategy == 'cls':
            sequence_embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
        elif pooling_strategy == 'sum':
            sequence_embedding = outputs.last_hidden_state.sum(dim=1).squeeze().numpy() 
        encoded_sequences.append(sequence_embedding)
    return encoded_sequences

In [5]:
positive_sequences = [str(record.seq) for record in SeqIO.parse("Put your positive sequence file path here", "fasta")]
negative_sequences = [str(record.seq) for record in SeqIO.parse("Put your negative sequence file path here", "fasta")]

In [6]:
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case=False)
model = BertModel.from_pretrained("Rostlab/prot_bert_bfd", output_attentions=True)

positive_encoded_mean = encode_sequences(positive_sequences, tokenizer, model, pooling_strategy='mean')
negative_encoded_mean = encode_sequences(negative_sequences, tokenizer, model, pooling_strategy='mean')

positive_encoded_max = encode_sequences(positive_sequences, tokenizer, model, pooling_strategy='max')
negative_encoded_max = encode_sequences(negative_sequences, tokenizer, model, pooling_strategy='max')

positive_encoded_cls = encode_sequences(positive_sequences, tokenizer, model, pooling_strategy='cls')
negative_encoded_cls = encode_sequences(negative_sequences, tokenizer, model, pooling_strategy='cls')

positive_encoded_sum = encode_sequences(positive_sequences, tokenizer, model, pooling_strategy='sum')
negative_encoded_sum = encode_sequences(negative_sequences, tokenizer, model, pooling_strategy='sum')

In [7]:
print(len(negative_sequences))
print(len(positive_sequences))

2311
2290


In [22]:
print(positive_encoded_sum[0])

[[ 0.03357462 -0.00529268  0.09248591 ... -0.10600295 -0.00761915
  -0.04234606]
 [-0.00605776  0.05155679  0.00612719 ... -0.04761077  0.01626376
   0.01347677]
 [ 0.04096205 -0.0606569   0.04992864 ... -0.05733338  0.01520739
   0.06641088]
 ...
 [ 0.13475624 -0.06733894  0.04205138 ...  0.03749225 -0.03350831
   0.05186091]
 [ 0.1554039  -0.03406589  0.01873373 ... -0.04375764 -0.117219
   0.09699213]
 [ 0.03420337 -0.00554545  0.09127251 ... -0.10602523 -0.00653328
  -0.04189716]]


In [9]:
with open("C:\\Users\\abu11\\Desktop\\BERT_Test_Pos_sum.csv", "w") as f:
    for seq in positive_encoded_sum:
        f.write(','.join(map(str, seq)) + '\n')

with open("C:\\Users\\abu11\\Desktop\\BERT_Test_Neg_sum.csv", "w") as f:
    for seq in negative_encoded_sum:
        f.write(','.join(map(str, seq)) + '\n')

In [52]:
with open("C:\\Users\\abu11\\Desktop\\BERT_Test_Pos_cls.csv", "w") as f:
    for seq in positive_encoded_cls:
        f.write(','.join(map(str, seq)) + '\n')

with open("C:\\Users\\abu11\\Desktop\\BERT_Test_Neg_cls.csv", "w") as f:
    for seq in negative_encoded_cls:
        f.write(','.join(map(str, seq)) + '\n')

In [21]:
import pandas as pd
df1 = pd.read_csv('C:\\Users\\abu11\\Desktop\\BERT_Test_Pos_attention.csv', delimiter=';')

df1.head()

Unnamed: 0,[ 0.03357462 -0.00529268 0.09248591 ... -0.10600295 -0.00761915
0,"-0.04234606],[-0.00605776 0.05155679 0.0061..."
1,"0.01347677],[ 0.04096205 -0.0606569 0.0499..."
2,"0.06641088],[ 0.01817202 -0.08332765 0.0611..."
3,"0.10475894],[ 0.02135178 -0.04715138 0.0698..."
4,"0.0268144 ],[ 0.04788483 -0.05477773 0.0821..."


In [12]:
df1.columns = [f'bert_sum{i}' for i in range(1, len(df1.columns) + 1)]

df1['target'] = 1

In [13]:
df2  = pd.read_csv('C:\\Users\\abu11\\Desktop\\BERT_Test_neg_sum.csv')
df2.head()

Unnamed: 0,0.6078501,-0.033819016,0.71843934,1.0952134,-0.05994433,-0.2627576,-0.70856726,-0.9782825,0.31817237,0.10805786,...,0.1961156,-0.024252236,-0.0052929334,0.48606575,1.4828134,-0.5066927,-0.22832696,-0.18453865,-0.59072304,0.6181028
0,0.465138,0.005496,0.264306,0.714771,0.23447,-0.478214,-0.652926,-0.626152,-0.502739,-0.236389,...,-0.059843,-0.063329,-0.141164,-0.352745,0.270981,-0.665013,0.199396,-0.084409,-0.058088,-0.174837
1,-0.127546,-0.489731,1.005401,0.512343,-0.350014,-0.449781,0.284251,-0.688972,0.738067,-0.60122,...,0.683462,0.441021,0.242982,-0.118614,0.994402,-0.825914,0.233739,-0.611015,0.486904,-0.039016
2,-0.142813,0.094975,0.709106,0.198099,0.14209,-1.05139,-0.387143,-0.695557,0.166491,0.286988,...,1.73816,0.904878,-0.774055,1.132149,1.446337,-1.393383,0.877534,-1.019527,-0.069766,0.780519
3,-1.975936,-0.486143,0.98355,1.244938,-2.983091,-0.794031,-0.390465,1.200614,-1.775833,3.329374,...,0.879556,0.970453,-2.028688,-0.910536,4.423152,0.277926,0.035409,-0.031013,-1.360497,0.414964
4,-1.962111,-0.445555,1.321101,2.66171,2.795088,-1.672955,1.078097,-2.972727,0.46092,1.968864,...,3.726685,-0.41002,-0.002791,2.117378,2.444093,-0.078362,-0.16663,-0.841132,-0.237218,0.511932


In [14]:
df2.columns = [f'bert_sum{i}' for i in range(1, len(df2.columns) + 1)]

df2['target'] = 0

In [15]:
concatenated_df = pd.concat([df1, df2], ignore_index=True)

concatenated_df.to_csv("C:\\Users\\abu11\\Desktop\\V_train_bert_sum.csv", index=False)

In [16]:
data  = pd.read_csv('C:\\Users\\abu11\\Desktop\\V_train_bert_sum.csv')
data.head()

Unnamed: 0,bert_sum1,bert_sum2,bert_sum3,bert_sum4,bert_sum5,bert_sum6,bert_sum7,bert_sum8,bert_sum9,bert_sum10,...,bert_sum1016,bert_sum1017,bert_sum1018,bert_sum1019,bert_sum1020,bert_sum1021,bert_sum1022,bert_sum1023,bert_sum1024,target
0,0.664481,-0.83324,0.949385,0.709222,0.796848,-3.534069,-0.662283,-2.118724,2.463265,-2.550574,...,-1.012768,-2.756887,-1.212176,3.672942,1.384616,0.598262,-1.745691,-0.153047,-0.009458,1
1,-0.650823,-0.273809,3.292779,0.909354,-1.355576,-0.737463,-2.499528,-1.206802,1.401714,1.191815,...,-0.820498,-0.699072,0.350797,3.1201,-1.41197,-0.847737,-1.305075,0.850361,0.468024,1
2,1.809837,-0.896181,1.435865,1.259906,-0.813909,-0.72648,0.321494,-0.918681,4.296026,-1.264895,...,0.150113,1.064861,-0.382289,2.081755,-0.581359,-0.791542,-2.098006,-1.009836,-0.814865,1
3,3.211404,-1.225885,1.490204,-0.05799,-1.634694,-0.737129,-1.577563,1.114908,1.029719,-0.75118,...,-2.034052,-1.425488,-0.329976,3.36468,0.276424,-0.409796,-0.636633,-0.109022,0.061589,1
4,0.933543,-0.420312,-0.317944,-0.458164,-0.173155,-0.707047,-0.7273,-0.153099,1.46309,-1.216267,...,-0.039619,-1.674797,0.169426,0.277445,-1.035159,-0.15528,-0.057869,-0.49371,0.381929,1


In [17]:
data.tail()

Unnamed: 0,bert_sum1,bert_sum2,bert_sum3,bert_sum4,bert_sum5,bert_sum6,bert_sum7,bert_sum8,bert_sum9,bert_sum10,...,bert_sum1016,bert_sum1017,bert_sum1018,bert_sum1019,bert_sum1020,bert_sum1021,bert_sum1022,bert_sum1023,bert_sum1024,target
4594,1.585739,-0.36317,1.587506,1.786021,-1.662963,-0.002691,-0.243363,-1.11092,2.573396,-0.492709,...,-0.698032,-0.413882,0.372741,2.623466,-1.940951,-0.118668,-0.964184,-0.879433,-0.748614,0
4595,-0.432391,-1.498025,-1.84751,0.820211,1.968133,0.472677,-1.654621,-1.007794,2.028036,0.777505,...,0.074654,1.927282,1.644241,-0.073696,0.742614,-2.158815,0.699135,-2.074988,-0.20572,0
4596,-0.273337,-0.706252,2.192671,-0.569836,0.478056,-0.448586,0.598712,0.45023,1.990354,1.847855,...,-1.481604,0.397334,-0.141435,3.663385,-1.701829,-2.718751,-0.232069,-0.156445,0.675764,0
4597,0.473447,-0.095857,1.010918,-0.189226,0.053869,-0.046097,-0.618201,-0.468743,1.017627,0.66962,...,-0.561635,0.544258,0.237207,0.976855,-1.063188,0.069729,-0.531466,-0.097044,0.311651,0
4598,1.0591,-0.299554,1.089729,-0.108264,-2.148119,-0.077978,0.293139,1.000435,0.726445,0.166719,...,-0.832273,-0.584437,-0.378896,1.197759,-0.197306,-0.521604,-1.499973,-0.886475,0.608731,0


In [18]:
data.shape

(4599, 1025)