In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn, einsum, Tensor
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
import joblib
import numpy as np
from Architectures.decoder_conv import Decoder
from Architectures.encoder_conv import Encoder
from Architectures.autoencoder_conv_pl import Autoencoder

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
data = joblib.load('data/latent_dim_500.joblib')
n, c, width = data['inputs'].shape
new = data['inputs'].view(n, c, int(np.sqrt(width)), int(np.sqrt(width)))
print(new.shape)

In [3]:
model = Unet(
    dim = 4,
    dim_mults = (1, 2),
    channels=1,
    resnet_block_groups=2
).cuda()


In [4]:
diffusion = GaussianDiffusion(
    model,
    image_size = 32,
    timesteps = 200,           # number of steps
).cuda()

In [5]:
data = joblib.load('data/latent_dim.joblib')

In [6]:
n, c, width = data['inputs'].shape
new = data['inputs'].view(n, c, int(np.sqrt(width)), int(np.sqrt(width)))

In [7]:
trainer = Trainer(
    diffusion,
    new.detach(),
    train_batch_size =4,
    train_lr = 8e-5,
    train_num_steps = 100000,         # total training steps
    gradient_accumulate_every = 2,    # gradient accumulation steps
    ema_decay = 0.995,                # exponential moving average decay
    amp = True,                       # turn on mixed precision
    calculate_fid = True ,
    results_folder='Diffusion/83737807a5e63cb4c3ef9f7fcf4a987b2776fb761252d19497c7c2ec5c41da32/',             # whether to calculate fid during training
)
# trainer.train()

In [8]:
import glob
import os
import re

list_of_files = glob.glob('Diffusion/83737807a5e63cb4c3ef9f7fcf4a987b2776fb761252d19497c7c2ec5c41da32/*') # * means all if need specific format then *.csv
latest_file = max(list_of_files, key=os.path.getctime)
print(latest_file)
pattern = r"-(.*?)\."
match = re.search(pattern, latest_file)
if match:
    result = match.group(1)
    print(result, type(result))

Diffusion/83737807a5e63cb4c3ef9f7fcf4a987b2776fb761252d19497c7c2ec5c41da32\model-10.pt
10 <class 'str'>


In [9]:
trainer.load(result)

loading from version 1.5.8


In [10]:
sampled_seq = diffusion.sample(batch_size = 4)
sampled_seq.shape # (4, 32, 128)

sampling loop time step: 100%|██████████| 200/200 [00:06<00:00, 32.12it/s]


torch.Size([4, 1, 32, 32])

In [11]:
sampled_seq = sampled_seq.view(sampled_seq.shape[0], sampled_seq.shape[1], 1024)

In [12]:
autoencoder = joblib.load('[1965, 512, 256]_1024_Tanh_AdamW_MSELoss_0_False_4_1_1')

In [13]:
decoder = autoencoder.decoder

In [14]:
mat = decoder(sampled_seq.detach().cpu())

In [15]:
mat.shape

torch.Size([4, 1965, 22])

In [16]:
from utils import reconstruct_sequence
rand_mat = joblib.load('data/AA_random_matrices_orth.joblib')

In [17]:
[''.join(x) for x in reconstruct_sequence(rand_mat, mat.detach().cpu())]

['TAISAAAGDDAADKAAAAAFETAAACAAAKAACSKRDCGCAFAAIDIICCAAAECMGGCSEGEGIFICEGIVIGAMCGGAGEIEEAESECACCMECESGMIK__CECQH___C_______A_____AGA___QA_FQQK_M_F_QAM_KK_F_AFEFDKKKNAK_KKSEAAPFFDFAEKE_CEAFASDAADADAPADGCYKDADAF_FEFAGDDAWPPAAYREA_W_YNAWRIFFN_QDWHANASNFQDFWRWNKWNFANWAMKN_WNFAFNWDAAWHTMHITHA_H_Y_NTN_WHAVAH_VAHAQHKWHAHHHHHAHAHHHHTH_HMHHKFHHHANMGTDAAMAHMHHMMMMHMDMALTLHMTDMGSMNAGMMSMAAGMHGQXASMLMGMMMMGLGGGHTQGLGMMMIXCMMTMHLGEKCFDMGQQATGDMCQMGLAGTMMEGQLMGITETETFGTHGGXGTAXATTEDKTICLCGMMEEMIIEMTFXEXMXGELIQIIEILFFSGGREHXIMQITTGAIEAEHILCXGEIMEGISFFIQIEDGGEQIMIKMKMEAREEMILEATKGXIDATQKFEMKFEIIAKCFTYISMTHSITIAMESTREKEMIKIIISGAIXYEXPEIGIENHEYYKSPYAFIRITFXKTTIDMNDHYTIYDMHKXDFIFPYAKKEMMAAXDNWFMIFLEMWKRRAEDIPIAYTDPXPCIIMQWEMPDMCCHEIAIEKKVKEIXTETHWSAYEHDEVHDNGEEHIIIEXENHHKEEKFHIHXYXEPSIKIDIPHDFEVHDAKIFXI_HPEIYEIHEDP_IDDATPQYIQHMDX_HDCEEIGXEMTMQHAAPREEDVWKETMACIQM_DPMEKQKXMAVKVMYSVDQAEEVHQIMCXHXYXYAHVEGXAAD_IMYMDEYXAYVXPXEVKTAYMKMEYYEEVKAMHE__HP_PQEH_V_EEDYS_MLHHSYGNPAXAKHTXFVSYYPDTXC_MEEY_EEXMAXIV_APDYPMEE

In [18]:
import ProtParam 

for seq in reconstruct_sequence(rand_mat, mat.detach().cpu()):
    print("No modification: ",ProtParam.get_sequence_half_life(seq))
    # print("With modification: ",ProtParam.get_sequence_half_life([string.replace("_", "") for string in seq]))

No modification:  7.2
No modification:  5.5
No modification:  4.4
No modification:  7.2
