In [1]:
import pandas as pd
import torch.cuda
import torch

In [2]:
BATCH_SIZE = 1 # 1 because tokenizations sizes
torch.manual_seed(42)
torch.cuda.manual_seed(42)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4

In [3]:
train_df = pd.read_csv("../../data/custom_fragments2/datasets/random_equal_distribution/random_equal_train.csv")
val_df = pd.read_csv("../../data/custom_fragments2/datasets/random_equal_distribution/random_equal_val.csv")
test_df = pd.read_csv("../../data/custom_fragments2/datasets/random_equal_distribution/random_equal_test.csv")

In [4]:
print(len(train_df))
print(len(val_df))
print(len(test_df))

478600
119710
149590


In [5]:
train_df.head()

Unnamed: 0,source_accession,fragment_type,sequence,is_fragment,fragment_length,source_length
0,A3Q8S8,mixed,MNGEKVKRLSTLAETLPIQVITPESFSLLFEGPKSRRQFIDWGAFH...,True,233,360
1,Q6GLS4,terminal_N,MPFGCVTLGDKKDYNHPTEVSDRYDLGQLIKTEEFCEVFRAKEKSS...,True,191,377
2,P02607,complete,MCDFSEEQTAEFKEAFQLFDRTGDGKILYSQCGDVMRALGQNPTNA...,False,151,151
3,Q4WRR0,complete,MGLTSILIAQVLFLGAANSAVVKRWPCSVSPTGPTDPSVAKNCGYW...,False,299,299
4,A9INS6,complete,MELDIKAIMERLPHRYPMLLVDRVLDIVPGKSIVAIKNVSINEPFF...,False,151,151


In [6]:
# t5 setup
from transformers import T5Tokenizer, T5EncoderModel

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
t5model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [8]:
# dataset
from torch.utils.data import Dataset, DataLoader
import re

In [9]:
class FragDataset(Dataset):
    def __init__(self, dataframe: pd.DataFrame, tokenizer: T5Tokenizer):
        self.input_ids = []
        self.masks = []
        self.labels = []
        for row in dataframe.itertuples(index=False):
            seq = " ".join(list(re.sub(r"[UZOB]", "X", row.sequence)))
            ids = tokenizer.encode_plus(seq, add_special_tokens=True)
            self.input_ids.append(torch.tensor(ids["input_ids"]))
            self.masks.append(torch.tensor(ids["attention_mask"]))
            self.labels.append(torch.tensor(1.0 if row.is_fragment else 0.0, dtype=torch.float))

        self.length = len(self.input_ids)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return self.input_ids[idx], self.masks[idx], self.labels[idx]

In [10]:
train_dataset = FragDataset(train_df, tokenizer)
train_dataset[0]

(tensor([19, 17,  5,  9, 14,  6, 14,  8,  4,  7, 11,  4,  3,  9, 11,  4, 13, 12,
         16,  6, 12, 11, 13,  9,  7, 15,  7,  4,  4, 15,  9,  5, 13, 14,  7,  8,
          8, 16, 15, 12, 10, 21,  5,  3, 15, 20, 11, 10, 14,  7, 15, 18, 11,  3,
         21,  3, 17,  6,  8,  8, 12,  4, 14, 20,  8, 17, 16, 19,  4, 14,  7,  9,
         11, 13, 18, 16, 16, 12, 16, 15, 21, 10, 14,  9,  4,  6,  8, 18,  3,  9,
         12,  6, 11,  9, 12,  8, 14,  8, 18,  6,  5,  7,  4, 17,  9,  8,  4, 14,
          5, 12, 12,  9,  9, 15,  4, 13, 16,  6, 10,  6, 14,  6,  7, 15, 11,  8,
          5, 21, 10,  7,  4,  8,  6,  5, 11,  4, 13,  6, 16, 10,  3,  4,  7,  8,
          5, 16,  4, 14,  4,  4,  6, 22,  3,  4,  8, 12,  3, 16,  5, 14,  4,  4,
         14, 16, 16, 12, 10, 14, 17,  7, 12, 18,  4,  6, 10, 10,  4, 13,  7,  9,
          4, 10,  3,  8, 20,  8, 16,  4,  4,  4, 16, 16,  4,  7, 10, 11,  5,  3,
         16,  6, 15,  6, 11,  3, 12,  9, 13,  3,  3, 12, 19, 10,  7,  4, 17, 11,
         13, 13,  6, 14,  6,

In [11]:
val_dataset = FragDataset(val_df, tokenizer)
test_dataset = FragDataset(test_df, tokenizer)

In [12]:
print(len(train_dataset))
print(len(val_dataset))
print(len(test_dataset))

478600
119710
149590


In [13]:
# model
import torch.nn as nn

In [14]:
class SimpleModel(nn.Module):
    def __init__(self, t5_model: T5EncoderModel):
        super().__init__()

        self.t5 = t5_model
        self.net = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, input_ids, mask):
        with torch.no_grad():
            prot_emb = self.t5(input_ids=input_ids, attention_mask=mask).last_hidden_state[0].mean(dim=0)

        return self.net(prot_emb)

In [15]:
model = SimpleModel(t5model).to(DEVICE)

In [16]:
with torch.no_grad():
    data1 = train_dataset[0]
    print(model(data1[0].unsqueeze(0).to(DEVICE), data1[1].unsqueeze(0).to(DEVICE)))

tensor([0.0532], device='cuda:0')


In [17]:
with torch.no_grad():
    data1 = train_dataset[0]
    print(model(data1[0].unsqueeze(0).to(DEVICE), data1[1].unsqueeze(0).to(DEVICE)).squeeze(-1))

tensor(0.0532, device='cuda:0')


In [18]:
# train
loss_fn = nn.BCEWithLogitsLoss()
optim = torch.optim.Adam(model.net.parameters(), lr=LEARNING_RATE)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

In [26]:
model.train()
train_loss = 0.0
train_correct = 0
train_count = 0

for step, (input_ids, mask, label) in enumerate(train_loader, 1):
    input_ids = input_ids.to(DEVICE)
    mask = mask.to(DEVICE)
    label = label.to(DEVICE)

    optim.zero_grad()
    output = model(input_ids, mask)

    loss = loss_fn(output, label)
    loss.backward()
    optim.step()

    train_loss += loss.item()
    pred = (torch.sigmoid(output) >= 0.5).float()
    train_correct += (pred == label).sum().item()
    train_count += 1

    if train_count % 1000 == 0:
        print(f"Seq Count: {train_count}")

    if train_count % 10000 == 0:
        print(f"Train set {train_count}: Loss = {train_loss / 10000:.4f}, "
              f"Accuracry = {train_correct / 10000:.4f}")
        train_loss = 0.0
        train_correct = 0

        torch.save(model.state_dict(), f"./simple_model_step_{train_count}.pt")

    if train_count % 100000 == 0:
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_count = 0
        with torch.no_grad():
            for v_input_ids, v_mask, v_label in val_loader:
                v_input_ids = v_input_ids.to(DEVICE)
                v_mask = v_mask.to(DEVICE)
                v_label = v_label.to(DEVICE)

                v_output = model(v_input_ids, mask)

                v_loss = loss_fn(v_output, v_label)
                val_loss += loss
                v_pred = (torch.sigmoid(output) >= 0.5).float()
                val_correct += (pred == label).sum().item()
                val_count += 1

        print(f"Validation after {train_count} sequences: "
              f"Loss = {val_loss / val_count:.4f}, "
              f"Accuracy = {val_correct / val_count:.4f}")
        model.train()

Seq Count: 1000


KeyboardInterrupt: 