# Example of generation 

In [None]:
import torch
from config import get_config
from model import CDGPT
from tokenizer import SentencePieceTokenizer


In [None]:
torch.set_grad_enabled(False)

In [None]:
tokenizer_path = "checkpoints/tokenizer.model"
cfg = get_config()
cfg.tokenizer.path = tokenizer_path
tokenizer = SentencePieceTokenizer(tokenizer_path)

In [4]:
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

fasta_file = "example.fasta"
parser = SeqIO.parse(fasta_file, "fasta")
dna_record = next(parser)
protein_record = next(parser)
reverse_translate_record = next(parser)

## Translation generation

In [None]:
model_path = "checkpoints/CD-GPT-1b.pth"
state = torch.load(model_path, map_location="cpu")
model = CDGPT(cfg)
model.load_state_dict(state["model"], strict=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.half().to(device).eval()

In [6]:
prompt = f"<mRNA>{str(dna_record.seq)}</mRNA><translate><:>"
x = tokenizer.encode(prompt, eos=False, device=device)
output = model.generate(x,
                        max_new_tokens=128,
                        temperature=0.8,
                        top_k=128,
                        top_p=0.0,
                        stop_ids=(tokenizer.bos, tokenizer.eos, tokenizer.pad)
                        )
output = tokenizer.decode(output.sequences)

In [7]:
output = output[len(prompt):]
translate_res = output.split("</protein>")[0]

In [None]:
translate_gt = str(dna_record.seq.translate())
print(f"GROUND TRUTH ABOVE, GENERATION BELOW, MISMATCHES IN \033[91mRED\033[0m")
print(translate_gt)
for i in range(len(translate_gt)):
    if translate_res[i] == translate_gt[i]:
        print(translate_res[i], end="")
    else:
        print(f"\033[91m{translate_res[i]}\033[0m", end="")


## Reverse translation generation

In [None]:
# you can download this model from Tencent Weiyun Disk.
model_path = "checkpoints/CD-GPT-1b-reverse-translation.pth"
state = torch.load(model_path, map_location="cpu")
model = CDGPT(cfg)
model.load_state_dict(state["model"], strict=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.half().to(device).eval()

In [6]:
prompt = f"<protein>{str(protein_record.seq)}</protein><reverse_translate><:>"
x = tokenizer.encode(prompt, eos=False, device=device)
output = model.generate(x,
                        max_new_tokens=1024,
                        temperature=0.8,
                        top_k=128,
                        top_p=0.0,
                        stop_ids=(tokenizer.bos, tokenizer.eos, tokenizer.pad)
                        )
output = tokenizer.decode(output.sequences)

In [7]:
output = output[len(prompt):]
reverse_translate_res = output.split("</mRNA>")[0].split("<mRNA>")[-1]

In [None]:
reverse_translate_gt = str(reverse_translate_record.seq)
print(f"GROUND TRUTH ABOVE, GENERATION BELOW, MISMATCHES IN \033[91mRED\033[0m")
print(reverse_translate_gt)
for i in range(len(reverse_translate_gt)):
    if i >= len(reverse_translate_res):
        break
    if reverse_translate_res[i] == reverse_translate_gt[i]:
        print(reverse_translate_res[i], end="")
    else:
        print(f"\033[91m{reverse_translate_res[i]}\033[0m", end="")