# Transformers

Note: it's important to use accelerate==0.21.0 and transformers==4.31.0

### Set up imports

In [1]:
import os
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.95'
os.environ['PYTORCH_MPS_LOW_WATERMARK_RATIO'] = '0.05'

from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import Trainer, TrainingArguments
import torch
import requests



  from .autonotebook import tqdm as notebook_tqdm


### Get Yarrowia protein sequences

In [2]:
# a function to get amino acid data from Yarrowia lipolytica
def download_sequences():
    # Define the URL for the API query
    uniprot_api_endpoint = 'https://rest.uniprot.org/uniprotkb/stream?format=fasta&query=%28yarrowia+lipolytica%29'
    
    # Send a GET request to the UniProt API
    response = requests.get(uniprot_api_endpoint)
    
    # Check if the request was successful
    if response.status_code == 200:
        # Print or save the sequences
        return response.text  # or save to a file
    else:
        print("Failed to retrieve data:", response.status_code)
        return ''

In [3]:
# Call the function to download sequences
aa_sequences = download_sequences()

len(aa_sequences)

10661518

In [4]:
# split the amino acid data into separate strings for each protein
aa_sequences = aa_sequences.split('>')

len(aa_sequences)

17951

In [5]:
# look at the first protein sequence
aa_sequences[0]

''

In [6]:
# just isolate the first 1000 sequences
aa_sequences = aa_sequences[1:1001]

len(aa_sequences)

1000

In [7]:
# look at the first protein sequence
aa_sequences[0][:200]

'tr|A0A1H6PWB9|A0A1H6PWB9_YARLL Pentafunctional AROM polypeptide OS=Yarrowia lipolytica OX=4952 GN=ARO1 PE=3 SV=1\nMFAEGQIQKVPILGKESIHIGYKMQDHIVSEIVANIKSSTYILVTDTNIEDLGYVESLKT\nKFEAAFAKDGIKSRLLTYTVAPGETS'

In [8]:
len(aa_sequences[0])

1695

### Create lists of amino acid sequences and enzyme names

In [9]:
clean_sequences = []
labels = []

# remove all non amino acid text
for seq in aa_sequences:
    label = seq.split('OS=')[0]

    label_parts = label.split(' ')

    label_parts = label_parts[1:]

    # combine the label parts
    label = ' '.join(label_parts)

    labels.append(label)

    # remove description text
    clean_seq = seq.split('SV=')[1]

    # remove new line characters
    clean_seq = clean_seq.replace('\n', '')

    # remove the first two characters
    clean_seq = clean_seq[2:]

    # shorten to just the first 256 characters
    clean_seq = clean_seq[:256]

    # add the sequence to the list
    clean_sequences.append(clean_seq)


In [10]:
# print first five sequeces
for i in range(5):
    print(labels[i])
    print(clean_sequences[i])
    print()

Pentafunctional AROM polypeptide 
FAEGQIQKVPILGKESIHIGYKMQDHIVSEIVANIKSSTYILVTDTNIEDLGYVESLKTKFEAAFAKDGIKSRLLTYTVAPGETSKSRATKAAIEDWMLSKGCTRDTVILAVGGGVIGDMIGYVAATFMRGVRFVQIPTTLLAMVDSSIGGKTAIDTPLGKNLVGAFWQPVNIFIDTSFLETLPVREFINGMAEVIKTAAFYDAEEFTRLESASEIFLSTIKKRDAKDPRRVDLSPITDTIGRIVLGSARIKAAVV

Fatty acid photodecarboxylase, chloroplastic 
ASITSRASARASCSQANTRAGRVALSGGALLRPARPARSFVPARKQQQGAVRRGGALSARASAVEDIRKVLSDSSSPVAGQKYDYILVGGGTAACVLANRLSADGSKRVLVLEAGPDNTSRDVKIPAAITRLFRSPLDWNLFSELQEQLAERQIYMARGRLLGGSSATNATLYHRGAAGDYDAWGVEGWSSEDVLSWFVQAETNADFGPGAYHGSGGPMRVENPRYTNKQLHTAFFKAAEEVGLTPNSDFNDWSHD

Fatty acyl-CoA reductase 6, chloroplastic 
ATTNVLATSHAFKLNGVSYFSSFPRKPNHYMPRRRLSHTTRRVQTSCFYGETSFEAVTSLVTPKTETSRNSDGIGIVRFLEGKSYLVTGATGFLAKVLIEKLLRESLEIGKIFLLMRSKDQESANKRLYDEIISSDLFKLLKQMHGSSYEAFMKRKLIPVIGDIEEDNLGIKSEIANMISEEIDVIISCGGRTTFDDRYDSALSVNALGPGRLLSFGKGCRKLKLFLHFSTAYVTGKREGTVLETPLCIGENITSD

Fatty acid synthase subunit beta 
DAYSTRPLTLSHGSLEHVLLVPTASFFIASQLQEQFNKILPEPTEGFAADDEPTTPAELVGKFLGYVSS

### Define the tokenizer

In [11]:
# Assuming you are using only standard amino acid letters which GPT2 can handle, otherwise you need to customize it.
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})


1

### Define the model

In [12]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.resize_token_embeddings(len(tokenizer))


Embedding(50258, 768)

### Model fine-tuning

In [24]:
from torch.utils.data import Dataset

class ProteinDataset(Dataset):
    def __init__(self, sequences, tokenizer, max_length):
        # Tokenize all sequences
        self.encodings = tokenizer(sequences, max_length=max_length, truncation=True, padding="max_length", return_tensors="pt")

    def __len__(self):
        return len(self.encodings['input_ids'])

    def __getitem__(self, idx):
        # This method should return a dictionary with the keys the model expects
        item = {key: self.encodings[key][idx] for key in self.encodings}
        item['labels'] = item['input_ids'].clone()  # If doing language modeling
        return item

# Create the dataset
max_length = 256  # Max length of sequences
dataset = ProteinDataset(clean_sequences, tokenizer, max_length)


In [25]:
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=3,              # number of training epochs
    per_device_train_batch_size=4,  # batch size for training
    per_device_eval_batch_size=8,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=10,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    # You can also add eval_dataset if you have validation data
)


In [27]:
# Assuming dataset is an instance of your dataset class
sample = dataset[0]


print(sample.keys())  # This will show you all the keys in the dictionary returned by __getitem__
print(sample)  # This might help you to see the actual values if they are not too large


dict_keys(['input_ids', 'attention_mask', 'labels'])
{'input_ids': tensor([ 7708,  7156,    48, 33866,    42,  8859,  4146,    38,    42,  1546,
           40,    39,  3528,    56,    42, 49215, 41473,  3824,  5188,  3824,
         1565, 18694,    50,  2257,    56,  4146,    53, 21016, 46559, 19767,
           43, 31212,    53,  1546,    43, 42176,    42, 15112,  3838,  7708,
           42,    35,    38, 18694, 12562,  3069,  9936,  6849,  2969, 18851,
        18831, 12562,  1404,    42,  3838, 19767,    54,  5805, 18831,    38,
         4177, 35257,  6849,  4146, 10116, 11190, 37094,  3528, 23127,  3528,
           56, 11731,  1404,    37, 13599,    38, 13024,    37,    53,    48,
         4061, 15751,  3069,  2390,    53,  5258,    50,  3528,    38,    42,
         5603,  2389,    51,  6489,    38, 29132, 30976,    38,  8579,    54,
           48,    47,    53,    45,  5064,  2389,  4694,    37,  2538, 14990,
           47,    53, 31688,  2751,  5673,    36, 12861, 42176, 38540,    5

In [28]:
trainer.train()


  0%|          | 0/750 [06:37<?, ?it/s]
  0%|          | 1/750 [00:09<2:00:15,  9.63s/it]

RuntimeError: MPS backend out of memory (MPS allocated: 3.27 GB, other allocations: 542.11 MB, max allowed: 3.79 GB). Tried to allocate 9.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

### Generate new protein outputs