In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class SplicingSitesDataset(Dataset):
	def __init__(self, sequences, labels, tokenizer, max_length):
		self.sequences = sequences
		self.labels = labels
		self.tokenizer = tokenizer
		self.max_length = max_length
	
	def __len__(self):
		return len(self.sequences)
	
	def __getitem__(self, idx):
		prompt = self.sequences[idx]
		label = self.labels[idx]

		input_text = f"sequence: {prompt} awnser: "
		output_text = f"{label}"
		input_ids = self.tokenizer.encode(input_text, truncation=True, max_length=self.max_length, add_special_tokens=True, padding=True)
		label_ids = self.tokenizer.encode(output_text, truncation=True, max_length=self.max_length, add_special_tokens=False)

		input_ids += label_ids
		labels = [-100] * len(input_ids[:-len(label_ids)]) + label_ids

		return torch.tensor(input_ids), torch.tensor(labels)

In [3]:
def collate_fn(batch):
	input_ids, labels = zip(*batch)
	max_len = max(len(ids) for ids in input_ids)
	input_ids_padded = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
	labels_padded = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
	return input_ids_padded, labels_padded

In [4]:
def train(model, dataloader, optimizer, epochs=3, device="cuda"):
	model.train()
	for epoch in range(epochs):
		total_loss = 0
		for batch in dataloader:
			input_ids, labels = [b.to(device) for b in batch]
			outputs = model(input_ids=input_ids, labels=labels)
			loss = outputs.loss
			loss.backward()
			optimizer.step()
			optimizer.zero_grad()
			total_loss += loss.item()
		print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader)}")

In [5]:
def predict(model, tokenizer, sequence, device="cuda"):
	model.eval()
	input_text = f"sequence: {sequence} awnser: "
	input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

	with torch.no_grad():
		outputs = model.generate(
			input_ids,
			max_new_tokens=10,
			repetition_penalty=2.0,
			top_k=50,
			top_p=0.9,
			pad_token_id=tokenizer.eos_token_id,
		)
		
		completion = tokenizer.decode(outputs[0], skip_special_tokens=True)
		return completion.replace(input_text, "").strip()

In [6]:
file = open("./database/col_ac.mod1", "rb")
data = pickle.load(file)

database = data["train"] + data["test"]

In [7]:
introns_data = []
exons_data = []

for sequence in database:
	introns = sequence["introns"]
	exons = sequence["exons"]

	for intron in introns:
		introns_data.append(intron["data"])

	for exon in exons:
		exons_data.append(exon["data"])

introns_data = list(set(introns_data))
exons_data = list(set(exons_data))

In [8]:
sequences = introns_data + exons_data
labels = ["intron" for _ in range(len(introns_data))] + ["exon" for _ in range(len(exons_data))]

In [9]:
import random

all_data = list(zip(sequences, labels))

random.shuffle(all_data)

sequences_shuffled, labels_shuffled = zip(*all_data)

split_index = int(0.9 * len(sequences_shuffled))

train_x = sequences_shuffled[:split_index]
train_y = labels_shuffled[:split_index]

test_x = sequences_shuffled[split_index:]
test_y = labels_shuffled[split_index:]


In [10]:
import numpy as np

max_length = int(np.percentile([len(seq) for seq in sequences], 95))
print(max_length)

112


In [None]:
checkpoint = "gpt2"

In [11]:
tokenizer = GPT2Tokenizer.from_pretrained(checkpoint)
model = GPT2LMHeadModel.from_pretrained(checkpoint)

tokenizer.pad_token = tokenizer.eos_token

In [None]:
special_tokens = ["[A]", "[C]", "[G]", "[T]", "[EXON]", "[INTRON]"]

In [12]:
dataset = SplicingSitesDataset(train_x, train_y, tokenizer, max_length=max_length)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = AdamW(model.parameters(), lr=0.0005)



In [14]:
train(model, dataloader, optimizer, epochs=10)

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Epoch 1/10, Loss: 0.6825599449872971
Epoch 2/10, Loss: 0.07910846535116434
Epoch 3/10, Loss: 0.03698986181756481
Epoch 4/10, Loss: 0.025260103847831488
Epoch 5/10, Loss: 0.017396068065427244
Epoch 6/10, Loss: 0.01097581772133708
Epoch 7/10, Loss: 0.006646627892478136
Epoch 8/10, Loss: 0.007252144568774384
Epoch 9/10, Loss: 0.008844278592941918
Epoch 10/10, Loss: 0.01107653695216868


In [23]:
total = len(test_x)

hits = 0
for i in range(len(test_x)):
	subject = test_x[i]
	label = test_y[i]

	prediction = predict(model, tokenizer, subject)

	if label == "exon":
		if prediction[:4] == "exon":
			hits+=1
	elif label == "intron":
		if prediction[:6] == "intron":
			hits+=1

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


In [24]:
print(hits/total)

1.0
