# Transformers

Note: it's important to use accelerate==0.21.0 and transformers==4.31.0

### Set up imports

In [1]:
import os
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.95'
os.environ['PYTORCH_MPS_LOW_WATERMARK_RATIO'] = '0.05'

from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import Trainer, TrainingArguments
import torch
from torch.utils.data import Dataset
from torch.nn import DataParallel
import requests



  from .autonotebook import tqdm as notebook_tqdm


### Get Yarrowia protein sequences

In [2]:
# a function to get amino acid data from Yarrowia lipolytica
def download_sequences():
    # Define the URL for the API query
    uniprot_api_endpoint = 'https://rest.uniprot.org/uniprotkb/stream?format=fasta&query=%28yarrowia+lipolytica%29'
    
    # Send a GET request to the UniProt API
    response = requests.get(uniprot_api_endpoint)
    
    # Check if the request was successful
    if response.status_code == 200:
        # Print or save the sequences
        return response.text  # or save to a file
    else:
        print("Failed to retrieve data:", response.status_code)
        return ''

In [3]:
# Call the function to download sequences
aa_sequences = download_sequences()

len(aa_sequences)

10661518

In [4]:
# split the amino acid data into separate strings for each protein
aa_sequences = aa_sequences.split('>')

len(aa_sequences)

17951

In [5]:
# look at the first protein sequence
aa_sequences[0]

''

In [6]:
# look at the first protein sequence
aa_sequences[0][:200]

''

In [7]:
len(aa_sequences[0])

0

### Create lists of amino acid sequences and enzyme names

In [8]:
clean_sequences = []

# remove all non amino acid text
for seq in aa_sequences:
    label = seq.split('OS=')[0]

    label_parts = label.split(' ')

    label_parts = label_parts[1:]

    # combine the label parts
    label = ' '.join(label_parts)

    if 'SV=' not in seq:
        print('Skipping sequence')
        continue
        
    else:
        clean_seq = seq.split('SV=')[1]

        # remove new line characters
        clean_seq = clean_seq.replace('\n', '')

        # remove the first two characters
        clean_seq = clean_seq[2:]

        # shorten to just the first 256 characters
        # clean_seq = clean_seq[:256]

        # add the sequence to the list
        clean_sequences.append(f'{label}: {clean_seq}')


Skipping sequence


In [9]:
# print first five sequeces
for i in range(5):
    # print(labels[i])
    print(clean_sequences[i])
    print()

Pentafunctional AROM polypeptide : FAEGQIQKVPILGKESIHIGYKMQDHIVSEIVANIKSSTYILVTDTNIEDLGYVESLKTKFEAAFAKDGIKSRLLTYTVAPGETSKSRATKAAIEDWMLSKGCTRDTVILAVGGGVIGDMIGYVAATFMRGVRFVQIPTTLLAMVDSSIGGKTAIDTPLGKNLVGAFWQPVNIFIDTSFLETLPVREFINGMAEVIKTAAFYDAEEFTRLESASEIFLSTIKKRDAKDPRRVDLSPITDTIGRIVLGSARIKAAVVSADEREGGLRNLLNFGHSIGHAYEAILTPYILHGECVAIGMVKEAELSRYLGILSPVAVARLAKCIKAYELPVSLDDATVKARSHGKKCPVDDLLRIMGVDKKNDGSTKKIVILSAIGKTHEQKASSVADKDIRFVLSEEVIVGEAPVGDKKSYTVTPPGSKSISNRAFVLTALGKGPCKLRNLLHSDDTQHMLEAIELLGGASFEWEADGETLLVTGNGGKLTAPAQELYLGNAGTASRFLTTAATLVQKGDKDHVILTGNKRMQERPIGPLVDALRSNGADIAFQNAEGSLPLKIEAGVGLKGGLIEVAATVSSQYVSSLLMCAPYAQTPVTLSLVGGKPISQFYIDMTIAMMADFGVVVTKDETKEHTYHIPQGVYTNPEEYVVESDASSATYPLAYAAMTGHTVTVPNIGSKSLQGDARFAIDVLKAMGCTVEQTATSTTVTGVPNLKAIAVDMEPMTDAFLTACVVAAVSEGTTVITGIANQRVKECNRIEAMRVQLAKYGVVCRELEDGIEVDGISRSDLKTPVSVHSYDDHRVAMSFSLLSSIMAAPVAIEERRCVEKTWPGWWDVLSGVFNVPLEGVTLAKTVSKAESGLSKPSIFIVGMRGAGKTHLGAQAANHLGYEFIDLDQLLEKDLDTTIPQLIADKGWDHFRAEELRLLKQCLNDKSEGYVISCGGGVVETPAARDALQTFKGVGGIVLHVHRPV

In [23]:
# only use the first 5000 sequences
clean_sequences = clean_sequences[:5000]

### Define the tokenizer

In [24]:
# Assuming you are using only standard amino acid letters which GPT2 can handle, otherwise you need to customize it.
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})


1

### Define the model

In [39]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.resize_token_embeddings(len(tokenizer))

# model = DataParallel(model)

Embedding(50258, 768)

### Model fine-tuning

In [40]:
class ProteinDataset(Dataset):
    def __init__(self, sequences, tokenizer, max_length):
        # Tokenize all sequences
        self.encodings = tokenizer(sequences, max_length=max_length, truncation=True, padding="max_length", return_tensors="pt")

    def __len__(self):
        return len(self.encodings['input_ids'])

    def __getitem__(self, idx):
        # This method should return a dictionary with the keys the model expects
        item = {key: self.encodings[key][idx] for key in self.encodings}
        item['labels'] = item['input_ids'].clone()
        return item

# Create the dataset
# max_length = 256  # Max length of sequences
# max_length = 1024  # Max length of sequences
max_length = 512  # Max length of sequences
dataset = ProteinDataset(clean_sequences, tokenizer, max_length)


In [46]:
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=1,              # number of training epochs
    # gradient_accumulation_steps=4,  # Accumulates gradients across steps
    per_device_train_batch_size=8,  # batch size for training
    per_device_eval_batch_size=16,   # batch size for evaluation
    warmup_steps=100,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=100,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)


In [47]:
trainer.train()


  1%|          | 15/1250 [01:11<1:37:32,  4.74s/it]
 16%|█▌        | 100/625 [01:52<10:11,  1.16s/it]
 16%|█▌        | 100/625 [01:52<10:11,  1.16s/it]

{'loss': 3.5028, 'learning_rate': 5e-05, 'epoch': 0.16}


 32%|███▏      | 200/625 [03:43<07:44,  1.09s/it]
 32%|███▏      | 200/625 [03:43<07:44,  1.09s/it]

{'loss': 3.0357, 'learning_rate': 4.047619047619048e-05, 'epoch': 0.32}


 48%|████▊     | 300/625 [05:33<05:57,  1.10s/it]
 48%|████▊     | 300/625 [05:33<05:57,  1.10s/it]

{'loss': 2.9694, 'learning_rate': 3.095238095238095e-05, 'epoch': 0.48}


 64%|██████▍   | 400/625 [07:22<04:07,  1.10s/it]
 64%|██████▍   | 400/625 [07:22<04:07,  1.10s/it]

{'loss': 3.0007, 'learning_rate': 2.1428571428571428e-05, 'epoch': 0.64}


 80%|████████  | 500/625 [09:13<02:18,  1.11s/it]
 80%|████████  | 500/625 [09:13<02:18,  1.11s/it]

{'loss': 2.8896, 'learning_rate': 1.1904761904761905e-05, 'epoch': 0.8}


 96%|█████████▌| 600/625 [11:06<00:27,  1.09s/it]
 96%|█████████▌| 600/625 [11:06<00:27,  1.09s/it]

{'loss': 2.9846, 'learning_rate': 2.3809523809523808e-06, 'epoch': 0.96}


100%|██████████| 625/625 [11:34<00:00,  1.10s/it]
100%|██████████| 625/625 [11:34<00:00,  1.11s/it]

{'train_runtime': 694.1215, 'train_samples_per_second': 7.203, 'train_steps_per_second': 0.9, 'train_loss': 3.060752014160156, 'epoch': 1.0}





TrainOutput(global_step=625, training_loss=3.060752014160156, metrics={'train_runtime': 694.1215, 'train_samples_per_second': 7.203, 'train_steps_per_second': 0.9, 'train_loss': 3.060752014160156, 'epoch': 1.0})

### Generate new protein outputs

In [48]:
# Generate protein sequences with attention_mask and pad_token_id set
prompt = 'Pyruvate Dehydrogenase: '
input_ids = tokenizer.encode(prompt, return_tensors='pt')
attention_mask = torch.ones(input_ids.shape, dtype=torch.long)  # Create an attention mask (all ones)

# Ensure that the pad_token_id is set in the tokenizer and model
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    model.resize_token_embeddings(len(tokenizer))

# switch to cpu rather than gpu
device = torch.device("cpu")
model.to(device)
input_ids = input_ids.to(device)

output = model.generate(
    input_ids=input_ids,
    attention_mask=attention_mask,
    max_length=512,
    num_return_sequences=1,
    temperature=0.7,
    pad_token_id=tokenizer.pad_token_id
)

decoded_output = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output]
for seq in decoded_output:
    print(seq)
    print()


Pyruvate Dehydrogenase: erythylalanine-protein kinase : SLLKLQKQKQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQ



### Define a function for generating amino acid sequences

In [49]:
def generate_aa_sequence(prompt):
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    attention_mask = torch.ones(input_ids.shape, dtype=torch.long)  # Create an attention mask (all ones)

    output = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=512,
        num_return_sequences=1,
        temperature=0.7,
        pad_token_id=tokenizer.pad_token_id
    )

    decoded_output = tokenizer.decode(output[0], skip_special_tokens=True) 
    
    return decoded_output

generate_aa_sequence('Pyruvate Dehydrogenase ')

'Pyruvate Dehydrogenase erythrocyte : KQKQKQKQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQ'

In [50]:
prompts = [
    'Pyruvate Dehydrogenase : ', 
    'ATP Synthase : ', 
    'Cytochrome C Oxidase : ', 
    'NADH Dehydrogenase :', 
    'Succinate Dehydrogenase : '
]

for prompt in prompts:
    print(prompt)
    print(generate_aa_sequence(prompt))
    print()

Pyruvate Dehydrogenase : 
Pyruvate Dehydrogenase : ɛVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVV

ATP Synthase : 
ATP Synthase : ɪ to,-- :-VQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQ