# Training only with introns and exons

In [1]:
import pickle
import random
import torch
from torch.utils.data import DataLoader, Dataset

import numpy as np
from datasets import Dataset
from transformers import GPT2LMHeadModel, GPT2Tokenizer, TrainingArguments, Trainer

  from .autonotebook import tqdm as notebook_tqdm





In [2]:
file = open("./database/col_ac.mod1", "rb")
data = pickle.load(file)

database = data["train"] + data["test"]

In [3]:
introns_data = []
exons_data = []

for sequence in database:
	introns = sequence["introns"]
	exons = sequence["exons"]

	for intron in introns:
		introns_data.append(intron["data"])

	for exon in exons:
		exons_data.append(exon["data"])

introns_data = list(set(introns_data))
exons_data = list(set(exons_data))

In [4]:
seq = introns_data[0]

In [5]:
transformers_input = []
idx = 0

for sequence in introns_data:
	tokenized_sequence = ""
	for nucl in sequence:
		tokenized_sequence = tokenized_sequence + f"[{nucl}]"

	transformers_input.append({
		"prompt": f"what is the classification for this sequence? {tokenized_sequence}",
		"completion": "[INTRON]",
		"idx": idx
	})
	idx += 1

for sequence in exons_data:
	tokenized_sequence = ""
	for nucl in sequence:
		tokenized_sequence = tokenized_sequence + f"[{nucl}]"

	transformers_input.append({
		"prompt": f"what is the classification for this sequence? {tokenized_sequence}",
		"completion": "[EXON]",
		"idx": idx
	})
	idx += 1

random.shuffle(transformers_input)

In [6]:
train_proportion = 0.8
dataset_len = len(transformers_input)
crop = int(train_proportion * dataset_len)

train = transformers_input[:crop]
test = transformers_input[crop:]

In [7]:
checkpoint = "gpt2"

tokenizer = GPT2Tokenizer.from_pretrained(checkpoint)
model = GPT2LMHeadModel.from_pretrained(checkpoint)

In [None]:
special_tokens = ["[A]", "[C]", "[G]", "[T]", "[EXON]", "[INTRON]"]
tokenizer.add_tokens(special_tokens)
model.resize_token_embeddings(len(tokenizer), mean_resizing=False)

In [9]:
sequence_max_lengths = [len(tokenizer(seq["prompt"])["input_ids"]) for seq in transformers_input]

In [10]:
crop_length = int(np.percentile(sequence_max_lengths, 95))

In [None]:
print(f"Lenght for the sequences crop: {crop_length}")

In [12]:
hf_train = Dataset.from_list(train)
hf_test = Dataset.from_list(test)

In [13]:
tokenizer.pad_token = tokenizer.eos_token

In [None]:
hf_train[0]

In [None]:
def tokenize_function(example):
  inputs = tokenizer(example["prompt"], truncation=True, padding="max_length", max_length=crop_length)
  outputs = tokenizer(example["completion"], truncation=True, padding="max_length", max_length=crop_length)
  inputs["labels"] = outputs["input_ids"]
  return inputs

tokenized_train = hf_train.map(tokenize_function, batched=True)
tokenized_test = hf_test.map(tokenize_function, batched=True)

In [16]:
training_args = TrainingArguments(
  output_dir="./results",
  eval_strategy="epoch",
  learning_rate=0.005,
  num_train_epochs=20,
  per_device_train_batch_size=16,
  save_steps=50,
  save_total_limit=1,
)

trainer = Trainer(
  model=model,
  args=training_args,
  train_dataset=tokenized_train,
  eval_dataset=tokenized_test,
  processing_class=tokenizer
)

In [None]:
trainer.train()

In [None]:
model.save_pretrained("spliceGPT")
tokenizer.save_pretrained("spliceGPT")

In [27]:
subject = 15
prompt = "Write a story about a dragon who learns to fly."

In [36]:
tokenized_prompt = tokenizer(prompt, return_tensors="pt")

In [None]:
output = model.generate(
    tokenized_prompt["input_ids"].cuda(),
    max_length=200,
    num_beams=5,            # Usando busca por feixe para melhor desempenho
    temperature=0.5,        # Experimente uma temperatura maior
    top_k=50,               # Considera apenas os 50 tokens mais prováveis
    top_p=0.95,             # Aplica amostragem com probabilidade acumulada
    do_sample=False,         # Amostragem ativa
    no_repeat_ngram_size=2  # Evita repetições de n-grams
)


In [None]:
print(tokenized_prompt["input_ids"])
print(output)

In [None]:
generated_sequence = tokenizer.decode(output[0], skip_special_tokens=True).strip()
print("Response:", generated_sequence)

In [None]:
dd