In [1]:
import pickle
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
file = open("./datasets/col_ac.mod1", "rb")
data = pickle.load(file)

data = data["train"] + data["test"]

In [3]:
checkpoint = "models/splicingGPT"

tokenizer = GPT2Tokenizer.from_pretrained(checkpoint)
model = GPT2LMHeadModel.from_pretrained(checkpoint)

device = torch.device("cuda")
model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50263, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50263, bias=False)
)

In [4]:
def predict(model, tokenizer, sequence, device="cuda"):
	model.eval()
	input_text = f"sequence: {sequence}\nawnser: "
	input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

	with torch.no_grad():
		outputs = model.generate(
			input_ids,
			max_new_tokens=1,
			repetition_penalty=2.0,
			top_k=50,
			top_p=0.9,
			pad_token_id=tokenizer.eos_token_id,
		)
		
		generated_token_ids = outputs[0, input_ids.size(-1):]
		new_token = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip()
		return new_token

In [5]:
def nucl2tokens(sequence):
  return "".join([f"[{nucl}]" for nucl in sequence])

def predict_sequence(seq):
  intron_comb = []
  gt_pos = 0
  ag_pos = 0

  intron_comb = []
  gt_pos = seq.find("GT")
  while(gt_pos != -1):
    ag_pos = seq.find("AG", gt_pos+2)
    prediction = "[EXON]"
    while(ag_pos != -1):
      current_seq = seq[gt_pos:ag_pos+2]

      tokenized_sequence = nucl2tokens(current_seq)

      prediction = predict(model, tokenizer, tokenized_sequence)
      
      if prediction == "[INTRON]":
        intron_comb.append(current_seq)
        gt_pos = seq.find("GT", ag_pos+1)
        ag_pos = -1
      else:
        ag_pos = seq.find("AG", ag_pos+1)

    if prediction != "[INTRON]":
      gt_pos = seq.find("GT", gt_pos+1)

  for intron in intron_comb:
    seq = seq.replace(intron, "")

  return seq

In [7]:
total = 0
hits = 0

for sequence in data:
  answer = predict_sequence(sequence['complete_sequence'])
  total += 1

  if(sequence['response_sequence'] == answer):
    hits += 1
  
  if (total % 10 == 0):
    print("Hits:"+str(hits),"\nTotal:"+str(total))
    print("Result:"+str((hits/total)*100))    


print("Hits:"+str(hits),"\nTotal:"+str(total))
print("Result:"+str((hits/total)*100))

Hits:0 
Total:10
Result:0.0
Hits:0 
Total:20
Result:0.0
Hits:0 
Total:30
Result:0.0
Hits:0 
Total:40
Result:0.0
Hits:0 
Total:50
Result:0.0
Hits:0 
Total:60
Result:0.0
Hits:0 
Total:70
Result:0.0
Hits:0 
Total:80
Result:0.0
Hits:0 
Total:90
Result:0.0
Hits:0 
Total:100
Result:0.0
Hits:0 
Total:110
Result:0.0
Hits:0 
Total:120
Result:0.0
Hits:0 
Total:130
Result:0.0
Hits:0 
Total:140
Result:0.0
Hits:0 
Total:150
Result:0.0
Hits:0 
Total:160
Result:0.0
Hits:0 
Total:170
Result:0.0
Hits:0 
Total:180
Result:0.0
Hits:0 
Total:190
Result:0.0
Hits:0 
Total:200
Result:0.0
Hits:0 
Total:210
Result:0.0
Hits:0 
Total:220
Result:0.0
Hits:0 
Total:230
Result:0.0
Hits:0 
Total:240
Result:0.0
Hits:0 
Total:250
Result:0.0
Hits:0 
Total:260
Result:0.0
Hits:0 
Total:270
Result:0.0
Hits:0 
Total:280
Result:0.0
Hits:0 
Total:290
Result:0.0
Hits:0 
Total:300
Result:0.0
Hits:0 
Total:310
Result:0.0
Hits:0 
Total:320
Result:0.0
Hits:0 
Total:330
Result:0.0
Hits:0 
Total:340
Result:0.0
Hits:0 
Total:350
Resul