In [1]:
import torch
import os
import re
import tempfile
from Bio import AlignIO, SeqIO
from Bio.Align.Applications import ClustalOmegaCommandline
from Bio.Seq import Seq
from evodiff.pretrained import MSA_OA_DM_MAXSUB
import tempfile
from Bio import AlignIO, SeqIO
from Bio.Align.Applications import ClustalOmegaCommandline
from Bio.Seq import Seq
from evodiff.generate_msa import generate_query_oadm_msa_simple
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd


root = '../../data'
device = torch.device("cuda:3")


Due to the on going maintenance burden of keeping command line application
wrappers up to date, we have decided to deprecate and eventually remove these
modules.

We instead now recommend building your command line and invoking it directly
with the subprocess module.


# Collect Sequence of enzymes (.fasta)

In [None]:
# collat_data.ipynb 最後一個cell


## Create MSA (.a3m file)
**conda environment: evodiff** \
since the model only accept .a3m file

In [None]:

directory = os.path.join(root,'uniprot/research/motif')

pro_bar =tqdm( [file for file in os.listdir(os.path.join(directory, 'ec_num')) if file.endswith('.fasta')])
for filename in pro_bar:
    input_file = os.path.join(directory,'ec_num', filename)
    a3m_file = os.path.join(directory,'msa',filename[:-6]+'.a3m')
    fasta_file = os.path.join(directory,'msa',filename[:-6]+'.fasta')
    
    # os.makedirs(os.path.dirname(output_file), exist_ok=True)
        
    # Read in fasta file
    records = list(SeqIO.parse(input_file, 'fasta'))

    if len(records) > 2:
        pro_bar.set_postfix(message=f' Processing {filename}')
        # Check max length of all sequences
        max_length = max(len(record.seq) for record in records)

        # pro_bar = tqdm(records, desc=f'Processing {filename}')
        # Add padding
        for record in records:
            seq_length = len(record.seq)
            if seq_length < max_length:
                padding_length = max_length - seq_length
                padding_seq = Seq('-' * padding_length)  # 將填充字符從空格改為短橫線'-'，這是多序列對齊中更常用的間隔符。
                record.seq += padding_seq
            
            # time.sleep(0.1)

        # Create a temporary file for the padded sequences
        with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.fasta') as temp_fasta:
            SeqIO.write(records, temp_fasta, 'fasta')
            temp_fasta_name = temp_fasta.name

        # Run Clustal Omega
        with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.aln') as temp_aln:
            clustal_omega_cline = ClustalOmegaCommandline(
                cmd='clustalo', 
                infile=temp_fasta_name,
                outfile=temp_aln.name,
                verbose=True,
                force=True
            )
            stdout, stderr = clustal_omega_cline()

        # Read the alignment and write it as fasta
        alignment = AlignIO.read(temp_aln.name, 'fasta')

        AlignIO.write(alignment, fasta_file, 'fasta')


        # Convert to A3M format manually
        with open(a3m_file , 'w') as handle:
            for record in alignment:
                handle.write(f">{record.id}\n")
                sequence = str(record.seq).replace('.', '-')  # Replace '.' with '-'
                handle.write(f"{sequence}\n")        

        # print(f'MSA has been written to {output_file }') 
    else:
        pro_bar.set_postfix(message=f'{filename} contains 1 sequence, nothing to align')

## Conditional generation MSA: EC Number + Organism

In [None]:
from Bio import SeqIO
from evodiff.conditional_generation_msa import 

root = '../../data'
directory = os.path.join(root, 'uniprot/research/motif')

input_files = [file for file in os.listdir(os.path.join(directory, 'ec_num'))]

for input_file in input_files:

    records = list(SeqIO.parse(input_file, 'fasta'))



## Conditional generation MSA: EC Number + Species

## Conditional generation: EC Number + Organism

## Create Motif Logo

In [None]:
%matplotlib inline
plt.ion()

import logomaker as lm

directory = os.path.join(root, 'uniprot/research/msa')
pro_bar =tqdm( [file for file in os.listdir(directory) if file.endswith('.a3m')])
for filename in pro_bar:
    input_file = os.path.join(directory, filename)
    output_file = os.path.join(root, f'uniprot/research/plots/motif_logo/{filename[:-4]}_logo.png')
    csv_file = os.path.join(root, f'uniprot/research/plots/motif_csv/{filename[:-4]}_logo.csv')

    with open(input_file, 'r') as f:
        lines = f.readlines()

    # extract ww domain sequences
    seqs = [seq.strip().upper() for seq in lines if ('#' not in seq) and ('>') not in seq]
    # create counts matrix
    counts_df = lm.alignment_to_matrix(sequences=seqs, to_type='counts', characters_to_ignore='.-X')    
    counts_df.to_csv(csv_file, index=False)


    # filter base on counts
    num_seqs = counts_df.sum(axis=1)
    pos_to_keep = num_seqs > len(seqs)/2
    counts_df = counts_df[pos_to_keep]
    counts_df.reset_index(drop=True, inplace=True)

    # show full ww counts
    plt.figure(figsize=(3000, 3))
    # logo = lm.Logo(counts_df)
    logo = lm.Logo(counts_df)
    
    plt.title(f"Motif Logo: {filename[:-3]}")    
    
    plt.savefig(output_file, dpi=300, bbox_inches='tight')    
    
    pro_bar.set_postfix(message=f'Saved logo for {filename}')


## Generate new MSA 
using EvoDiff to generate new MSA

In [19]:
# from evodiff.generate_msa import generate_msa
from evodiff.generate_msa import *
from evodiff.pretrained import MSA_D3PM_UNIFORM_MAXSUB

checkpoint = MSA_OA_DM_MAXSUB()
# checkpoint = MSA_D3PM_UNIFORM_MAXSUB()
model, collater, tokenizer, scheme = checkpoint

directory = os.path.join(root,'uniprot/research/msa')
all_files =  [file for file in os.listdir(directory) if file.endswith('.a3m')]

for filename in all_files:
    print(f'Processing {filename}')
    msa_file = os.path.join(directory, filename)
    n_sequences=1 # number of sequences in MSA to subsample
    records = list(SeqIO.parse(msa_file, 'a3m'))

    if len(records) > 1:
        # pro_bar.set_postfix(message=f' Processing {filename}')
        # Check max length of all sequences
        max_length = max(len(record.seq) for record in records)

        
    # seq_length=256 # maximum sequence length to subsample
    seq_length=1024 if max_length>1024 else max_length # maximum sequence length to subsample
    selection_type='random' # or 'MaxHamming'; MSA subsampling scheme


    tokeinzed_sample, generated_sequence  = generate_query_oadm_msa_simple(msa_file, model.to(device), tokenizer, n_sequences, seq_length, device=device, selection_type=selection_type,)
    # tokeinzed_sample, generated_sequence  = generate_msa(msa_file, model.to(device), tokenizer, n_sequences, seq_length, device=device, selection_type=selection_type,)
    

    print("New sequence (no gaps, pad tokens)", re.sub('[!-]', '', generated_sequence[0][0],))
    print()

Processing (ec:3.2.1.1)AND(organism_name:Streptomyces strain A3)_reviewed0_2.a3m


100%|██████████| 1024/1024 [00:55<00:00, 18.57it/s]


New sequence (no gaps, pad tokens) AEHTEZEEKLZFKEFZZZFFZZLMFZZKSHSCIFTZEZZZZVZFFFFNEZZPZEZATEFEDAQELHZFLEGEZGSEEKEADYZREQZEGSTEZZPNNFERFZGDZALFFBZQZZTQENZAZSLQIEREDQVFFVYZVBKFEEZBZSZHSSFEEYGZLQVGBKGBMNMZSEGPFGZTPBZFFEYZSGZTBCKLLQNLBETFMZEFFKGTKLKYDFQTKEBEZPBEKBFNNTFCGTTBSEFBBFZQFHMRIFFFFSGQFZBZNYKQVEAFHGMILHZEGMTAHKFKDSHYBFLELTZZTZVNSTFTCZLTFEQZEEFHEZIGBHBYFRTZZKZJFGEVFKBZGZYBEGFZEBREZZKNFZZNGKLYGHZZZBBGKZEGZEBLCZGEZZZLZBFTFFBHBSZFFKZVGNEPECZKZLEZFZDKBZFZZLBBFEBGBSGFLFBFEBSEZLZIVKYTFBDHSKQZDGZCMZGBYZGGZZLBFKIZZGQDFGSVGKEAEEFTYEFVZQBNFAGVZBRZYZNFKMZIRVZWZBQPAEYBAFZYGAONFZFEVVBZEZRZZFEZKLNZBBBFFGQZYKZZFEFZZZLFZLZKKZAEYZACLZQKYZKVGZZZSLKLLZBZFPNFKSBZZZZNZLFZZKZZZZFLRZYHHZGZZFBVFZTKYZTZZFNZFZZZZZHVFFQIFRYKFVKAFSZIFVZZEFZZGFZDZKFZYZZWZKZKKBEEKYRZFRZBKZZGFSZZFFTQZPZEZZNKGGNFZFBFZSFKZZABFFPZLHFBYFFZIQSZZZZFFANZNEHQZRZZZSEHFSZZFTNKZFAEBFEZZEFEZZQGEFZZQFZZZNPAATZTFGZZASTLFZDZBZFFZZMZGFFBHZFEZYEVBZEKKHAABZVZBVZDBNIABTBZQZBENKFSZEALZAGZRZZLZFFZFZRAGLZHEHLZFKLDQBZZCDZGIZFEEGHFBGCYZZZLZNFZFBZKVEZPS

100%|██████████| 214/214 [00:02<00:00, 75.77it/s]


New sequence (no gaps, pad tokens) MNFNZKEZZZVFAFNGZSFLDFZJZGGZZGZLZZZNADKZZEVZHGGEVSEGNLINKIZZSVGEALTANTQZEVBALAINFWHIINZLSZVEZGZSAGLZGKKZFZALZGNAZAKVHZNNVFEREZHZKZTTEFTKIFEKAKVDLKKDVIRAIEZELZZEZVKGKVNHLKFEZEGVKIAKFZFAGTFLAVESEETGQZEZBVEGLZMZZAGDKZ

Processing (ec:3.1.1.3)AND(organism_name:Candida albicans (C. albicans))_reviewed0_48.a3m


100%|██████████| 985/985 [00:47<00:00, 20.77it/s]


New sequence (no gaps, pad tokens) IAFZKZZQGNSHPGABFSABEWYTAEEGELGFKZNLZEZNTZFVLKTNAASDTZEKHMZNZEMZZTEGEZZFLBSYEZKYTVZENFNBEVZVTVTTTZVLLGVBNHBDZCBNFEZFZGDDTRABEZSCTBTESELIKEKZVNKAELCBCGLZZFBBBLBVFGFZPFEATKQFKLTKTYMFKSFVNZZOATYBLABZMBRNQLBSGQFZFABZGTILQZBQVKBLNZAMZFTGIBLAMBTGZLDGZEZDEBFTIZRZHBTDDBSQZREEHDYFLAABAZBLFMNEEBIEFYZNGGDSCVEQZVAKKAZCGBLTZZFIFNZGAAYOTFHYYBZEITBMLLZGVRZZLKNNWFTGMMZEZKTEQNIRVGBQFTELRVNVDZLGBZFEKKNZSANQTZAMGCBDNTZKBKFVZNZFLKBHVPSSKFJBYMVZIAFKGHBAGZKHCNLZZENNNBEVNEYNZKLSNTTEDTZZLLZYVYAZEKADYFBFNZBTZQKZGANPZFESBBVGBZKTZIGTLKRZYNZBLZZFZKOGZFSAZZEZTFRALFTZZZBGIBTHNZSZZKZZEGEFNAZKFEVLZXIEKFKZZLGZTGFSSNBZBZEETLAEBZIBFLAZEIAMYZKGZZKFEIFQYYRASEAZFEFBAAAZBATGHVBZFHFMQPTLRZZZZZFEQZFSBFZFAZNZFYLBFEZNZZVZVHZHNNEBBATEAMBBAFSTGLTZBKRZIZLBZZNNZKEZFTBZKKNZLLZLFZHQZMBLGZZTESFFNZVZBZZLHESBQYBZQICFYLGZSZHARSFEINLYEZELZAZKRZBGTZGZZZHIRKRZZZVBVTZZDBBEKZLDNQFKZBZMLVHKBZZZZZTTBYYFZTZPNOFZBDNKTHFNYZZZAVAEELPFFIZEZZCZFWZZSZRBZNHTELZABFEGZJZBZZIALZBFZETZZYZMLSFFBBFVYZZQFVZLRHVRGZSZHEZFZMZZV

100%|██████████| 212/212 [00:03<00:00, 55.23it/s]


New sequence (no gaps, pad tokens) GAZVKKZFVTELTPJFZZTVZBYZVVEDZKTEERTSTZTZJFETTIZZVFVJSGZNTQLZVFRHTKJTQVHZFZZCZJAKFTEJZDLZZZKJYTAZTDBVAJAJCVRGZJGJVJFVNJZKZVKTKZGNYADZJAVFJZJAQZCTZSBZWZLZDTYZTALJTZQJZZPALZNVBJZSTKEAGKZKTZGEOZKFYJZTHTKZVKAZNLQZJZPZ

Processing (ec:3.1.1.3)AND(organism_name:Candida albicans (C. albicans))_reviewed1_11.a3m


100%|██████████| 604/604 [00:26<00:00, 22.50it/s]


New sequence (no gaps, pad tokens) MFNZZGZZSLBTZEZZZBRKZBFZFATGVZNFFVLCQTVEFAABZLRIDGNTTABBFRZZLZVBAWPKBVGLHLNZCRZVLGZSETTAQNBRTCTFZBEZNZTANNIBVYTFNKZLBBAZAGNBTBLLFNHAKILKPKFBRAENHGHHMTZFBCAYZHLNNDIGNEERZABBIESINZANNEBZZIHAEESBZVBTTNGIZYSAZCQLTZVYTVSZTSFTNGGZTMTETKZFCBRIZLBGLBNSMVZFKZNSTBINBEEETVNMZNKTNFBNSVFTADBTITZZZQZVBENTZVFCENHZFFAZGEZGZBYKITTTBELCKCLBZALASFZEBVTDKEADBZBFZTKALNTAFKZGSLJTDBZGVKVBABNEBELNNBBFBTAZVZEIVRFKPIZTBRJSCNZSNZSZZTABCZZGZTAZKAFKVTTTMNQSKBTHSFGYQTBFFKZZFNICZFBBBNEBLIBTTAIBJZSISPTAZFSBDAEKDZYZKEFLIGZANSZTZKATTVBGIVNEFBRSNBZMILFZFTMFJBZKZLZTZCZZFZZTZZNZKKZZYDVKZGMZZFZPTVDEABZKCYGQBHTZTNIZZKZSJTVZFALMEZVGZVSB

Processing (ec:3.4.21.62)AND(organism_name:Bacillus licheniformis)_reviewed1_2.a3m


100%|██████████| 379/379 [00:14<00:00, 26.25it/s]


New sequence (no gaps, pad tokens) MGNENKNYTFGAAWFZNSBKEKYEADZYGQYVGNEWGSFFZZEFCEKIYAZNZEILLOKBBFELEGZWAEIYHYWDNEEFGAFILVQRWCZCBFVYFHYDFEFEFZAMYQEYCBQFBGBZGINKYANYHYTGYDYGBAAZFIWJECZHYAEGBEYGLFNEIYGINQBAKYNZYGDNEPLNEPIYNGECZFFYQQFAESEEAPAVFWLLESKZQLCFYEVACEAQBEFYKAPFNNTFKGQFBEKLBEYHYMYVHLFYHIKFKGRNYGCEFYAIFEELBYVIFIELPNVZBFQGCBRHYGIYPKAEQYYFESLAIFJYTZGKEKLBZDDZIFAZFLYFHSFYLYPGTFNEKTDHBBFFLYAZEELTYARJMKAFVDFEKFI

Processing (ec:3.2.1.78)AND(organism_name:Lentinula edodes)_reviewed0_3.a3m


100%|██████████| 788/788 [00:53<00:00, 14.63it/s]


New sequence (no gaps, pad tokens) ZMTMFZZPFSVZHINIAKLBMAHAZBTZZBGEERTNKATEFCSFKNZZCFAZHEVGKLTZZZLATGSZNGYNRYEFZHHEZCEATGILZEZZYYKLDBEZOZFZFKAYZEQZZLZGZFZLBKZZZTBTETSBBZRZZFTHAZZZLTKZKZTZFBFHLCZBZMBAYZZKTBLTZEZLMANZCKBNNFETFBVFFZBEOOJZFHYBGBZFEHBZNZFRAGMZNLBENFHMFZBCRFNIZZKRTFFFBZZAKFLBZVABYFKVFMNAZTLZEVKREFECZLLNLBFBZZHGYAHVKZZHZZNBKBBFAZNKFHFZZBAVAHRGYELNSZVAZNBADZSGZQAVLKNCYZOTMEZOKMKYLAMZBTSDGCQPTPNDZZNNFBBFQEZSZALZLZBCZSFELZMAEBEBVYZBKGIGZFANZRZZZQGKZEHFSMBZHQBSNGQZYTPRSDLZGLMERQOFFCTZOOZZTEBZFQELMBGREAOIHZZFSZVKAOZZBVZAHEZLZBARZHEQEVZRZZAZBLZLSGZFZFYFZZAZINBSLRAZZFZZTZGZZZPZZKAPKZFVFSZZBRZZTDFRAVKZGDITTSZZZTNHECFZZZZFQLTZHEFAZZZZBZMGFQATZZBZZZPZAZZGZEKFZZYHHRZYZEAFZKBIKHFSZZAFHBZZTAKZHZZZKZAATKHTZFCAFBZGCEVDIQZCZZKZENTFZCZKZSZFFTWNTBAZZHHZRSHDCYZFAZGZNDZJGEZZKZFHKZYZFKHWZZAHFZFZFZZZNFZYZVZZDLSFSQSLZZZQMAAL

Processing (ec:3.2.1.4)AND(organism_name:Fibrobacter succinogenes S85)_reviewed0_19.a3m


100%|██████████| 1024/1024 [01:30<00:00, 11.31it/s]


New sequence (no gaps, pad tokens) MTZHPZZZAZHITKZGVBWBFTMKAEAZIZBZZZSBZZZBLZHAHBBITHGFZINHDASBTQZZZBVKSAZNZGILLFNZZFMQBKZAFSHZETSFFBZHBKZBZGQBFBTNAGZFAITDFBASZYIZKZHQQGKKGKVTMHBFBFYDZZKBENTLBBEFEMBNZFTKZZZMGELLHKTZZHBSZEGNZGNTNTKEANMGVEAZVTBZBHZTEVTZNMZZYBLKEZNTNQZCETRVZINGZNLKBSKBKLFAZHZZFYKEJLBKKALFKZBZFTVZBFHLFYNNZBNTSZZBNLZGIBTFKEVBFTBZBHZHKEZTZBBTENTFZAZAQGKZSTNGHRFGTEIFELFHBTIEZGLKTNRTBENDGBFBHNHKTBZGSIBKEHTZWYHREETBFTDEAKZSZPFSNKFFIAZOYEZZLQBHBBKHYZABTHTFYYFTLNHLBCTTFKNBFVTFHGBZBSRFDEPZFRIBGFNFZGBFZZQZQCFRNBZEFZATLBHTZBEVRHLBFOZNDEMZINBSZSYZNZAYZBMHZBLZZKSNIZCKZZFTZYNZYGZZLAGTLFFZZEHHEIAZZLZBWZZTBELHHEZZZCEHIPAZLNZTBZVRFNNLRYZFBZFBGEZBEZFASZZZTZRZZONFFAANVFMSZZNFHTZFETZTTZLFTDFBFHRLZSZZEZFZNZZZKRSTZFZVSEIZZFZFKGHSFZGKSZHZHBKDVFZZTZEZEZFHZNITSFZZBYKTZFZFNZHFZHTZZEZZTTAZZMBAFZLVMKTTNREFTZZLFSZSKZDKBZSHKAZMTCDFFFZZQFETFZFHHZZFINZZFZAEZGZZGYGZANRFZKBEFBNZZNTPZCLPGAZBZZETATZBSBKZZZZNLVIKLZNVDFEZZZZVENYSATTVQKKTAZZZZFZEZTHHBRKFZYGLZBRZKYGKCZZZHTCTHZKBSFANBKYRGDFZTCZDZLZNKZHZZFVEEFZIF

100%|██████████| 537/537 [00:26<00:00, 19.98it/s]


New sequence (no gaps, pad tokens) MDBEHKEEBGFEVRFZZBYAYQFLSFEBHQFLLFVLRTATRAQFLLVKLINAZZBLEQVZKVKCYZGNTPHDGJCEAZYKFBEBGEEFNLALNQNBTLKSDMNIFFLDAVFNPEEPEBECPFLFFFGLIFLNBABVZANGLFTEZEEZEVNDDBARFBEBVPENYGFENKPNFDYKLNSNNZAKVHGFVRKENEKIZAHEYLARDBVGIZVKAFCFHVGNIGBYGEQBGLFGGNEKKQENDTHYDRHNKLGNBAIFDCZGFFFZVVKGRFTGELPEGMLAELEKFENLCCBZZATENCEEFBNAKGLDEKGSFFLDFEHNIELKKKEBGDAEFNFBGIQFNVHKVVVDFZLLEKIRHTLEENNNIFVBYGLFFBGEZNFHEQFEENKLKNEADNVELNVBNGBANVALEENFZSPFENHFKKNDBZDFGGLKZKNKALKRHNZLKZKZFBNKFLALGQCNABANCTDTEGPTKLABZAFFFNVFCFRENHEKEFFIZRBVSVLEVEWILEFAVFVNZCTENLTNZBFTLQFNGSFKK

Processing (ec:3.1.1.3)AND(organism_name:Acinetobacter sp.)_reviewed0_22.a3m


100%|██████████| 529/529 [00:24<00:00, 21.53it/s]


New sequence (no gaps, pad tokens) MZZKHZHNATZZANIZFBZNFKGLVCBPYFZZAKBKKBKFKGAGNRKARBZKLZHNOGNHKZOZTHAAYZQSFZVBZZVZZSZAQTLJYKTEINIZGLFROHKLFAZAZESSZYNGPTVKZAKENNNTNHNOGZZAFKAOFVDBNKFYYBNKGFNNLBNMZFZQNCANFZDSBNGLBABZRGEBDCEQBMASZBNFKCZYKBGKEFPZZIZYLEQIBZZYZAFGKVGEZNZNLLEKMZCBFZVGLWWENYPHLOVSALSQSBEBTGLSDEGFGZZGZBQKTBBAZBAFELFJNYGHINFKKLSOLTYZTHGAZFWFSKPGNTZFEBZBIBIELAFBZCBSOHSHZLAKZQTIZBDGVNPLMOZELFAKVEQKZEZZZNSFIZGTFEOFAANFZFBELBAZZZENBZKLZZNEDLLILZZONSBOALQKTZZFRZKFBNZQNFHTZBISVGHTDZHNDIZVLIFEAIZVZMKGZAFNQJGZKGFRSGGTIVZHHGHQQDZGZEKEYEYNZNZFPVTLRSFZNMZZCZKAF

Processing (ec:3.4.21.62)AND(organism_name:Bacillus licheniformis)_reviewed0_3.a3m


100%|██████████| 382/382 [00:09<00:00, 40.63it/s]

New sequence (no gaps, pad tokens) KGKDZRZVTZMLEAKFBSDGZZVNLFILMMSZZZFSZDKLVBZKSKECILLZEFAFZSZIZVVYZIZLZEZCZZVARZKODAZMZZLETZETZKZKOTZZSBTADDEZSATJZIZDBIANZKZTTBKZETTLKQARBTEZDEHKTTTBKZBVZKZKZZZAADQKEZOAVEKAAGTTKZZSGZILZZBIZZASVAVZZVVZAZSBMAZZBZQCLGNZTWKZBAZNZQGZGTNITLTBRFVADIDSLSZBMBVAAGVZZSZZBDZZKEZEDZBAZAVGZZFVZZZVATLASLTGTDZMTZMZSNZBZGBIZKLLZZELTBLAZATCVDAQTZRAZTZKZZSTLZVZZZZDZLZREAVDKEZLMLOZKEZSZEZZNTEVAVKZFZ






In [None]:
print(records[0].seq)
print(records[0].name)
print(records[0].description)
print(records[0].format)



------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------MKLVNIWL---LLLV-VLLCGKKH-LGD--RLEK----------------------KSFEKAPCPGCSHLTLKVEFSSTVVEYEYIVAFNGYFTAKARNS-------------FISS--ALK--SSEVDNWRIIPRNNPSS------------------DYPSD----------FEVIQIKEKQKAGL-LTLEDHPNIKRVTPQRK--------VFRSLKYAES--D----------P--------TVPCNETRWSQK-W---QSSRPL-RRASLSLGSGF--------WHATGRH-SSRRLLRAIPRQVAQTL--------QADVLWQMGYTGANVRVAVFDTGLSEKHPHFKN---V-------KER-------------TNWTN------------------------------ERTLDDGLGHGTFVAGVIASM------------RECQGFAPDAELHIFRVFTNNQ---VSYTSWFLDAFN--YAILKKIDVLNLSIGGPDFMDH-----PFVDKVWELTA--NNVIMVSAIGNDGPLYGTLNNPADQMDVIGVGGID----------------------------------------------------------------------------------------------------------------------------

In [None]:
'''
MZGTZZZZZERZAKYYAZZFEVAWQNZWLHAKEDYVFZHOLZAEEMKKAKEZLBEGKGEIEVAQDDEALRZLEALZTNAEZZKZEYYLAAFYKZQMAELAEZQZTPAEKFVOELFOKEVAZFAZGNFQDFLOEYPFEGOGZVIMAFZAZMKEYEZHGCZFEHKAFZHZKGCKSETLKDKLALEERINDEFKEZQZKGAKLZHZACVAZTZZK

AQTVPYGIPLIKADKVQAQGYKGANVKVGIIDTGIAASHTDLKVVGGASFVSGESYNTDGNGHGTHVAGTVAALDNTTGVLGVAPNVSLYAIKVLNSSGSGTYSAIVSGIEWATQNGLDVINMSLGGPSGSTALKQAVDKAYASGIVVVAAAGNSGSSGSQNTIGYPAKYDSVIAVGAVDSNKNRASFSSVGAELEVMAPGVSVYSTYPSNTYTSLNGTSMASPHVAGAAALILSKYPTLSASQVRNRLSSTATNLGDSFYYGKGLINVEAAAQ---------------------------------------------------------------------------------------------------------
MMRKKSFWLGMLTAFMLVFTMAFSDSASAAQPAKNVEKDYIVGFKSGVKTASVKKDIIKESGGKVDKQFRIINAAKAKLDKEALKEVKNDPDVAYVEEDHVAHALAQTVPYGIPLIKADKVQAQGFKGANVKVAVLDTGIQASHPDLNVVGGASFVAGEAYNTDGNGHGTHVAGTVAALDNTTGVLGVAPSVSLYAVKVLNSSGSGSYSGIVSGIEWATTNGMDVINMSLGGASGSTAMKQAVDNAYAKGVVVVAAAGNSGSSGNTNTIGYPAKYDSVIAVGAVDSNSNRASFSSVGAELEVMAPGAGVYSTYPTNTYATLNGTSMASPHVAGAAALILSKHPNLSASQVRNRLSSTATYLGSSFYYGKGLINVEAAAQ

'''



# TODO: use alphafold2 to predict the structure of enzyme


# 

# Training

In [5]:
import argparse
import json
import os
from datetime import datetime, timedelta
import pathlib
import numpy as np
import torch
from torch.cuda.amp import GradScaler
import torch.multiprocessing as mp
from torch.nn.utils import clip_grad_norm_
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.distributed as dist
from evodiff.collaters import D3PMCollaterMSA
from evodiff.utils import Tokenizer
from evodiff.losses import  D3PMCELoss,  D3PMLVBLossMSA
from evodiff.model import MSATransformerTime
import sys

from types import SimpleNamespace

from torch.utils.data import Dataset


from sequence_models.esm import MSATransformer
from sequence_models.constants import MSA_ALPHABET
from evodiff.data import TRRMSADataset, A3MMSADataset
from sequence_models.collaters import MSAAbsorbingCollater
from sequence_models.samplers import SortishSampler, ApproxBatchSampler
from sequence_models.losses import MaskedCrossEntropyLossMSA
from evodiff.metrics import MaskedAccuracyMSA
from torch.utils.data import Subset
from sequence_models.utils import warmup, transformer_lr




## Parameters

In [68]:
home = str(pathlib.Path.home())

def get_default_args():
    
    args = SimpleNamespace(
        nodes = 1, 
        gpus=1, # number of gpus per node
        nr = 0, # ranking within the nodes
        offset = 3, # Number of GPU devices to skip
        dropout = 0.0,
        weight_Decay = 0.0,
        tie_Weights = 0,
        task = None,
        dataset = os.path.join(root, 'uniprot/research/msa'),
        out_fpath = 'results',
        state_dict = None,
        mask = 'oadm',
        checkpoint_frep = 120,
        weight_save_frep = None,
        log_freq = 1000,
        reweighting_term = 0.001,
        selection_type = 'MaxHamming',
        d_embed = 1536,
        d_hidden = 6144,
        n_layers = 16,        
        n_heads = 16,
        max_batch_size = 3,
        epochs = 100,
        lr = 1e-4,
        bucket_size = 1000,
        max_tokens = 18000000,
        warmup_steps = 15000,
        max_square_tokens = 1000000000000000,
        n_sequences = 64,
        max_seq_len = 512,
        diffusion_timesteps = 500,
        clip = 1.0,
        decay = 'store_true',
    )
    return args

# 獲取默認參數
args = get_default_args()

# data path
data_dir = os.path.join(root,'uniprot/research/msa')

## Tokenizer

In [60]:
tokenizer = Tokenizer()

padding_idx = tokenizer.pad_id  # PROTEIN_ALPHABET.index(PAD)
masking_idx = tokenizer.mask_id
gap_idx = tokenizer.gap_id
print(tokenizer.alphabet)
print('Using {} as padding index'.format(padding_idx))
print('Using {} as masking index'.format(masking_idx))
print('Using {} as gap index'.format(gap_idx))

['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', 'B', 'Z', 'X', 'J', 'O', 'U', '-', '*', '#', '@', '!']
Using 30 as padding index
Using 28 as masking index
Using 26 as gap index


## Calculating lengths, depths, gap_depths

In [61]:
detergent_files = [file for file in os.listdir(data_dir) if file.endswith('.fasta')]

# the length of sequenc in .a3m
detergent_depths = np.array([], dtype=int)
#TODO I'm not sure but it may be the maximum ammount of gap in .a3m
detergent_gap_depths = np.array([], dtype=int)
# the ammount of sequences in .a3m
detergent_lengths = np.array([], dtype=int)

for filename in detergent_files:
    input_file = os.path.join(data_dir, filename)
    records = list(SeqIO.parse(input_file, 'fasta'))
    depth = len(records[0].seq)
    length = len(records)    
    gap_depth = max([record.count('-') for record in records])

    detergent_depths = np.append(detergent_depths,depth)
    detergent_gap_depths = np.append(detergent_gap_depths,gap_depth)
    detergent_lengths = np.append(detergent_lengths,length)
    

np.savez(os.path.join(data_dir,'detergent_depths.npz'), array=detergent_depths)
np.savez(os.path.join(data_dir,'detergent_gap_depths.npz'), array=detergent_gap_depths)
np.savez(os.path.join(data_dir,'detergent_lengths.npz'), array=detergent_lengths)


print(f'detergent_gap_depths: {np.load(os.path.join(data_dir,'detergent_gap_depths.npz'))['array']}')
print(f'detergent_depths: {np.load(os.path.join(data_dir,'detergent_depths.npz'))['array']}')
print(f'detergent_lengths: {np.load(os.path.join(data_dir,'detergent_lengths.npz'))['array']}')

detergent_gap_depths: [   2  178  248    8 1292  814  105  594  139  169  260]
detergent_depths: [ 212  604  788  382 1798  985  379 1061  537  214  529]
detergent_lengths: [ 2 11  3  3  2 48  2 19  3  8 22]


## Define dataset

In [73]:
from sequence_models.constants import PROTEIN_ALPHABET, PAD, GAP
from sequence_models.utils import parse_fasta


class MSADataset(Dataset):
    """Build dataset for A3M data: MSA Absorbing Diffusion model"""

    def __init__(self, selection_type, n_sequences, max_seq_len, data_dir=None, min_depth=None):
        """
        Args:
            selection_type: str,
                MSA selection strategy of random or MaxHamming
            n_sequences: int,
                number of sequences to subsample down to
            max_seq_len: int,
                maximum MSA sequence length
            data_dir: str,
                if you have a specified data directory
        """
        alphabet = PROTEIN_ALPHABET
        self.tokenizer = Tokenizer(alphabet)
        self.alpha = np.array(list(alphabet))
        self.gap_idx = self.tokenizer.alphabet.index(GAP)

        # Get npz_data dir
        if data_dir is not None:
            self.data_dir = data_dir
        else:
            raise FileNotFoundError(data_dir)
        
        [print("Excluding", x) for x in os.listdir(self.data_dir) if x.endswith('.npz')]
        all_files = [x for x in os.listdir(self.data_dir) if x.endswith('.a3m')]
        all_files = sorted(all_files)
        print("unfiltered length", len(all_files))


        ## Filter based on depth (keep > 64 seqs/MSA)
        if not os.path.exists(os.path.join(data_dir,'detergent_lengths.npz')):
            raise Exception(f"Missing detergent_lenths.npz in {data_dir}")
        if not os.path.exists(os.path.join(data_dir,'detergent_depths.npz')):
            #get_msa_depth_openfold(data_dir, sorted(all_files), 'openfold_depths.npz')
            raise Exception(f"Missing detergent_depths.npz in {data_dir}")
        if min_depth is not None: # reindex, filtering out MSAs < min_depth
            _depths = np.load(os.path.join(data_dir,'detergent_depths.npz'))['array']
            depths = pd.DataFrame(_depths, columns=['depth'])
            depths = depths[depths['depth'] >= min_depth]
            keep_idx = depths.index

            _lengths = np.load(os.path.join(data_dir,'detergent_lengths.npz'))['array']
            lengths = np.array(_lengths)[keep_idx]
            all_files = np.array(all_files)[keep_idx]
            print("filter MSA depth > 64", len(all_files))


        # Re-filter based on high gap-contining rows
        if not os.path.exists(os.path.join(data_dir,'detergent_gap_depths.npz')):
            #get_sliced_gap_depth_openfold(data_dir, all_files, 'openfold_gap_depths.npz', max_seq_len=max_seq_len)
            raise Exception(f"Missing detergent_gap_depths.npz in {data_dir}")
        _gap_depths = np.load(os.path.join(data_dir,'detergent_gap_depths.npz'))['array']
        gap_depths = pd.DataFrame(_gap_depths, columns=['gapdepth'])
        gap_depths = gap_depths[gap_depths['gapdepth'] >= min_depth]
        filter_gaps_idx = gap_depths.index
        lengths = np.array(lengths)[filter_gaps_idx]
        all_files = np.array(all_files)[filter_gaps_idx]
        print("filter rows with GAPs > 512", len(all_files))


        self.filenames = all_files  # IDs of samples to include
        self.lengths = lengths # pass to batch sampler
        self.n_sequences = n_sequences
        self.max_seq_len = max_seq_len
        self.selection_type = selection_type

    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        filename = self.filenames[idx]

        def read_files(data_dir, filename):
            """
            inputs:
                data_dir : path to directory with data
                filename: MSA name

            outputs:
                path: path to .a3m file
            """
            if os.path.exists(os.path.join(data_dir, filename)):
                path = os.path.join(data_dir, filename)
            else:
                raise Exception("Missing filepaths")
            return path


        path = read_files(self.data_dir, filename)
        parsed_msa = parse_fasta(path)

        aligned_msa = [[char for char in seq if (char.isupper() or char == '-') and not char == '.'] for seq in parsed_msa]
        aligned_msa = [''.join(seq) for seq in aligned_msa]

        tokenized_msa = [self.tokenizer.tokenizeMSA(seq) for seq in aligned_msa]
        tokenized_msa = np.array([l.tolist() for l in tokenized_msa])
        msa_seq_len = len(tokenized_msa[0])

        if msa_seq_len > self.max_seq_len:
            slice_start = np.random.choice(msa_seq_len - self.max_seq_len + 1)
            seq_len = self.max_seq_len
        else:
            slice_start = 0
            seq_len = msa_seq_len

        # Slice to 512
        sliced_msa_seq = tokenized_msa[:, slice_start: slice_start + self.max_seq_len]
        anchor_seq = sliced_msa_seq[0]  # This is the query sequence in MSA

        # slice out all-gap rows
        sliced_msa = [seq for seq in sliced_msa_seq if (list(set(seq)) != [self.gap_idx])]
        msa_num_seqs = len(sliced_msa)




        if msa_num_seqs < self.n_sequences:
            print("before for len", len(sliced_msa_seq))
            print("msa_num_seqs < self.n_sequences should not be called")
            print("tokenized msa shape", tokenized_msa.shape)
            print("tokenized msa depth", len(tokenized_msa))
            print("sliced msa depth", msa_num_seqs)
            print("used to set slice")
            print("msa_seq_len", msa_seq_len)
            print("self max seq len", self.max_seq_len)
            print(slice_start)
            import pdb; pdb.set_trace()
            output = np.full(shape=(self.n_sequences, seq_len), fill_value=self.tokenizer.pad_id)
            output[:msa_num_seqs] = sliced_msa
            raise Exception("msa num_seqs < self.n_sequences, indicates dataset not filtered properly")
        elif msa_num_seqs > self.n_sequences:
            if self.selection_type == 'random':
                random_idx = np.random.choice(msa_num_seqs - 1, size=self.n_sequences - 1, replace=False) + 1
                anchor_seq = np.expand_dims(anchor_seq, axis=0)
                output = np.concatenate((anchor_seq, np.array(sliced_msa)[random_idx.astype(int)]), axis=0)
            elif self.selection_type == "MaxHamming":
                output = [list(anchor_seq)]
                msa_subset = sliced_msa[1:]
                msa_ind = np.arange(msa_num_seqs)[1:]
                random_ind = np.random.choice(msa_ind)
                random_seq = sliced_msa[random_ind]
                output.append(list(random_seq))
                random_seq = np.expand_dims(random_seq, axis=0)
                msa_subset = np.delete(msa_subset, (random_ind - 1), axis=0)
                m = len(msa_ind) - 1
                distance_matrix = np.ones((self.n_sequences - 2, m))

                for i in range(self.n_sequences - 2):
                    curr_dist = cdist(random_seq, msa_subset, metric='hamming')
                    curr_dist = np.expand_dims(np.array(curr_dist), axis=0)  # shape is now (1,msa_num_seqs)
                    distance_matrix[i] = curr_dist
                    col_min = np.min(distance_matrix, axis=0)  # (1,num_choices)
                    max_ind = np.argmax(col_min)
                    random_ind = max_ind
                    random_seq = msa_subset[random_ind]
                    output.append(list(random_seq))
                    random_seq = np.expand_dims(random_seq, axis=0)
                    msa_subset = np.delete(msa_subset, random_ind, axis=0)
                    distance_matrix = np.delete(distance_matrix, random_ind, axis=1)
        else:
            output = sliced_msa

        output = [''.join(seq) for seq in self.alpha[output]]
        return output

min_depth = args.n_sequences

dataset = MSADataset(args.selection_type, args.n_sequences, args.max_seq_len, data_dir=data_dir, min_depth=min_depth)
train_size = len(dataset)

random_ind = np.random.choice(train_size, size=(train_size - 10000 if train_size>1000 else train_size), replace=False)
print("TRAIN SIZE:", train_size, random_ind)


if args.mask == 'oadm':
    collater = MSAAbsorbingCollater(alphabet=MSA_ALPHABET, num_seqs=2)
    diffusion_timesteps = None # Not input to model

Excluding detergent_depths.npz
Excluding detergent_gap_depths.npz
Excluding detergent_lengths.npz
unfiltered length 11
filter MSA depth > 64 11
filter rows with GAPs > 512 9
TRAIN SIZE: 9 [8 2 6 7 1 0 4 3 5]


In [76]:
_ = torch.manual_seed(1)
np.random.seed(1)
torch.cuda.set_device(args.offset)
device = torch.device('cuda:' + str(args.offset))

selection_type = args.selection_type
min_depth = args.n_sequences # Will filter out sequences smaller than this number
if hasattr(args, 'clip'):
    clip = args.clip
else:
    # 設置一個默認值，或者處理沒有 clip 屬性的情況
    clip = None  # 或其他適當的默認值

ptjob = False

# build datasets, samplers, and loaders

ds_train = Subset(dataset, random_ind)

# if args.dataset == os.path.join(root, 'uniprot/research/msa'):
#metadata = np.load(data_dir + config['dataset'] + '_lengths.npz')['ells']
metadata = np.array(dataset.lengths)
train_idx = ds_train.indices
#print(train_idx)
len_train = metadata[train_idx]

len_train = np.minimum(len_train, args.max_seq_len)

train_sortish_sampler = SortishSampler(len_train, args.bucket_size)
train_sampler = ApproxBatchSampler(train_sortish_sampler, args.max_tokens, args.max_batch_size, len_train,
                                    max_square_tokens=args.max_square_tokens, msa_depth=args.n_sequences)
dl_train = DataLoader(dataset=ds_train,  batch_sampler=train_sampler, collate_fn=collater, num_workers=8)


# if rank == 0: 2024/07/23 still not figure out what it is
val_ind = np.delete(np.arange(train_size), random_ind)
ds_valid = Subset(dataset, val_ind)
valid_idx = ds_valid.indices
len_valid = metadata[valid_idx]
len_valid = np.minimum(len_valid, args.max_seq_len)

valid_sortish_sampler = SortishSampler(len_valid, args.bucket_size, num_replicas=1, rank=0)
valid_sampler = ApproxBatchSampler(valid_sortish_sampler, args.max_tokens, args.max_batch_size, len_valid,
                                    max_square_tokens=args.max_square_tokens, msa_depth=args.n_sequences)

dl_valid = DataLoader(dataset=ds_valid,
                        batch_sampler=valid_sampler,
                        collate_fn=collater,
                        num_workers=8)

# Initiate model
if args.mask == 'oadm':
    model = MSATransformer(args.d_embed, args.d_hidden, args.n_layers, args.n_heads, use_ckpt=True, n_tokens=len(MSA_ALPHABET),
                            padding_idx=padding_idx, mask_idx=masking_idx).cuda()

optimizer = Adam(model.parameters(), lr=args.lr)
if args.decay:
    scheduler = LambdaLR(optimizer, transformer_lr(args.warmup_steps))
else:
    scheduler = LambdaLR(optimizer, warmup(args.warmup_steps))
scaler = GradScaler()

outputs = os.listdir(args.out_fpath)

if len(outputs) > 0:
    last_epoch = -1
    for output in outputs:
        if 'checkpoint' in output:
            epoch = int(output.split('checkpoint')[-1][:-4])
            if epoch > last_epoch:
                args.state_dict = args.out_fpath + output
                last_epoch = epoch
if args.state_dict is not None:
    print('Loading weights from ' + args.state_dict + '...')
    sd = torch.load(args.state_dict, map_location=torch.device('cpu'))
    msd = sd['model_state_dict']
    msd = {k.split('module.')[1]: v for k, v in msd.items()}
    model.load_state_dict(msd)
    optimizer.load_state_dict(sd['optimizer_state_dict'])
    scheduler.load_state_dict(sd['scheduler_state_dict'])
    scaler.load_state_dict(sd['scaler_state_dict']),
    initial_epoch = sd['epoch'] + 1
    total_steps = sd['step']
    total_tokens = sd['tokens']
else:
    initial_epoch = 0
    total_steps = 0
    total_tokens = 0

model = model.to(device)

if args.mask == 'oadm':
    loss_func = MaskedCrossEntropyLossMSA(ignore_index=padding_idx)


accu_func = MaskedAccuracyMSA()
# if rank == 0: 2024/07/23 still not figure out what it is
# with open(args.config_fpath, 'r') as f_from:
#     with open(args.out_fpath + "config.json", "w") as f_to:
#         f_to.write(f_from.read())

## Training loop

In [77]:
def epoch(model, e, split, current_step=0, current_tokens=0):
    start_time = datetime.now()
    if split == 'train':
        loader = dl_train
        t = 'Training:'
    elif split == 'valid':
        loader = dl_valid
        t = 'Validating:'
    else:
        # loader = dl_test
        t = "Testing"
    ardm_losses = []
    nll_losses = []
    accus = []
    ns = []
    num_seqs = []
    chunk_time = datetime.now()
    weight_chunk_time = datetime.now()
    n_seen = 0
    tokens_trained = current_tokens
    if split == 'train':
        n_total = len(ds_train)
    elif split == 'valid':
        n_total = len(ds_valid)
    # else:
    #     n_total = len(ds_test)
    for i, batch in enumerate(loader):
        if split == 'train' and i == 1 and e == initial_epoch and args.state_dict is not None:
            optimizer.load_state_dict(sd['optimizer_state_dict'])
            scheduler.load_state_dict(sd['scheduler_state_dict'])
            scaler.load_state_dict(sd['scaler_state_dict'])
        ardm_loss, nll_loss, new_accu, new_n, new_seqs, new_processed = step(model, batch, split)

        if split == 'train':
            dist.reduce(ardm_loss, 0, op=dist.ReduceOp.SUM)
            dist.reduce(nll_loss, 0, op=dist.ReduceOp.SUM)
            dist.reduce(new_accu, 0, op=dist.ReduceOp.SUM)
            dist.reduce(new_n, 0, op=dist.ReduceOp.SUM)
            dist.reduce(new_seqs, 0, op=dist.ReduceOp.SUM)
        ardm_losses.append(ardm_loss.item())
        nll_losses.append(nll_loss.item())
        accus.append(new_accu.item())
        ns.append(new_n.item())
        num_seqs.append(new_seqs.item())
        n_seen += new_seqs.item()
        total_n = sum(ns)
        total_s = sum(num_seqs)
        rloss_ardm = sum(ardm_losses) / total_n
        rloss_nll = sum(nll_losses) / total_n
        raccu = sum(accus) / total_n

        if split == 'train':
            # writer.add_scalar("Loss/train", rloss, e)
            # writer.add_scalar("Acc/train", raccu, e)
            nsteps = current_step + i + 1
            tokens_trained += new_processed.item()
        else:
            # writer.add_scalar("Loss/valid", rloss, e)
            # writer.add_scalar("Acc/valid", raccu, e)
            nsteps = i
        if rank == 0:
            if ptjob:
                end = '\n'
                start = ''
            else:
                start = ''
                end = '\n'
            print(
                start + '%s Epoch %d of %d Step %d Example %d of %d ardm_loss = %.4f nll_loss = %.4f accu = %.4f'
                % (t, e + 1, epochs, nsteps, n_seen, n_total, rloss_ardm, rloss_nll, raccu),
                end=end)
            print('\n')

        if split == 'train':
            ardm_losses = ardm_losses[-999:]
            nll_losses = nll_losses[-999:]
            accus = accus[-999:]
            ns = ns[-999:]
            num_seqs = num_seqs[-999:]
            if nsteps % args.log_freq == 0:
                if rank == 0:
                    with open(args.out_fpath + 'metrics_train.csv', 'a') as f:
                        #f.write(','.join(
                        #    [str(rloss_ardm), str(rloss_nll), str(raccu), str(int(current_tokens)), str(current_step)]))
                        #f.write('\n')  # Can add for train too
                        f.write(','.join(
                            [str(rloss_ardm), str(rloss_nll), str(raccu), str(int(current_tokens)),
                                str(nsteps), str(e)]))
                        f.write('\n')
            if datetime.now() - chunk_time > timedelta(minutes=args.checkpoint_freq):
                if rank == 0:
                    print('Training complete in ' + str(datetime.now() - chunk_time))
                    with torch.no_grad():
                        if rank == 0:
                            ckpt_fpath = args.out_fpath + 'checkpoint%d.tar' % nsteps
                            torch.save({
                                'step': nsteps,
                                'tokens': tokens_trained,
                                'model_state_dict': model.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                                'scheduler_state_dict': scheduler.state_dict(),
                                'scaler_state_dict': scaler.state_dict(),
                                'epoch': e,
                                # 'amp_state_dict': amp.state_dict()
                            }, ckpt_fpath)
                            _ = epoch(model, e, split='valid', current_step=nsteps, current_tokens=tokens_trained)
                    chunk_time = datetime.now()
                    weight_chunk_time = datetime.now()
            elif datetime.now() - weight_chunk_time > timedelta(minutes=args.weight_save_freq):
                if rank == 0:
                    print('Saving weights ' + str(datetime.now() - chunk_time))
                    with torch.no_grad():
                        if rank == 0:
                            ckpt_fpath = args.out_fpath + 'checkpoint%d.tar' % nsteps
                            torch.save({
                                'step': nsteps,
                                'tokens': tokens_trained,
                                'model_state_dict': model.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                                'scheduler_state_dict': scheduler.state_dict(),
                                'scaler_state_dict': scaler.state_dict(),
                                'epoch': e,
                                # 'amp_state_dict': amp.state_dict()
                            }, ckpt_fpath)
                    weight_chunk_time = datetime.now()
    if split == 'valid':
        if rank == 0:
            with open(args.out_fpath + 'metrics.csv', 'a') as f:
                # f.write(','.join(
                #     [str(rloss_ardm), str(rloss_nll), str(raccu), str(int(current_tokens)), str(current_step)]))
                # f.write('\n')  # Can add for train too
                f.write(','.join(
                    [str(rloss_ardm), str(rloss_nll), str(raccu), str(int(current_tokens)),
                        str(current_step), str(e)]))
                f.write('\n')
        print('Validation complete in ' + str(datetime.now() - start_time))
    print('Epoch complete in ' + str(datetime.now() - start_time))
    return i, tokens_trained

def step(model, batch, split):
    if args.mask == 'blosum' or args.mask == 'random':
        src, src_one_hot, timestep, tgt, tgt_one_hot, Q, Q_prod, q = batch
        src_one_hot = src_one_hot.to(device)
        tgt_one_hot = tgt_one_hot.to(device)
        q = q.to(device)
        Q = Q.to(device)
        Q_prod = Q_prod.to(device)
        timestep = timestep.to(device)
    else:
        src, tgt, mask = batch
        mask = mask.to(device)
    src = src.to(device)
    tgt = tgt.to(device)
    input_mask = (src != masking_idx).float()
    nonpad_mask = (src != padding_idx).float()
    if args.mask == 'blosum' or args.mask == 'random':
        n_tokens = nonpad_mask.sum()
    else:
        n_tokens = mask.sum()
    if n_tokens == 0:
        raise ValueError("N TOKENS IN STEP IS 0!!")
    n_processed = input_mask.sum()

    if split == 'train':
        optimizer.zero_grad()

    if args.mask == 'blosum' or args.mask == 'random':
        outputs = model(src, timestep)
        lvb_loss = loss_func1(src_one_hot, q, outputs, tgt, tgt_one_hot, nonpad_mask, timestep, Q, Q_prod)
        ce_loss = loss_func2(outputs, tgt, nonpad_mask)
        lvb_loss = lvb_loss.to(torch.float32)
        ce_loss = ce_loss.to(torch.float32)
        nll_loss = ce_loss * n_tokens
        accu = accu_func(outputs, tgt, nonpad_mask) * n_tokens
        loss = (lvb_loss + _lambda * ce_loss) * n_tokens
    elif args.mask == 'oadm':
        outputs = model(src)
        ce_loss, nll_loss = loss_func(outputs, tgt, mask, nonpad_mask)
        loss = ce_loss
        accu = accu_func(outputs, tgt, mask) * n_tokens

    if split == 'train':
        scaler.scale(loss).backward()
        _ = clip_grad_norm_(model.parameters(), clip)
        scaler.step(optimizer)
        scale = scaler.get_scale()
        scaler.update()
        skip_scheduler = (scale > scaler.get_scale())
        if not skip_scheduler:
            scheduler.step()

    n_seqs = torch.tensor(len(src), device=device)
    return loss, nll_loss, accu, n_tokens, n_seqs, n_processed

n_parameters = sum(p.numel() for p in model.parameters())
if rank == 0:
    print('%d model parameters' % n_parameters)
for e in range(initial_epoch, epochs):
    print("epoch: ", e + 1, rank)
    s, t = epoch(model, e, split='train', current_step=total_steps, current_tokens=total_tokens)
    total_steps += s
    total_tokens += t


NameError: name 'rank' is not defined

In [17]:
# model = MSATransformer(256,128,16,16)

checkpoint = MSA_OA_DM_MAXSUB()
model, collater, tokenizer, scheme = checkpoint


model = model.to(device)


In [18]:
tokenizer = Tokenizer()
src = '-----------MLSGLSISAAHATNAEQVKNSFVYSSYAQTKYPLVFNHGMAGFNRVG-----------TDTLGLDYWYQILPDLARNGGNVWATRVSPFNSTEVRGEQLAQQV---------EEIIAITGKPKVNLIGHSHGGPTIRYVAGIMPEKVASLTTIGAPHKGSPMADVILNVEG---TPLSGLATLVNWFSAAITWAGGLDPTSYPHDSLAGAHSLSTQGSAQFNAQFPMGVPTTSCGEGTYQEKGIYMYSFSGNKALTNPLDPFDIALTGSSLVVDPFG---------------DNDGLVSRCSAKFGKTIRDDYNWNHLDEVNQVLGIRSIFASDPVSVYRQHANRLKLQGL-----------------------------------------------------------------------------------------------------------------------------------------------------------------------'
print(len(src))
encode = torch.tensor([[tokenizer.tokenizeMSA(src)]])

print(encode.shape)
pred = model.forward(encode.to(device))
print(pred.shape)
print(pred)



batchsize, length, depth, tokens = pred.shape

_, p = torch.max(torch.nn.functional.softmax(pred, dim=-1), -1)
print(encode)
print(np.squeeze(p))

print("Encoded input shape:", encode.shape)
print("Prediction shape:", p.shape)


encode_np = encode.cpu().numpy().squeeze()
p_np = p.cpu().numpy().squeeze()
correlation = np.corrcoef(encode_np, p_np)[0, 1]
print(f"Correlation coefficient: {correlation:.4f}")



529
torch.Size([1, 1, 529])
torch.Size([1, 1, 529, 31])




tensor([[[[ 12.9114, -35.5368,  -8.6801,  ...,   8.7947,  31.8852,  34.8907],
          [ 42.2695, -21.8901, -11.5830,  ...,   8.8298,  36.4136,  30.8669],
          [ 26.6929, -29.1680,  -5.0770,  ...,   8.7894,  40.7508,  18.5738],
          ...,
          [ 19.7565, -28.4536, -15.5611,  ...,   8.8328,  33.9661,  28.4636],
          [ 15.6928, -32.2167, -13.9613,  ...,   8.9582,  30.1659,  28.8393],
          [ 10.9344, -43.0428,   1.6987,  ...,   9.0806,  41.4534,  24.4211]]]],
       device='cuda:0', grad_fn=<AddBackward0>)
tensor([[[26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 10,  9, 15,  5,  9, 15,
           7, 15,  0,  0,  6,  0, 16, 11,  0,  3, 13, 17,  8, 11, 15,  4, 17,
          19, 15, 15, 19,  0, 13, 16,  8, 19, 12,  9, 17,  4, 11,  6,  5, 10,
           0,  5,  4, 11, 14, 17,  5, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
          26, 16,  2, 16,  9,  5,  9,  2, 19, 18, 19, 13,  7,  9, 12,  2,  9,
           0, 14, 11,  5,  5, 11, 17, 18,  0, 16, 14, 17, 15, 12,  4, 11, 15