# Transformers

Note: it's important to use accelerate==0.21.0 and transformers==4.31.0

### Set up imports

In [1]:
import os
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.95'
os.environ['PYTORCH_MPS_LOW_WATERMARK_RATIO'] = '0.05'

from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import Trainer, TrainingArguments
import torch
import requests



  from .autonotebook import tqdm as notebook_tqdm


### Get Yarrowia protein sequences

In [2]:
# a function to get amino acid data from Yarrowia lipolytica
def download_sequences():
    # Define the URL for the API query
    uniprot_api_endpoint = 'https://rest.uniprot.org/uniprotkb/stream?format=fasta&query=%28yarrowia+lipolytica%29'
    
    # Send a GET request to the UniProt API
    response = requests.get(uniprot_api_endpoint)
    
    # Check if the request was successful
    if response.status_code == 200:
        # Print or save the sequences
        return response.text  # or save to a file
    else:
        print("Failed to retrieve data:", response.status_code)
        return ''

In [3]:
# Call the function to download sequences
aa_sequences = download_sequences()

len(aa_sequences)

10661518

In [4]:
# split the amino acid data into separate strings for each protein
aa_sequences = aa_sequences.split('>')

len(aa_sequences)

17951

In [5]:
# inspect the length of the list
len(aa_sequences)

17951

In [6]:
# look at the first protein sequence
aa_sequences[0]

''

In [7]:
# just isolate the first 1000 sequences
aa_sequences = aa_sequences[1:1001]

len(aa_sequences)

1000

In [8]:
# look at the first protein sequence
aa_sequences[0][:200]

'tr|A0A1H6PWB9|A0A1H6PWB9_YARLL Pentafunctional AROM polypeptide OS=Yarrowia lipolytica OX=4952 GN=ARO1 PE=3 SV=1\nMFAEGQIQKVPILGKESIHIGYKMQDHIVSEIVANIKSSTYILVTDTNIEDLGYVESLKT\nKFEAAFAKDGIKSRLLTYTVAPGETS'

In [9]:
len(aa_sequences[1])

785

In [16]:
clean_sequences = []
labels = []

# remove all non amino acid text
for seq in aa_sequences:
    label = seq.split('OS=')[0]

    label_parts = label.split(' ')

    label_parts = label_parts[1:]

    # combine the label parts
    label = ' '.join(label_parts)

    labels.append(label)

    # remove description text
    clean_seq = seq.split('SV=')[1]

    # remove new line characters
    clean_seq = clean_seq.replace('\n', '')

    # remove the first two characters
    clean_seq = clean_seq[2:]

    # shorten to just the first 256 characters
    clean_seq = clean_seq[:256]

    # add the sequence to the list
    clean_sequences.append(clean_seq)


### Define the model

In [17]:
# define a tokenizer and a model
try:
    tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
    model = GPT2LMHeadModel.from_pretrained('distilgpt2')
except Exception as e:
    print(e)


### Model fine-tuning

In [24]:
from torch.nn.functional import pad

# Tokenizing labels and clean sequences
inputs = tokenizer(
    labels,
    return_tensors='pt', 
    truncation=True, 
    padding=True,
    max_length=256
)

input_labels = tokenizer(
    clean_sequences, 
    return_tensors='pt', 
    truncation=True, 
    padding=True,
    max_length=256
)

# Ensuring all tokenized inputs are of the same length
max_len = 256  # Define your max length
input_ids_padded = pad(inputs['input_ids'], (0, max_len - inputs['input_ids'].shape[1]), value=tokenizer.pad_token_id)
labels_padded = pad(input_labels['input_ids'], (0, max_len - input_labels['input_ids'].shape[1]), value=tokenizer.pad_token_id)

# Assign padded labels to inputs for consistency in training
inputs['labels'] = labels_padded


In [19]:
from torch.utils.data import Dataset, DataLoader

class SimpleDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __len__(self):
        return len(self.encodings['input_ids'])

    def __getitem__(self, idx):
        return {key: val[idx] for key, val in self.encodings.items()}

# Assuming 'inputs' is the output from your tokenizer
train_dataset = SimpleDataset(inputs)
eval_dataset = SimpleDataset(inputs)  # Assuming eval and train are the same for the example


In [25]:
print("Input IDs shape:", inputs['input_ids'].shape)
print("Labels shape:", inputs['labels'].shape)

# Ensure they match
assert inputs['input_ids'].shape[0] == inputs['labels'].shape[0], "Mismatch in batch sizes"


Input IDs shape: torch.Size([1000, 41])
Labels shape: torch.Size([1000, 256])


In [26]:
# Set training arguments
training_args = TrainingArguments(
  output_dir = './results',
  num_train_epochs = 3,
  per_device_train_batch_size = 4,
  per_device_eval_batch_size = 4,
  warmup_steps = 100,
  weight_decay = 0.01,
  logging_dir = './logs',
  logging_steps = 10, 
)

# Create a Trainer instance
trainer = Trainer(
  model = model,
  args = training_args,
  train_dataset = train_dataset,
  eval_dataset = eval_dataset
)

# Train the model
trainer.train()

  0%|          | 0/1500 [02:39<?, ?it/s]


ValueError: Expected input batch_size (160) to match target batch_size (720).

### Generate new protein outputs

In [None]:
# Generate outputs

