In [None]:
import os
import sys
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.notebook import trange
from tqdm.notebook import tqdm

In [None]:
# git for functions loading and work path finding
import git

repo = git.Repo('.', search_parent_directories=True)
work_path = Path(repo.working_tree_dir)
if str(work_path) not in sys.path:
    sys.path.append(str(work_path))

In [None]:
# package for model training
# logging
from comet_ml import Experiment

# tokenizer
from sklearn.preprocessing import LabelEncoder

# pytorch
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

# transformer encoder
# https://github.com/fkodom/transformer-from-scratch
# https://medium.com/the-dl/transformers-from-scratch-in-pytorch-8777e346ca51
from function.dlcode.atten import TransformerEncoder as TransformerEncoderScratch

# cosine lr rate
# https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts

# mocov3
# code from Facebook's Github licensed by CC BY-NC 4.0 with slightly modified
# https://github.com/facebookresearch/moco-v3
import function.dlcode.moco_builder as moco_builder

In [None]:
# some comet.ml logging code
if __name__ == '__main__' and '__file__' not in globals():
    logging = False
    if logging:
        api_key = ""  # comet api key used for logging

# dataset

In [None]:
class SeqProcess():

    def __init__(self, repeat_padding_target_length=512):
        le = LabelEncoder()
        le.fit(list('OACDEFGHIKLMNPQRSTVWY'))
        self.encoder = le
        self.repeat_padding_target_length = repeat_padding_target_length

    def seq_process_pipe(self, seq_list):
        encoded_seq_list = []
        for seq in seq_list:
            #repeat_string
            seq = self.repeat_string(
                seq, target_length=self.repeat_padding_target_length)
            #add cls token 'O'
            seq = 'O' + seq
            #encode
            seq = self.encoder.transform(list(seq))
            encoded_seq_list.append(seq)
        encoded_seq = torch.tensor(np.stack(encoded_seq_list))

        return encoded_seq

    def repeat_string(self, a_string, target_length=256):
        '''
        https://www.kite.com/python/answers/how-to-repeat-a-string-in-python
        '''
        number_of_repeats = target_length // len(a_string) + 1
        a_string_repeated = a_string * number_of_repeats
        a_string_repeated_to_target = a_string_repeated[:target_length]
        return a_string_repeated_to_target

In [None]:
class SeqDataset(Dataset):

    def __init__(self):
        self.df = df
        self.homology_ids = df['homology_id'].unique()
        self.seqprocess = seqprocess

    def __len__(self):
        return len(self.homology_ids)

    def __getitem__(self, index):
        homology_id = self.homology_ids[index]

        # get human q
        q = self.df[(self.df['homology_id'] == homology_id) & (self.df["is_human_seq"] == 1)]
        q_seq = q['frag_seq'].tolist()
        q_seq_tokened = self.seqprocess.seq_process_pipe(q_seq)

        # get k as postive sample
        q_as_k = self.df[self.df['homology_id'] == homology_id]
        q_as_k_seq = q_as_k.sample(1, replace=True, weights='q_prob_pure')['frag_seq'].tolist()
        q_as_k_seq_tokened = self.seqprocess.seq_process_pipe(q_as_k_seq)  # max_length + encode

        return {
            "q": {
                'token': q_seq_tokened,
                'seq': q_seq,
            },
            "k": {
                'token': q_as_k_seq_tokened,
                'seq': q_as_k_seq,
            }
        }

    def collate_fn(self, data):
        q_seq_tokened, q_seq, k_seq_tokened, k_seq, = \
        zip(*[(s['q']['token'], s['q']['seq'],
               s['k']['token'], s['k']['seq'],
               ) for s in data])
        q_seq_tokened = torch.stack(q_seq_tokened).squeeze(dim=1)
        k_seq_tokened = torch.stack(k_seq_tokened).squeeze(dim=1)

        return {
            "q": {
                'token': q_seq_tokened,
                'seq': q_seq,
            },
            "k": {
                'token': k_seq_tokened,
                'seq': k_seq,
            }
        }

# network

In [None]:
class AttenTorchScratch(nn.Module):

    def __init__(self):
        super(AttenTorchScratch, self).__init__()

        self.embed_dim = embed_dim
        self.atten_mlp_ratio = atten_mlp_ratio
        self.depth = depth
        self.num_heads = num_heads
        self.seq_length_with_cls = seq_length + 1
        self.num_tokens = 21

        self.embed = nn.Embedding(self.num_tokens, self.embed_dim)
        self.atten = TransformerEncoderScratch(
            num_layers=self.depth,
            dim_model=self.embed_dim,
            num_heads=self.num_heads,
            dim_feedforward=self.embed_dim * self.atten_mlp_ratio,
            dropout=0.1,
        )
        self.head = nn.Identity()

    def forward(self, x, return_all_atten=False):

        x = self.embed(x)
        x, all_atten = self.atten(x, return_all_atten)
        cls = x[:, 0]
        x = self.head(cls)

        return x, all_atten

# training param

In [None]:
# log, we used git commit id to log our experinment records
checkpoint_id = repo.commit().hexsha[:7]
memo = "temp_0.2_length_512_newdataset"

# loading dataset
# please change the path downloaded from OSF: https://osf.io/jk29b/
if __name__ == '__main__' and '__file__' not in globals():
    dataset_path = work_path / "1_prepare_training_data" / '1-3_vsl2_omaseq_with_prob.pkl'
    relative_dataset_path = str(dataset_path.relative_to(work_path))
    df = pd.read_pickle(dataset_path)

    # q_prob_pure is NaN indicating that there is only one frag sequnce in the homolog,
    # and can not perform contrastive pretext task
    df = df[df['q_prob_pure'].notnull()].reset_index(drop=True)

# dataset + dataloader
batch_size = 50
seq_length = 512  # seq_length + CLS = 513

# BaseEncoder transformer
depth = 6
num_heads = 8
embed_dim = 128
atten_mlp_ratio = 4
moco_mlp_dim = 128

# MoCo based
m = 0.999  # momentum
nce_temp = 0.2

# training param
max_epochs = 400
lr = 1e-3
warm_up_step = 10

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# logging

In [None]:
# for log
if __name__ == '__main__' and '__file__' not in globals():
    if logging:
        hyper_params = {
            "atten_mlp_ratio": atten_mlp_ratio,
            "batch_size": batch_size,
            "seq_length": seq_length,
            "embed_dim": embed_dim,
            "moco_mlp_dim": moco_mlp_dim,
            "depth": depth,
            "num_heads": num_heads,
            "moco_momentum": m,
            "max_epochs": max_epochs,
            "learning_rate": lr,
            "warm_up_step": warm_up_step,
            "nce_temp": nce_temp,
            "use_dataset": relative_dataset_path,
            "checkpoint_id": checkpoint_id,
        }

        exp_name = '{}-{}-mocov3'.format(checkpoint_id, memo)
        experiment = Experiment(api_key=api_key,
                                project_name="mocov3",
                                log_code=True,
                                auto_metric_logging=False,
                                auto_param_logging=False)

        experiment.set_name(exp_name)
        experiment.log_parameters(hyper_params)

# training

In [None]:
# load dataset
seqprocess = SeqProcess(repeat_padding_target_length=seq_length)
if __name__ == '__main__' and '__file__' not in globals():
    seqdataset = SeqDataset()
    train_dataset, _ = random_split(seqdataset, [28892 - 42, 42])
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  collate_fn=seqdataset.collate_fn,
                                  num_workers=os.cpu_count())

In [None]:
#training, only trining
if __name__ == '__main__' and '__file__' not in globals():

    model_save_path = str(work_path / "trained_weight.pt")
    moco = moco_builder.MoCo(base_encoder=AttenTorchScratch,
                             dim=embed_dim,
                             mlp_dim=moco_mlp_dim,
                             T=nce_temp).to(device)

    optimizer = optim.AdamW(moco.parameters(), lr=lr)
    scheduler = CosineAnnealingWarmupRestarts(optimizer,
                                              first_cycle_steps=30,
                                              cycle_mult=1.0,
                                              max_lr=lr,
                                              min_lr=lr / 100,
                                              warmup_steps=warm_up_step,
                                              gamma=1.0)

    #training
    for current_epoch in trange(max_epochs):
        training_loss = []
        moco.train()
        for batch in tqdm(train_dataloader, desc="train", leave=False):

            q = batch['q']['token'].to(device)
            k = batch['k']['token'].to(device)

            loss = moco(q, k, m)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            training_loss.append(loss.cpu().detach())

        scheduler.step()

        training_loss = torch.stack(training_loss).mean().numpy()

        if logging:
            experiment.log_metric(name="training_loss",
                                  value=training_loss,
                                  epoch=current_epoch)
            experiment.log_metric(name="learning_rate",
                                  value=round(scheduler.get_lr()[0], 12),
                                  epoch=current_epoch)

        torch.save(moco.state_dict(), model_save_path)

In [None]:
#close comet ml
if __name__ == '__main__' and '__file__' not in globals():
    if logging:
        experiment.end()