In [3]:
import ml_collections
import torch
from torch import multiprocessing as mp
from datasets import get_dataset
from torchvision.utils import make_grid, save_image
import utils
import einops
#from torch.uatils._pytree import tree_map
import accelerate
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
import tempfile
from fid_score import calculate_fid_given_paths
from absl import logging
import builtins
import os
import wandb
import libs.autoencoder
import numpy as np


In [4]:
from torch.utils._pytree import tree_map


In [5]:
import ml_collections

def d(**kwargs):
    """Helper of creating a config dict."""
    return ml_collections.ConfigDict(initial_dictionary=kwargs)
    
config = ml_collections.ConfigDict()
config.seed = 1234
config.z_shape = (3, 1296, 4)
config.config_name = 'test'
config.ckpt_root = '/Users/jihyeonje/unidiffuser/test/'
config.sample_dir = '/Users/jihyeonje/unidiffuser/test/'
config.workdir = '/Users/jihyeonje/unidiffuser/test/'
config.hparams = 'default'

config.autoencoder = d(
    pretrained_path='assets/stable-diffusion/autoencoder_kl.pth',
    scale_factor=0.23010
)
config.train = d(
    n_steps=50,
    batch_size=2,
    log_interval=10,
    eval_interval=5,
    save_interval=5
)

config.optimizer = d(
    name='adamw',
    lr=0.0002,
    weight_decay=0.03,
    betas=(0.9, 0.9)
)

config.lr_scheduler = d(
    name='customized',
    warmup_steps=5000
)

config.nnet = d(
    name='uvit_t2i',
    img_size=1296,
    in_chans=3,
    patch_size=4,
    embed_dim=100,
    depth=12,
    num_heads=10,
    mlp_ratio=4,
    qkv_bias=False,
    mlp_time_embed=False,
    clip_dim=19,
    num_clip_token=77
)

config.dataset = d(
    name='ligprot_features',
    path='test/feats',
    cfg=False,
    p_uncond=0.1
)

config.sample = d(
    sample_steps=2,
    n_samples=5,
    mini_batch_size=1,
    cfg=False,
    scale=1.,
    path='/Users/jihyeonje/unidiffuser/test/res/'
)


In [6]:
def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
    _betas = (torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2)
    return _betas.numpy()


def get_skip(alphas, betas):
    N = len(betas) - 1
    skip_alphas = np.ones([N + 1, N + 1], dtype=betas.dtype)
    for s in range(N + 1):
        skip_alphas[s, s + 1:] = alphas[s + 1:].cumprod()
    skip_betas = np.zeros([N + 1, N + 1], dtype=betas.dtype)
    for t in range(N + 1):
        prod = betas[1: t + 1] * skip_alphas[1: t + 1, t]
        skip_betas[:t, t] = (prod[::-1].cumsum())[::-1]
    return skip_alphas, skip_betas


def stp(s, ts: torch.Tensor):  # scalar tensor product
    if isinstance(s, np.ndarray):
        s = torch.from_numpy(s).type_as(ts)
    extra_dims = (1,) * (ts.dim() - 1)
    return s.view(-1, *extra_dims) * ts


def mos(a, start_dim=1):  # mean of square
    return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)

def LSimple(x0, y0, nnet, schedule, **kwargs):
    n, eps_x, eps_y, xn, yn = schedule.sample(x0, y0)  # n in {1, ..., 1000}
    #n= timestep
    eps_pred_prot, eps_pred_lig = nnet(xn, n, n, yn)

    mos_p = mos(eps_x - eps_pred_prot)
    mos_l = mos(eps_y - eps_pred_lig)
    return mos_p + mos_l


In [7]:

class Schedule(object):  # discrete time
    def __init__(self, _betas):
        r""" _betas[0...999] = betas[1...1000]
             for n>=1, betas[n] is the variance of q(xn|xn-1)
             for n=0,  betas[0]=0
        """

        self._betas = _betas
        self.betas = np.append(0., _betas)
        self.alphas = 1. - self.betas
        self.N = len(_betas)

        assert isinstance(self.betas, np.ndarray) and self.betas[0] == 0
        assert isinstance(self.alphas, np.ndarray) and self.alphas[0] == 1
        assert len(self.betas) == len(self.alphas)

        # skip_alphas[s, t] = alphas[s + 1: t + 1].prod()
        self.skip_alphas, self.skip_betas = get_skip(self.alphas, self.betas)
        self.cum_alphas = self.skip_alphas[0]  # cum_alphas = alphas.cumprod()
        self.cum_betas = self.skip_betas[0]
        self.snr = self.cum_alphas / self.cum_betas

    def tilde_beta(self, s, t):
        return self.skip_betas[s, t] * self.cum_betas[s] / self.cum_betas[t]

    def sample(self, x0, y0):  # 
        n = np.random.choice(list(range(1, self.N + 1)), (len(x0),))
        eps_x = torch.randn_like(x0)
        eps_y = torch.randn_like(y0)
        xn = stp(self.cum_alphas[n] ** 0.5, x0) + stp(self.cum_betas[n] ** 0.5, eps_x)
        yn = stp(self.cum_alphas[n] ** 0.5, y0) + stp(self.cum_betas[n] ** 0.5, eps_y)
        return torch.tensor(n, device=x0.device), eps_x, eps_y, xn, yn

    def __repr__(self):
        return f'Schedule({self.betas[:10]}..., {self.N})'


In [8]:
def combine_joint(z, text):
    z = einops.rearrange(z, 'B C H W -> B (C H W)')
    text = einops.rearrange(text, 'B L D -> B (L D)')
    return torch.concat([z, text], dim=-1)

def split_joint(x):
    C, H, W = 3, 1296, 4
    z_dim = C * H * W
    z, text = x.split([z_dim, 100 * 19], dim=1)
    z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
    text = einops.rearrange(text, 'B (L D) -> B L D', L=100, D=19)
    return z, text

In [9]:
def train(config):
#    if config.get('benchmark', False):
#        torch.backends.cudnn.benchmark = True
#        torch.backends.cudnn.deterministic = False

    #mp.set_start_method('spawn')
    accelerator = accelerate.Accelerator()
    device = accelerator.device
    accelerate.utils.set_seed(config.seed, device_specific=True)
    logging.info(f'Process {accelerator.process_index} using device: {device}')

    config.mixed_precision = accelerator.mixed_precision
    config = ml_collections.FrozenConfigDict(config)

    assert config.train.batch_size % accelerator.num_processes == 0
    mini_batch_size = config.train.batch_size // accelerator.num_processes

    if accelerator.is_main_process:
        os.makedirs(config.ckpt_root, exist_ok=True)
        os.makedirs(config.sample_dir, exist_ok=True)
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        wandb.init(dir=os.path.abspath(config.workdir), project=f'uvit_{config.dataset.name}', config=config.to_dict(),
                   name=config.hparams, job_type='train', mode='offline')
        utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log'))
        logging.info(config)
    else:
        utils.set_logger(log_level='error')
        builtins.print = lambda *args: None
    logging.info(f'Run on {accelerator.num_processes} devices')

    dataset = get_dataset(**config.dataset)
    #assert os.path.exists(dataset.fid_stat)
    train_dataset = dataset.get_split(split='train', labeled=True)
    train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True,
                                      num_workers=1, pin_memory=True, persistent_workers=True)
    test_dataset = dataset.get_split(split='test', labeled=True)  # for sampling
    test_dataset_loader = DataLoader(test_dataset, batch_size=config.sample.mini_batch_size, shuffle=True, drop_last=True,
                                     num_workers=1, pin_memory=True, persistent_workers=True)

    train_state = utils.initialize_train_state(config, device)
    nnet, nnet_ema, optimizer, train_dataset_loader, test_dataset_loader = accelerator.prepare(
        train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader, test_dataset_loader)
    lr_scheduler = train_state.lr_scheduler
    train_state.resume(config.ckpt_root)


    #autoencoder = libs.autoencoder.get_model(**config.autoencoder)
    #autoencoder.to(device)

    #@ torch.cuda.amp.autocast()
    #def encode(_batch):
    #    return autoencoder.encode(_batch)

    #@ torch.cuda.amp.autocast()
    #def decode(_batch):
    #    return autoencoder.decode(_batch)

    def get_data_generator():
        while True:
            for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'):
                yield data

    data_generator = get_data_generator()

    def get_context_generator():
        while True:
            for data in test_dataset_loader:
                yield data[0], data[1]

    context_generator = get_context_generator()
    
    _betas = stable_diffusion_beta_schedule()
    _schedule = Schedule(_betas)
    logging.info(f'use {_schedule}')

    def cfg_nnet(x, timesteps, context):
        _cond = nnet_ema(x, timesteps, context=context)
        _empty_context = torch.tensor(dataset.empty_context, device=device)
        _empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0))
        _uncond = nnet_ema(x, timesteps, context=_empty_context)
        return _cond + config.sample.scale * (_cond - _uncond)

    def joint_nnet(x, timesteps):
        z, text = split_joint(x)
        z_out, text_out = nnet(z, t_prot=timesteps, t_lig=timesteps, context = text)
        if len(z_out.shape)==3:
            z_out = torch.unsqueeze(z_out, 0)
            text_out = torch.unsqueeze(text_out, 0)
        x_out = combine_joint(z_out, text_out)

        return x_out


    def train_step(_batch):
        _metrics = dict()
        optimizer.zero_grad()
        
        loss = LSimple(_batch[0], _batch[1], nnet, _schedule)  # currently only support the extracted feature version
        _metrics['loss'] = accelerator.gather(loss.detach()).mean()
        accelerator.backward(loss.mean())
        optimizer.step()
        lr_scheduler.step()
        train_state.ema_update(config.get('ema_rate', 0.9999))
        train_state.step += 1
        return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics)
    

    def sample_fn(_n_samples, sample_steps):
        _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
        _t_init = torch.randn(_n_samples, *(100,19), device=device)
        _x_init = combine_joint(_z_init, _t_init)
        noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())

        def model_fn(x, t_continuous):
            t = t_continuous * _schedule.N
            return joint_nnet(_x_init, t)

        dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
        _z = dpm_solver.sample(_x_init, steps=sample_steps, eps=1. / _schedule.N, T=1.)
        prot, lig = split_joint(_z)
        return prot, lig

    
    def eval_step(n_samples, sample_steps):
        logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}, algorithm=dpm_solver, '
                     f'mini_batch_size={config.sample.mini_batch_size}')

        _z, _text = sample_fn(n_samples, sample_steps)
        with tempfile.TemporaryDirectory() as temp_path:
            path = config.sample.path or temp_path
            if accelerator.is_main_process:
                os.makedirs(path, exist_ok=True)
            utils.sample2dir(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, sample_steps, dataset.unpreprocess)

            score = 0
            if accelerator.is_main_process:
                score = calculate_fid_given_paths((path))
                logging.info(f'step={train_state.step} eval{n_samples}={score}')
                with open(os.path.join(config.workdir, 'eval.log'), 'a') as f:
                    print(f'step={train_state.step} score{n_samples}={score}', file=f)
                wandb.log({f'score{n_samples}': score}, step=train_state.step)
            score = torch.tensor(score, device=device)
            #_fid = accelerator.reduce(_fid, reduction='sum')

        return score

    logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}')

    step_score = []
    while train_state.step < config.train.n_steps:
        nnet.train()
        batch = tree_map(lambda x: x.to(device), next(data_generator))
        metrics = train_step(batch)

        nnet.eval()
        if accelerator.is_main_process and train_state.step % config.train.log_interval == 0:
            logging.info(utils.dct2str(dict(step=train_state.step, **metrics)))
            logging.info(config.workdir)
            wandb.log(metrics, step=train_state.step)

        if accelerator.is_main_process and train_state.step % config.train.eval_interval == 0:
            torch.cuda.empty_cache()
            #contexts = torch.tensor(dataset.contexts, device=device)[: 2 * 5]
            #samples = dpm_solver_sample(_n_samples=2 * 5, _sample_steps=50, context=contexts)
            #samples = make_grid(dataset.unpreprocess(samples), 5)
            #save_image(samples, os.path.join(config.sample_dir, f'{train_state.step}.png'))
            #wandb.log({'samples': wandb.Image(samples)}, step=train_state.step)
            torch.cuda.empty_cache()
        #accelerator.wait_for_everyone()

        if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps:
            torch.cuda.empty_cache()
            logging.info(f'Save and eval checkpoint {train_state.step}...')
            if accelerator.local_process_index == 0:
                train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt'))
            accelerator.wait_for_everyone()
            score = eval_step(n_samples=5, sample_steps=5)  # calculate fid of the saved checkpoint
            step_score.append((train_state.step, score))
            torch.cuda.empty_cache()
        #accelerator.wait_for_everyone()

    logging.info(f'Finish fitting, step={train_state.step}')
    logging.info(f'step_score: {step_score}')
    step_best = sorted(step_score, key=lambda x: x[1])[0][0]
    logging.info(f'step_best: {step_best}')
    train_state.load(os.path.join(config.ckpt_root, f'{step_best}.ckpt'))
    del metrics
    #accelerator.wait_for_everyone()
    eval_step(n_samples=config.sample.n_samples, sample_steps=config.sample.sample_steps)





In [10]:
train(config)

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


INFO:absl:autoencoder:
  pretrained_path: assets/stable-diffusion/autoencoder_kl.pth
  scale_factor: 0.2301
ckpt_root: /Users/jihyeonje/unidiffuser/test/
config_name: test
dataset:
  cfg: false
  name: ligprot_features
  p_uncond: 0.1
  path: test/feats
hparams: default
lr_scheduler:
  name: customized
  warmup_steps: 5000
mixed_precision: 'no'
nnet:
  clip_dim: 19
  depth: 12
  embed_dim: 100
  img_size: 1296
  in_chans: 3
  mlp_ratio: 4
  mlp_time_embed: false
  name: uvit_t2i
  num_clip_token: 77
  num_heads: 10
  patch_size: 4
  qkv_bias: false
optimizer:
  betas: !!python/tuple
  - 0.9
  - 0.9
  lr: 0.0002
  name: adamw
  weight_decay: 0.03
sample:
  cfg: false
  mini_batch_size: 1
  n_samples: 5
  path: /Users/jihyeonje/unidiffuser/test/res/
  sample_steps: 2
  scale: 1.0
sample_dir: /Users/jihyeonje/unidiffuser/test/
seed: 1234
train:
  batch_size: 2
  eval_interval: 5
  log_interval: 10
  n_steps: 50
  save_interval: 5
workdir: /Users/jihyeonje/unidiffuser/test/
z_shape: !!pyth

Prepare dataset...
Prepare dataset ok
attention mode is math


INFO:absl:nnet has 1752151 parameters
  self.snr = self.cum_alphas / self.cum_betas
INFO:absl:use Schedule([0.         0.00085    0.0008547  0.00085941 0.00086413 0.00086887
 0.00087362 0.00087839 0.00088316 0.00088795]..., 1000)
INFO:absl:Start fitting, step=0, mixed_precision=no


epoch:   0%|          | 0/17 [00:00<?, ?it/s]

test/feats/train/5.npy
test/feats/train/28.npy
test/feats/train/11.npy
test/feats/train/23.npy
test/feats/train/2.npy
test/feats/train/22.npy
test/feats/train/15.npy
test/feats/train/24.npy
test/feats/train/27.npy
test/feats/train/33.npy
test/feats/train/9.npy
test/feats/train/0.npy
test/feats/train/8.npy
test/feats/train/29.npy
test/feats/train/10.npy
test/feats/train/4.npy


INFO:absl:Save and eval checkpoint 5...
INFO:absl:eval_step: n_samples=5, sample_steps=5, algorithm=dpm_solver, mini_batch_size=1


torch.Size([1, 3, 1296, 4])




torch.Size([1, 3, 1296, 4])




torch.Size([1, 3, 1296, 4])




torch.Size([1, 3, 1296, 4])


sample2dir: 100%|██████████| 5/5 [00:03<00:00,  1.64it/s]

torch.Size([1, 3, 1296, 4])



INFO:absl:step=5 eval5=773.5457477517582


test/feats/train/18.npy
test/feats/train/17.npy
test/feats/train/25.npy
test/feats/train/26.npy
test/feats/train/21.npy
test/feats/train/3.npy
test/feats/train/14.npy
test/feats/train/32.npy
test/feats/train/30.npy
test/feats/train/6.npy


INFO:absl:{'step': '10', 'lr': '4e-07', 'loss': '2.09428'}
INFO:absl:/Users/jihyeonje/unidiffuser/test/
INFO:absl:Save and eval checkpoint 10...
INFO:absl:eval_step: n_samples=5, sample_steps=5, algorithm=dpm_solver, mini_batch_size=1


torch.Size([1, 3, 1296, 4])




torch.Size([1, 3, 1296, 4])




torch.Size([1, 3, 1296, 4])




torch.Size([1, 3, 1296, 4])


sample2dir: 100%|██████████| 5/5 [00:02<00:00,  1.76it/s]

torch.Size([1, 3, 1296, 4])



INFO:absl:step=10 eval5=614.08916347545


test/feats/train/12.npy
test/feats/train/20.npy
test/feats/train/13.npy
test/feats/train/1.npy
test/feats/train/34.npy
test/feats/train/31.npy
test/feats/train/16.npy
test/feats/train/7.npy


INFO:absl:Save and eval checkpoint 15...
INFO:absl:eval_step: n_samples=5, sample_steps=5, algorithm=dpm_solver, mini_batch_size=1


torch.Size([1, 3, 1296, 4])




torch.Size([1, 3, 1296, 4])




torch.Size([1, 3, 1296, 4])




torch.Size([1, 3, 1296, 4])


sample2dir: 100%|██████████| 5/5 [00:02<00:00,  1.67it/s]

torch.Size([1, 3, 1296, 4])



INFO:absl:step=15 eval5=692.6620186652698


epoch:   0%|          | 0/17 [00:00<?, ?it/s]

test/feats/train/0.npy
test/feats/train/7.npy
test/feats/train/31.npy
test/feats/train/29.npy
test/feats/train/4.npy
test/feats/train/21.npy
test/feats/train/10.npy
test/feats/train/17.npy
test/feats/train/8.npy
test/feats/train/5.npy
test/feats/train/25.npy
test/feats/train/11.npy


INFO:absl:{'step': '20', 'lr': '8e-07', 'loss': '2.05817'}
INFO:absl:/Users/jihyeonje/unidiffuser/test/
INFO:absl:Save and eval checkpoint 20...
INFO:absl:eval_step: n_samples=5, sample_steps=5, algorithm=dpm_solver, mini_batch_size=1


torch.Size([1, 3, 1296, 4])




torch.Size([1, 3, 1296, 4])




torch.Size([1, 3, 1296, 4])




torch.Size([1, 3, 1296, 4])


sample2dir: 100%|██████████| 5/5 [00:03<00:00,  1.64it/s]

torch.Size([1, 3, 1296, 4])



INFO:absl:step=20 eval5=694.740091513724


test/feats/train/16.npy
test/feats/train/30.npy
test/feats/train/28.npy
test/feats/train/6.npy
test/feats/train/2.npy
test/feats/train/3.npy
test/feats/train/27.npy
test/feats/train/22.npy
test/feats/train/26.npy
test/feats/train/18.npy


INFO:absl:Save and eval checkpoint 25...
INFO:absl:eval_step: n_samples=5, sample_steps=5, algorithm=dpm_solver, mini_batch_size=1


torch.Size([1, 3, 1296, 4])




torch.Size([1, 3, 1296, 4])




torch.Size([1, 3, 1296, 4])




torch.Size([1, 3, 1296, 4])


sample2dir: 100%|██████████| 5/5 [00:03<00:00,  1.63it/s]

torch.Size([1, 3, 1296, 4])



INFO:absl:step=25 eval5=774.7367134776658


test/feats/train/9.npy
test/feats/train/1.npy
test/feats/train/13.npy
test/feats/train/32.npy
test/feats/train/24.npy
test/feats/train/15.npy


KeyboardInterrupt: 

In [None]:
import glob
gts = '/Users/jihyeonje/unidiffuser/test/res/*_gt.npy'
preds = '/Users/jihyeonje/unidiffuser/test/res/*_pred.npy'
idx = len(glob.glob(gts))
loss = torch.nn.MSELoss()
for id in range(idx):
    target = torch.tensor(np.load(f'/Users/jihyeonje/unidiffuser/test/res/{id}_gt.npy'))
    input = torch.tensor(np.load(f'/Users/jihyeonje/unidiffuser/test/res/{id}_pred.npy'))
    ls = loss(input, target)

In [13]:
import numpy as np

In [20]:
lig = np.load(f'/Users/jihyeonje/unidiffuser/test/res/2_lig.npy')

In [14]:
prot = np.load(f'/Users/jihyeonje/unidiffuser/test/res/4_prot.npy')

In [21]:
xyz = lig[:,:3]
one_hot = lig[:,3:]
one_hot = np.argmax(one_hot, axis=1)

In [16]:
from datasets import atom_encoder, atom_decoder

In [17]:
from chem_utils import get_bond_order, allowed_bonds, draw_mol, BasicMolecularMetrics

In [22]:
one_hot

array([ 6, 11, 14,  7,  5, 10,  6,  4, 14, 11,  1,  5,  3,  6,  6,  6,  3,
       13,  4,  9,  1, 12,  6,  2,  9,  8,  5,  2,  4, 10, 12,  5,  5, 11,
       13,  8, 14,  7, 15, 12, 15,  2,  1, 15,  8, 13,  8, 15,  7,  3, 15,
        8,  4,  9,  7,  7,  3,  6,  8, 13,  8,  4,  8,  7,  7, 15,  3, 14,
       10,  2, 14, 13, 10,  5, 10,  9,  6,  3,  9,  0,  1,  1, 11, 10, 11,
       11,  7,  0,  9,  1,  0,  7,  9,  6,  8,  4,  7,  5,  6,  9])

In [19]:
one_hot

array([10, 10,  5,  2,  4, 15,  8,  5,  1,  2, 10,  9,  0,  1,  7,  0, 13,
        7, 11,  7,  7,  4,  9, 12,  3,  5,  0,  7,  2, 12, 10,  6,  5,  0,
       11,  0,  2,  8, 11,  0,  9,  9,  6, 15,  1, 14,  2,  6,  4,  5,  9,
        8,  3,  5,  7,  9,  8,  0,  2,  4, 13,  7,  4, 13,  9,  5,  6,  1,
        8, 13,  3,  5,  3,  8, 11,  9, 11,  7, 11, 14,  2, 14, 11,  7, 11,
        1,  0, 15,  9,  8,  7,  3,  4, 12, 12,  1,  6, 12, 10, 14])

In [23]:
draw_mol(xyz, one_hot)

In [16]:
metrics = BasicMolecularMetrics(atom_decoder)
rdkit_metrics = metrics.evaluate([(xyz, one_hot)])
val, unique = rdkit_metrics
stable, n_bonds, ratio = check_stability(xyz, one_hot)
lig_score = val + ratio
lig_score

1.01

In [12]:
seed = 1234 #@param {type:"number"}
steps = 50 #@param {type:"slider", min:0, max:100, step:1}
cfg_scale = 8 #@param {type:"slider", min:0, max:10, step:0.1}
n_samples = 2 #@param {type:"number"}
nrow = 2 #@param {type:"number"}
data_type = 1
output_path = '/Users/jihyeonje/unidiffuser/test'
device = 'cpu'


##sampling

In [13]:
import ml_collections
import torch
import random
import utils
from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
from absl import logging
import einops
import libs.autoencoder
import libs.clip
from torchvision.utils import save_image, make_grid
import torchvision.transforms as standard_transforms
import numpy as np
import clip
from PIL import Image

In [14]:
from libs.uvit_t2i import UViT

In [15]:
nnet = UViT(
    img_size=1296,
    in_chans=3,
    patch_size=4,
    embed_dim=100,
    depth=12,
    num_heads=10,
    mlp_ratio=4,
    qkv_bias=False,
    mlp_time_embed=False,
    clip_dim=19,
    num_clip_token=77
)


nnet.to(device)
nnet.load_state_dict(torch.load('/Users/jihyeonje/unidiffuser/test/5.ckpt/nnet.pth', map_location='cpu'))
nnet.eval()

UViT(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 100, kernel_size=(4, 4), stride=(4, 4))
  )
  (time_embed): Identity()
  (context_embed): Linear(in_features=19, out_features=100, bias=True)
  (text_embed): Linear(in_features=19, out_features=100, bias=True)
  (text_out): Linear(in_features=100, out_features=19, bias=True)
  (in_blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((100,), eps=1e-05, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=100, out_features=300, bias=False)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=100, out_features=100, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (norm2): LayerNorm((100,), eps=1e-05, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=100, out_features=400, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=400, out_features=100, bias=True)
        (drop): Dro

In [17]:
_betas = stable_diffusion_beta_schedule()
N = len(_betas)

In [None]:
_schedule = Schedule(_betas)

def cfg_nnet(x, timesteps, context):
    _cond = nnet_ema(x, timesteps, context=context)
    _empty_context = torch.tensor(dataset.empty_context, device=device)
    _empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0))
    _uncond = nnet_ema(x, timesteps, context=_empty_context)
    return _cond + config.sample.scale * (_cond - _uncond)

def joint_nnet(x, timesteps):
    z, text = split_joint(x)
    z_out, text_out = nnet(z, t_prot=timesteps, t_lig=timesteps, context = text)
    if len(z_out.shape)==3:
        z_out = torch.unsqueeze(z_out, 0)
        text_out = torch.unsqueeze(text_out, 0)
    x_out = combine_joint(z_out, text_out)

    return x_out


def train_step(_batch):
    _metrics = dict()
    optimizer.zero_grad()
    
    loss = LSimple(_batch[0], _batch[1], nnet, _schedule)  # currently only support the extracted feature version
    _metrics['loss'] = accelerator.gather(loss.detach()).mean()
    accelerator.backward(loss.mean())
    optimizer.step()
    lr_scheduler.step()
    train_state.ema_update(config.get('ema_rate', 0.9999))
    train_state.step += 1
    return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics)


def sample_fn(_n_samples, sample_steps):
    _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
    _t_init = torch.randn(_n_samples, *(100,19), device=device)
    _x_init = combine_joint(_z_init, _t_init)
    noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())

    def model_fn(x, t_continuous):
        t = t_continuous * _schedule.N
        return joint_nnet(_x_init, t)

    dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
    _z = dpm_solver.sample(_x_init, steps=sample_steps, eps=1. / _schedule.N, T=1.)
    prot, lig = split_joint(_z)
    return prot, lig





In [21]:
_z, _text = sample_fn(_n_samples=5, sample_steps=5)

In [23]:
_z.shape

torch.Size([5, 3, 1296, 4])

In [30]:
xyz = _text[3][:,:3]
one_hot = _text[3][:,3:]
one_hot = np.argmax(one_hot, axis=1)

In [35]:
one_hot = _text[3][:,3:]
np.argmax(one_hot, axis=1)

tensor([ 0, 11,  7,  8,  1, 11,  7,  6,  4, 14, 14, 10,  1,  1, 11,  8, 11,  3,
         9,  1,  6, 11,  8, 12,  1,  8,  4,  7,  9, 11,  9,  7,  5,  7, 13, 14,
         8,  1,  7, 15, 15,  4, 12,  0,  1, 15,  8,  2, 11, 10, 14,  4,  8,  2,
         0,  9,  2,  8, 13,  9,  4, 15,  7,  6,  2, 14, 15, 14, 15,  6,  2, 14,
        14,  5,  6,  9,  2, 14,  7, 12, 10, 11,  8,  7,  3,  4, 13,  3, 12,  6,
         5,  2, 15,  5,  5, 10,  9,  1,  2,  6])