In [1]:
import pandas as pd
import random

from classes.RebuildSeqs_GPT import RebuildSeqsGPT

In [2]:
df_train = pd.read_csv("datasets/RebuildSeqs_20k_small.csv", keep_default_na=False)
df_test = pd.read_csv("datasets/RebuildSeqs_2k_small.csv", keep_default_na=False)

train_sequence = df_train.iloc[:, 0].tolist()
train_builded = df_train.iloc[:, 1].tolist()
train_organism = df_train.iloc[:, 2].tolist()

test_sequence = df_test.iloc[:, 0].tolist()
test_builded = df_test.iloc[:, 1].tolist()
test_organism = df_test.iloc[:, 2].tolist()

In [4]:
model = RebuildSeqsGPT(checkpoint="models/ExInSeqs-GPT2-001", device="cuda", seed=1234, notification=True, logs_dir="logs", alias="gpt2", log_level="info")

Started models/ExInSeqs-GPT2-001 model


In [5]:
model.add_train_data({
  "sequence": train_sequence,
  "builded": train_builded,
  "organism": train_organism,
},  batch_size=8, sequence_len=256)

In [6]:
model.add_test_data({
  "sequence": test_sequence,
  "builded": test_builded,
  "organism": test_organism,
},  batch_size=8, sequence_len=256)

In [7]:
model.train(lr=5e-5, epochs=5, save_at_end=False, save_freq=5)

Training Epoch 1/5:   0%|          | 0/2500 [00:00<?, ?it/s]

Training Epoch 2/5:   0%|          | 0/2500 [00:00<?, ?it/s]

Training Epoch 3/5:   0%|          | 0/2500 [00:00<?, ?it/s]

Training Epoch 4/5:   0%|          | 0/2500 [00:00<?, ?it/s]

Training Epoch 5/5:   0%|          | 0/2500 [00:00<?, ?it/s]

In [8]:
def process_sequence( sequence):
	return f"".join(f"[{nucl.upper()}]" for nucl in sequence)

In [9]:
idx = 0
prompt = f"Sequence:{process_sequence(train_sequence[idx])}\nOrganism:{train_organism[idx]}\nMarked Sequence:"
print(prompt)

Sequence:[A][T][C][A][T][A][C][C][T][G][A][T][G][G][A][A][T][A][A][A][T][T][G][C][T][T][T][T][T][A][G][A][A][A][A][T][T][T][C][A][T][A][T][T][T][A][G][T][A][T][T][C][C][T][A][C][T][A][C][C][T][G][G][T][G][T][T][G][C][A][T][C][T][A][G][A][G][C][A][G][C][T][G][C][A][G][C][A][C][C][G][G][A][A][A][A][T][G][T][T][A][A][T][A][A][T][C][C][A][T][T][T][T][C][C][T][T][C][A][G][A][T][A][T][T][T][G][C][T][G][G][C][T][A][A][T][G][A][T][G][C][A][A][C][A][A][T][T][A][A][T][G][T][T][G][A][A][T][T][A][A][C][G][C][T][T][T][G][T][T][T][G][A][A][T][A][C][C][A][T][C][T][C][C][G][A][A][T][T][T][T][T][T][A][T][G][A][A][A][T][C][T][A][A][T][G][A][T][T][T][A][A][A][T][A][A][A][T][T][T][T][C][T][T][C][T][T][A][A][A][C][A][G][C][T][A][C][A][A][A][T][A][T][T][A][T][T][T][G][A][C][A][G][A][T][A][C][C][T][T][T][A][T][G][A][C][T][A][G][C][A][T][T][A][C][C][A][C][C][A][G][C][C][C][A][G][C][C][A][C][C]
Organism:Rotaria magnacalcarata
Marked Sequence:


In [10]:
sequence = model.tokenizer.encode(prompt, truncation=True, max_length=1024, add_special_tokens=True, padding=True, return_tensors="pt")

In [11]:
sequence = sequence.to("cuda")

In [12]:
attention_mask=[token != model.tokenizer.eos_token_id for token in sequence]

In [13]:
attention_mask = attention_mask[0].unsqueeze(0)

In [46]:
pred = model.model.generate(input_ids=sequence, attention_mask=attention_mask, repetition_penalty=1, temperature=1.2, top_k=5, do_sample=True, max_new_tokens=512, pad_token_id=model.tokenizer.eos_token_id)

In [47]:
print(f"Prompt: {prompt}\n")
print(f"Target: {train_builded[idx]}\n")
result = model.tokenizer.decode(pred[0])
print(f"Result: {result}\n")

Prompt: Sequence:[A][T][C][A][T][A][C][C][T][G][A][T][G][G][A][A][T][A][A][A][T][T][G][C][T][T][T][T][T][A][G][A][A][A][A][T][T][T][C][A][T][A][T][T][T][A][G][T][A][T][T][C][C][T][A][C][T][A][C][C][T][G][G][T][G][T][T][G][C][A][T][C][T][A][G][A][G][C][A][G][C][T][G][C][A][G][C][A][C][C][G][G][A][A][A][A][T][G][T][T][A][A][T][A][A][T][C][C][A][T][T][T][T][C][C][T][T][C][A][G][A][T][A][T][T][T][G][C][T][G][G][C][T][A][A][T][G][A][T][G][C][A][A][C][A][A][T][T][A][A][T][G][T][T][G][A][A][T][T][A][A][C][G][C][T][T][T][G][T][T][T][G][A][A][T][A][C][C][A][T][C][T][C][C][G][A][A][T][T][T][T][T][T][A][T][G][A][A][A][T][C][T][A][A][T][G][A][T][T][T][A][A][A][T][A][A][A][T][T][T][T][C][T][T][C][T][T][A][A][A][C][A][G][C][T][A][C][A][A][A][T][A][T][T][A][T][T][T][G][A][C][A][G][A][T][A][C][C][T][T][T][A][T][G][A][C][T][A][G][C][A][T][T][A][C][C][A][C][C][A][G][C][C][C][A][G][C][C][A][C][C]
Organism:Rotaria magnacalcarata
Marked Sequence:

Target: (exon)ATCATACCTGATGGAATAAATTGCTTTTTAGAAAATTTCATATTT

In [48]:
model.save_checkpoint()

Model Successful Saved at models\gpt2_4
