In [1]:
!pip install transformers --upgrade
!pip install datasets

[0m

In [6]:
import time
import scipy.sparse as sp
from transformers import EsmTokenizer, EsmModel
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim


In [7]:
path="data/"
# Read sequences
sequences = list()
with open(path+"sequences.txt", "r") as f:
    for line in f:
        sequences.append(line[:-1])

# Split data into training and test sets
sequences_train = list()
sequences_test = list()
proteins_test = list()
train_target = list()
with open(path+"graph_labels.txt", "r") as f:
    for i, line in enumerate(f):
        t = line.split(",")
        if len(t[1][:-1]) == 0:
            proteins_test.append(t[0])
            sequences_test.append(sequences[i])
        else:
            sequences_train.append(sequences[i])
            train_target.append(int(t[1][:-1]))

sequences_train = np.array(sequences_train)
train_target = np.array(train_target)
sequences_test = np.array(sequences_test)

In [8]:
# train set
y_train_g = train_target[:4400].copy()
sequences_train_g=sequences_train[:4400].copy()
# validation
y_val_g = train_target[4400:].copy()
sequences_val_g=sequences_train[4400:].copy()


In [5]:
# Initialize device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [7]:
from transformers import EsmForSequenceClassification

model = EsmForSequenceClassification.from_pretrained("facebook/esm2_t30_150M_UR50D", num_labels=18,output_hidden_states=True) #esm2_t6_8M_UR50D

Some weights of the model checkpoint at facebook/esm2_t33_650M_UR50D were not used when initializing EsmForSequenceClassification: ['lm_head.layer_norm.weight', 'esm.contact_head.regression.weight', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.layer_norm.bias', 'esm.contact_head.regression.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight']
- This IS expected if you are initializing EsmForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['classi

In [9]:
from transformers import EsmTokenizer

tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")

Downloading:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/125 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

In [10]:

inputs = tokenizer("A", return_tensors="pt")

with torch.no_grad():
    print(model(**inputs).hidden_states[1].shape)


torch.Size([1, 3, 1280])


In [10]:
my_dict_train = {'label': y_train_g,     'text': sequences_train_g}
my_dict_val = {'label': y_val_g,     'text': sequences_val_g}

In [11]:
from datasets import Dataset
dataset_train = Dataset.from_dict(my_dict_train)
dataset_val = Dataset.from_dict(my_dict_val)

In [None]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)


tokenized_datasets_train = dataset_train.map(tokenize_function, batched=True)
tokenized_datasets_val = dataset_val.map(tokenize_function, batched=True)



  0%|          | 0/5 [00:00<?, ?ba/s]

Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


  0%|          | 0/1 [00:00<?, ?ba/s]

In [None]:
small_train_dataset = tokenized_datasets_train.shuffle(seed=42).select(range(4000))
small_eval_dataset = tokenized_datasets_val.shuffle(seed=42).select(range(400))

In [None]:
import numpy as np
from datasets import load_metric,list_metrics

metric = load_metric("accuracy")

Downloading builder script:   0%|          | 0.00/1.65k [00:00<?, ?B/s]

In [12]:
# list_metrics()

In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [20]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch",num_train_epochs=8)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [21]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

In [22]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `EsmForSequenceClassification.forward` and have been ignored: text. If text are not expected by `EsmForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 4000
  Num Epochs = 8
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 4000
  Number of trainable parameters = 7845778


Epoch,Training Loss,Validation Loss


Saving model checkpoint to test_trainer/checkpoint-500
Configuration saved in test_trainer/checkpoint-500/config.json
Model weights saved in test_trainer/checkpoint-500/pytorch_model.bin
The following columns in the evaluation set don't have a corresponding argument in `EsmForSequenceClassification.forward` and have been ignored: text. If text are not expected by `EsmForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 400
  Batch size = 8
  result = getattr(asarray(obj), method)(*args, **kwds)


ValueError: operands could not be broadcast together with shapes (400,18) (7,400,1024,320) 

In [51]:
PATH="test_trainer/checkpoint-5000/pytorch_model.bin"
model.load_state_dict(torch.load(PATH))
model.eval()

EsmForSequenceClassification(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 320, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 320, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0): EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=320, out_features=320, bias=True)
              (key): Linear(in_features=320, out_features=320, bias=True)
              (value): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((320,), eps=1e-05, eleme

In [11]:
soft = nn.Softmax(dim=1)
def tok(sample):
    return tokenizer(sample, return_tensors="pt")

In [19]:
with torch.no_grad():
    # logits = model(**tokenized_datasets_test).logits
    logits=np.array([ soft(model(**tok(s)).logits)[0].tolist() for s in sequences_test])


KeyboardInterrupt: 

In [56]:
import csv
y_hat_proba=logits
with open('Submissions/fakir_submission.csv', 'w') as csvfile:
    writer = csv.writer(csvfile, delimiter=',')
    lst = list()
    for i in range(18):
        lst.append('class'+str(i))
    lst.insert(0, "name")
    writer.writerow(lst)
    for i, protein in enumerate(proteins_test):
        lst = y_hat_proba[i,:].tolist()
        lst.insert(0, protein)
        writer.writerow(lst)