In [1]:
import ml_collections
import torch
from torch import multiprocessing as mp
from datasets import get_dataset
from torchvision.utils import make_grid, save_image
import utils
import einops
#from torch.uatils._pytree import tree_map
import accelerate
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
import tempfile
from fid_score import calculate_fid_given_paths
from absl import logging
import builtins
import os
import wandb
import libs.autoencoder
import numpy as np


In [2]:
from torch.utils._pytree import tree_map


In [3]:
def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
    _betas = (
        torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
    )
    return _betas.numpy()


def get_skip(alphas, betas):
    N = len(betas) - 1
    skip_alphas = np.ones([N + 1, N + 1], dtype=betas.dtype)
    for s in range(N + 1):
        skip_alphas[s, s + 1:] = alphas[s + 1:].cumprod()
    skip_betas = np.zeros([N + 1, N + 1], dtype=betas.dtype)
    for t in range(N + 1):
        prod = betas[1: t + 1] * skip_alphas[1: t + 1, t]
        skip_betas[:t, t] = (prod[::-1].cumsum())[::-1]
    return skip_alphas, skip_betas


def stp(s, ts: torch.Tensor):  # scalar tensor product
    if isinstance(s, np.ndarray):
        s = torch.from_numpy(s).type_as(ts)
    extra_dims = (1,) * (ts.dim() - 1)
    return s.view(-1, *extra_dims) * ts


def mos(a, start_dim=1):  # mean of square
    return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)


def LSimple(x0, nnet, schedule, **kwargs):
    n, eps, xn = schedule.sample(x0)  # n in {1, ..., 1000}
    eps_pred = nnet(xn, n, **kwargs)
    return mos(eps - eps_pred)


In [4]:
import ml_collections

def d(**kwargs):
    """Helper of creating a config dict."""
    return ml_collections.ConfigDict(initial_dictionary=kwargs)
    
config = ml_collections.ConfigDict()
config.seed = 1234
config.z_shape = (4, 32, 32)
config.config_name = 'test'
config.ckpt_root = '/Users/jihyeonje/unidiffuser/test/'
config.sample_dir = '/Users/jihyeonje/unidiffuser/test/'
config.workdir = '/Users/jihyeonje/unidiffuser/test/'
config.hparams = 'default'

config.autoencoder = d(
    pretrained_path='assets/stable-diffusion/autoencoder_kl.pth',
    scale_factor=0.23010
)
config.train = d(
    n_steps=5,
    batch_size=2,
    log_interval=10,
    eval_interval=5,
    save_interval=5
)

config.optimizer = d(
    name='adamw',
    lr=0.0002,
    weight_decay=0.03,
    betas=(0.9, 0.9)
)

config.lr_scheduler = d(
    name='customized',
    warmup_steps=5000
)

config.nnet = d(
    name='uvit_t2i',
    img_size=32,
    in_chans=4,
    patch_size=2,
    embed_dim=512,
    depth=12,
    num_heads=8,
    mlp_ratio=4,
    qkv_bias=False,
    mlp_time_embed=False,
    clip_dim=768,
    num_clip_token=77
)

config.dataset = d(
    name='ligprot_features',
    path='test/feats',
    cfg=True,
    p_uncond=0.1
)

config.sample = d(
    sample_steps=2,
    n_samples=3,
    mini_batch_size=1,
    cfg=True,
    scale=1.,
    path=''
)


In [5]:

class Schedule(object):  # discrete time
    def __init__(self, _betas):
        r""" _betas[0...999] = betas[1...1000]
             for n>=1, betas[n] is the variance of q(xn|xn-1)
             for n=0,  betas[0]=0
        """

        self._betas = _betas
        self.betas = np.append(0., _betas)
        self.alphas = 1. - self.betas
        self.N = len(_betas)

        assert isinstance(self.betas, np.ndarray) and self.betas[0] == 0
        assert isinstance(self.alphas, np.ndarray) and self.alphas[0] == 1
        assert len(self.betas) == len(self.alphas)

        # skip_alphas[s, t] = alphas[s + 1: t + 1].prod()
        self.skip_alphas, self.skip_betas = get_skip(self.alphas, self.betas)
        self.cum_alphas = self.skip_alphas[0]  # cum_alphas = alphas.cumprod()
        self.cum_betas = self.skip_betas[0]
        self.snr = self.cum_alphas / self.cum_betas

    def tilde_beta(self, s, t):
        return self.skip_betas[s, t] * self.cum_betas[s] / self.cum_betas[t]

    def sample(self, x0):  # sample from q(xn|x0), where n is uniform
        n = np.random.choice(list(range(1, self.N + 1)), (len(x0),))
        eps = torch.randn_like(x0)
        xn = stp(self.cum_alphas[n] ** 0.5, x0) + stp(self.cum_betas[n] ** 0.5, eps)
        return torch.tensor(n, device=x0.device), eps, xn

    def __repr__(self):
        return f'Schedule({self.betas[:10]}..., {self.N})'


In [10]:
prompts = [
    'A green train is coming down the tracks.',
    'A group of skiers are preparing to ski down a mountain.',
    'A small kitchen with a low ceiling.',
    'A group of elephants walking in muddy water.',
    'A living area with a television and a table.',
    'A road with traffic lights, street lights and cars.',
    'A bus driving in a city area with traffic signs.',
    'A bus pulls over to the curb close to an intersection.',
    'A group of people are walking and one is holding an umbrella.',
    'A baseball player taking a swing at an incoming ball.',
    'A city street line with brick buildings and trees.',
    'A close up of a plate of broccoli and sauce.',
]

device = 'cpu'
clip = libs.clip.FrozenCLIPEmbedder()
clip.eval()
clip.to(device)

save_dir = f'test/feats/run_vis'
latent = clip.encode(prompts)
for i in range(len(latent)):
    c = latent[i].detach().cpu().numpy()
    np.save(os.path.join(save_dir, f'{i}.npy'), np.asarray((prompts[i], c), dtype='object'))


AttributeError: module 'libs' has no attribute 'clip'

In [None]:
from libs import clip
prompts = [
    '',
]

device = 'cpu'
clip = clip.FrozenCLIPEmbedder()
clip.eval()
clip.to(device)

save_dir = f'test/feats'
latent = clip.encode(prompts)
print(latent.shape)
c = latent[0].detach().cpu().numpy()
np.save(os.path.join(save_dir, f'empty_context.npy'), c)   

In [19]:

directory = '/Users/jihyeonje/Downloads/PDBBind_processed/'
ligpaths = []
protpaths = []

# iterate over files in
# that directory
for dir in os.listdir(directory):
    if dir !='.DS_Store':
        foldr = os.path.join(directory, dir)
    for i in os.listdir(foldr):
        if i.endswith('.sdf'):
            ligpaths.append(os.path.join(foldr, i))
        elif i.endswith('.pdb'):
            protpaths.append(os.path.join(foldr, i))

In [10]:
import torch
import os
import numpy as np
import libs.autoencoder
import libs.clip
from datasets import LigProtDatabase
import argparse
from tqdm import tqdm

datas = LigProtDatabase(root='assets/datasets/coco/train2014',
                             annFile='assets/datasets/coco/annotations/captions_train2014.json',
                             size=32)
save_dir = f'test/feats/test'

In [53]:
def make_train_valid_dfs(protpaths, ligtxts):
    max_id = len(protpaths)
    image_ids = np.arange(0, max_id)
    np.random.seed(42)
    valid_ids = np.random.choice(
        image_ids, size=int(0.2 * len(image_ids)), replace=False
    )
    train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
    train_prot = list(map(protpaths.__getitem__, train_ids))
    train_ligs = list(map(ligtxts.__getitem__, train_ids))
    valid_prot = list(map(protpaths.__getitem__, valid_ids))
    valid_ligs = list(map(ligtxts.__getitem__, valid_ids))
    return train_prot, train_ligs, valid_prot, valid_ligs


In [56]:
train_prot, train_lig, valid_prot, valid_lig = make_train_valid_dfs(protpaths, ligpaths)

In [64]:
datas = LigProtDatabase(protpaths = train_prot, ligpaths = train_lig)

In [65]:
len(datas)

TypeError: object of type 'LigProtDatabase' has no len()

In [63]:


class LigProtDatabase(Dataset):
    def __init__(self, protpaths, ligpaths, size=None):
        self.protpaths = protpaths
        self.ligpaths = ligpaths
        self.height = self.width = size
    
    def __len__(self):
        return len(self.protpaths)
        
    def __getitem__(self, index):

        p_feats = load_feats_from_pdb(protpaths[index])
        bb_coords = p_feats['bb_coords']
        protein = einops.rearrange(p_feats, 'h w c -> c h w')
        
        suppl = Chem.SDMolSupplier(ligpaths[index], sanitize=False)
        smi = Chem.MolToSmiles(suppl[0])
        x, atom_types, one_hot = mol_extraction(smi)
        ligand = torch.cat(x, torch.unsqueeze(one_hot,0), dim=2)
        #image = np.ones((32,32,3))
        
        #key = self.keys[index]
        #image = self._load_image(key)
        #image = np.array(image).astype(np.uint8)
        #image = center_crop(self.width, self.height, image).astype(np.float32)
        
        #image = (image / 127.5 - 1.0).astype(np.float32)
        #image = einops.rearrange(image, 'h w c -> c h w')

        #anns = self._load_target(key)
        #target = []
        #for ann in anns:
        #    target.append(ann['caption'])

        return protein, ligand


In [None]:
device = "cpu"


for i in range(10):
    x = np.ones((4,32,32))*255
    x = (x / 127.5 - 1.0).astype(np.float32)
    #if len(x.shape) == 3:
    #    x = x[None, ...]
    #x = torch.tensor(x, device=device)
    #moments = autoencoder(x, fn='encode_moments').squeeze(0)
    #moments = moments.detach().cpu().numpy()
    np.save(os.path.join(save_dir, f'{i}.npy'), x)

    #latent = clip.encode(captions)
    for j in range(3):
        c = np.empty([77,768],dtype='f')
        #c = latent[i].detach().cpu().numpy()
        np.save(os.path.join(save_dir, f'{i}_{j}.npy'), c)


In [None]:
train_prot, train_lig, valid_prot, valid_lig = make_train_valid_dfs(protpaths, ligpaths)

In [None]:
class LigProtFeatureDataset(Dataset):
    # the image features are got through sample
    def __init__(self, protpaths, ligpathss):
        self.root = root
        self.num_data, self.n_captions = len(protpaths), len(ligpaths)

    def __len__(self):
        return self.num_data

    def __getitem__(self, index):
        
        z = np.load(os.path.join(self.root, f'{index}.npy'))
        c = np.load(os.path.join(self.root, f'{index}_{k}.npy'))
        return z, c

In [None]:

class LigProtFeatures(DatasetFactory):  # the moments calculated by Stable Diffusion image encoder & the contexts calculated by clip
    def __init__(self, path, cfg=False, p_uncond=None):
        super().__init__()
        print('Prepare dataset...')
        self.train = LigProtFeatureDataset(os.path.join(path, 'train'))
        self.test = LigProtFeatureDataset(os.path.join(path, 'test'))
        #assert len(self.train) == 82783
        #assert len(self.test) == 40504
        print('Prepare dataset ok')

        self.empty_context = np.load(os.path.join(path, 'empty_context.npy'))

        if cfg:  # classifier free guidance
            assert p_uncond is not None
            print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}')
            self.train = CFGDataset(self.train, p_uncond, self.empty_context)

        # text embedding extracted by clip
        # for visulization in t2i
        self.prompts, self.contexts = [], []
        for f in sorted(os.listdir(os.path.join(path, 'run_vis')), key=lambda x: int(x.split('.')[0])):
            prompt, context = np.load(os.path.join(path, 'run_vis', f), allow_pickle=True)
            self.prompts.append(prompt)
            self.contexts.append(context)
        self.contexts = np.array(self.contexts)

    @property
    def data_shape(self):
        return 4, 32, 32

    @property
    def fid_stat(self):
        return f'assets/fid_stats/fid_stats_mscoco256_val.npz'

In [None]:

class LigProtFeatureDataset(Dataset):
    # the image features are got through sample
    def __init__(self, root):
        self.root = root
        self.num_data, self.n_captions = get_feature_dir_info(root)

    def __len__(self):
        return self.num_data

    def __getitem__(self, index):
        z = np.load(os.path.join(self.root, f'{index}.npy'))
        k = random.randint(0, self.n_captions[index] - 1)
        c = np.load(os.path.join(self.root, f'{index}_{k}.npy'))
        return z, c

In [11]:
device = "cpu"


for i in range(10):
    x = np.ones((4,32,32))*255
    x = (x / 127.5 - 1.0).astype(np.float32)
    #if len(x.shape) == 3:
    #    x = x[None, ...]
    #x = torch.tensor(x, device=device)
    #moments = autoencoder(x, fn='encode_moments').squeeze(0)
    #moments = moments.detach().cpu().numpy()
    np.save(os.path.join(save_dir, f'{i}.npy'), x)

    #latent = clip.encode(captions)
    for j in range(3):
        c = np.empty([77,768],dtype='f')
        #c = latent[i].detach().cpu().numpy()
        np.save(os.path.join(save_dir, f'{i}_{j}.npy'), c)


In [None]:
test_dataset.__dict__

In [None]:
test_dataset_loader.__dict__

In [None]:
dataset = get_dataset(**config.dataset)

test_dataset = dataset.get_split(split='test', labeled=True)  # for sampling
test_dataset_loader = DataLoader(test_dataset, batch_size=config.sample.mini_batch_size, shuffle=True, drop_last=True,
                                    num_workers=1, pin_memory=True, persistent_workers=True)

In [6]:

def train(config):
#    if config.get('benchmark', False):
#        torch.backends.cudnn.benchmark = True
#        torch.backends.cudnn.deterministic = False

    #mp.set_start_method('spawn')
    accelerator = accelerate.Accelerator()
    device = accelerator.device
    accelerate.utils.set_seed(config.seed, device_specific=True)
    logging.info(f'Process {accelerator.process_index} using device: {device}')

    config.mixed_precision = accelerator.mixed_precision
    config = ml_collections.FrozenConfigDict(config)

    assert config.train.batch_size % accelerator.num_processes == 0
    mini_batch_size = config.train.batch_size // accelerator.num_processes

    if accelerator.is_main_process:
        os.makedirs(config.ckpt_root, exist_ok=True)
        os.makedirs(config.sample_dir, exist_ok=True)
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        wandb.init(dir=os.path.abspath(config.workdir), project=f'uvit_{config.dataset.name}', config=config.to_dict(),
                   name=config.hparams, job_type='train', mode='offline')
        utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log'))
        logging.info(config)
    else:
        utils.set_logger(log_level='error')
        builtins.print = lambda *args: None
    logging.info(f'Run on {accelerator.num_processes} devices')

    dataset = get_dataset(**config.dataset)
    #assert os.path.exists(dataset.fid_stat)
    train_dataset = dataset.get_split(split='train', labeled=True)
    train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True,
                                      num_workers=1, pin_memory=True, persistent_workers=True)
    test_dataset = dataset.get_split(split='test', labeled=True)  # for sampling
    test_dataset_loader = DataLoader(test_dataset, batch_size=config.sample.mini_batch_size, shuffle=True, drop_last=True,
                                     num_workers=1, pin_memory=True, persistent_workers=True)

    train_state = utils.initialize_train_state(config, device)
    nnet, nnet_ema, optimizer, train_dataset_loader, test_dataset_loader = accelerator.prepare(
        train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader, test_dataset_loader)
    lr_scheduler = train_state.lr_scheduler
    train_state.resume(config.ckpt_root)

    #autoencoder = libs.autoencoder.get_model(**config.autoencoder)
    #autoencoder.to(device)

    #@ torch.cuda.amp.autocast()
    #def encode(_batch):
    #    return autoencoder.encode(_batch)

    #@ torch.cuda.amp.autocast()
    #def decode(_batch):
    #    return autoencoder.decode(_batch)

    def get_data_generator():
        while True:
            for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'):
                yield data

    data_generator = get_data_generator()

    def get_context_generator():
        while True:
            for data in test_dataset_loader:
                _, _context = data
                yield _context

    context_generator = get_context_generator()

    _betas = stable_diffusion_beta_schedule()
    _schedule = Schedule(_betas)
    logging.info(f'use {_schedule}')

    def cfg_nnet(x, timesteps, context):
        _cond = nnet_ema(x, timesteps, context=context)
        _empty_context = torch.tensor(dataset.empty_context, device=device)
        _empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0))
        _uncond = nnet_ema(x, timesteps, context=_empty_context)
        return _cond + config.sample.scale * (_cond - _uncond)

    def train_step(_batch):
        _metrics = dict()
        optimizer.zero_grad()
        
        _z = _batch[0]
        print(_z.shape)
        loss = LSimple(_z, nnet, _schedule, context=_batch[1])  # currently only support the extracted feature version
        _metrics['loss'] = accelerator.gather(loss.detach()).mean()
        accelerator.backward(loss.mean())
        optimizer.step()
        lr_scheduler.step()
        train_state.ema_update(config.get('ema_rate', 0.9999))
        train_state.step += 1
        return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics)

    def dpm_solver_sample(_n_samples, _sample_steps, **kwargs):
        _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
        noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())

        def model_fn(x, t_continuous):
            t = t_continuous * _schedule.N
            return cfg_nnet(x, t, **kwargs)

        dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
        _z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / _schedule.N, T=1.)
        return _z

    def eval_step(n_samples, sample_steps):
        logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}, algorithm=dpm_solver, '
                     f'mini_batch_size={config.sample.mini_batch_size}')

        def sample_fn(_n_samples):
            _context = next(context_generator)
            assert _context.size(0) == _n_samples
            return dpm_solver_sample(_n_samples, sample_steps, context=_context)

        with tempfile.TemporaryDirectory() as temp_path:
            path = config.sample.path or temp_path
            if accelerator.is_main_process:
                os.makedirs(path, exist_ok=True)
            utils.sample2dir(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)

            _fid = 0
            if accelerator.is_main_process:
                _fid = calculate_fid_given_paths((dataset.fid_stat, path))
                logging.info(f'step={train_state.step} fid{n_samples}={_fid}')
                with open(os.path.join(config.workdir, 'eval.log'), 'a') as f:
                    print(f'step={train_state.step} fid{n_samples}={_fid}', file=f)
                wandb.log({f'fid{n_samples}': _fid}, step=train_state.step)
            _fid = torch.tensor(_fid, device=device)
            _fid = accelerator.reduce(_fid, reduction='sum')

        return _fid.item()

    logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}')

    step_fid = []
    while train_state.step < config.train.n_steps:
        nnet.train()
        batch = tree_map(lambda x: x.to(device), next(data_generator))
        metrics = train_step(batch)

        nnet.eval()
        if accelerator.is_main_process and train_state.step % config.train.log_interval == 0:
            logging.info(utils.dct2str(dict(step=train_state.step, **metrics)))
            logging.info(config.workdir)
            wandb.log(metrics, step=train_state.step)

        if accelerator.is_main_process and train_state.step % config.train.eval_interval == 0:
            torch.cuda.empty_cache()
            logging.info('Save a grid of images...')
            #contexts = torch.tensor(dataset.contexts, device=device)[: 2 * 5]
            #samples = dpm_solver_sample(_n_samples=2 * 5, _sample_steps=50, context=contexts)
            #samples = make_grid(dataset.unpreprocess(samples), 5)
            #save_image(samples, os.path.join(config.sample_dir, f'{train_state.step}.png'))
            #wandb.log({'samples': wandb.Image(samples)}, step=train_state.step)
            torch.cuda.empty_cache()
        accelerator.wait_for_everyone()

        if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps:
            torch.cuda.empty_cache()
            logging.info(f'Save and eval checkpoint {train_state.step}...')
            if accelerator.local_process_index == 0:
                train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt'))
            accelerator.wait_for_everyone()
            fid = eval_step(n_samples=1, sample_steps=1)  # calculate fid of the saved checkpoint
            step_fid.append((train_state.step, fid))
            torch.cuda.empty_cache()
        accelerator.wait_for_everyone()

    logging.info(f'Finish fitting, step={train_state.step}')
    logging.info(f'step_fid: {step_fid}')
    step_best = sorted(step_fid, key=lambda x: x[1])
    print(step_best)
    logging.info(f'step_best: {step_best}')
    ##note
    train_state.load(os.path.join(config.ckpt_root, f'5.ckpt'))
    accelerator.wait_for_everyone()
    #eval_step(n_samples=10000, sample_steps=50)



In [7]:
train(config)

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


INFO:absl:autoencoder:
  pretrained_path: assets/stable-diffusion/autoencoder_kl.pth
  scale_factor: 0.2301
ckpt_root: /Users/jihyeonje/unidiffuser/test/
config_name: test
dataset:
  cfg: true
  name: ligprot_features
  p_uncond: 0.1
  path: test/feats
hparams: default
lr_scheduler:
  name: customized
  warmup_steps: 5000
mixed_precision: 'no'
nnet:
  clip_dim: 768
  depth: 12
  embed_dim: 512
  img_size: 32
  in_chans: 4
  mlp_ratio: 4
  mlp_time_embed: false
  name: uvit_t2i
  num_clip_token: 77
  num_heads: 8
  patch_size: 2
  qkv_bias: false
optimizer:
  betas: !!python/tuple
  - 0.9
  - 0.9
  lr: 0.0002
  name: adamw
  weight_decay: 0.03
sample:
  cfg: true
  mini_batch_size: 1
  n_samples: 3
  path: ''
  sample_steps: 2
  scale: 1.0
sample_dir: /Users/jihyeonje/unidiffuser/test/
seed: 1234
train:
  batch_size: 2
  eval_interval: 5
  log_interval: 10
  n_steps: 5
  save_interval: 5
workdir: /Users/jihyeonje/unidiffuser/test/
z_shape: !!python/tuple
- 4
- 32
- 32

INFO:absl:Run on 

Prepare dataset...
Prepare dataset ok
prepare the dataset for classifier free guidance with p_uncond=0.1
attention mode is math


INFO:absl:nnet has 44692644 parameters
  self.snr = self.cum_alphas / self.cum_betas
INFO:absl:use Schedule([0.         0.00085    0.0008547  0.00085941 0.00086413 0.00086887
 0.00087362 0.00087839 0.00088316 0.00088795]..., 1000)
INFO:absl:Start fitting, step=0, mixed_precision=no


epoch:   0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([2, 4, 32, 32])
orig x
torch.Size([2, 4, 32, 32])
patch
torch.Size([2, 256, 512])
context
torch.Size([2, 77, 768])
torch.Size([2, 77, 512])


KeyboardInterrupt: 

In [None]:
def load_coords_from_pdb(
    pdb,
    atoms,
    method="raw",
    also_bfactors=False,
    normalize_bfactors=True,
):
    """
    Returns array of shape (1, n_res, len(atoms), 3)
    """

    coords = []
    bfactors = []
    if method == "raw":  # Raw numpy implementation, faster than biopdb
        # Indexing into PDB format, allowing XXXX.XXX
        coords_in_pdb = [slice(30, 38), slice(38, 46), slice(46, 54)]
        # Indexing into PDB format, allowing XXX.XX
        bfactor_in_pdb = slice(60, 66)

        with open(pdb, "r") as f:
            resi_prev = 1
            counter = 0
            for l in f:
                l_split = l.rstrip("\n").split()
                if len(l_split) > 0 and l_split[0] == "ATOM" and l_split[2] in atoms:
                    resi = l_split[5]
                    if resi == resi_prev:
                        counter += 1
                    else:
                        counter = 0
                    if counter < len(atoms):
                        xyz = [
                            np.array(l[s].strip()).astype(float) for s in coords_in_pdb
                        ]
                        coords.append(xyz)
                        if also_bfactors:
                            bfactor = np.array(l[bfactor_in_pdb].strip()).astype(float)
                            bfactors.append(bfactor)
                    resi_prev = resi
            coords = torch.Tensor(np.array(coords)).view(1, -1, len(atoms), 3)

    return coords

In [24]:
import protein
import residue_constants

In [26]:
def load_feats_from_pdb(pdb, bb_atoms=["N", "CA", "C", "O"]):
    """
    Load model input features from a PDB file or mmcif file.
    - bb_atoms: list of backbone atom names to load
    - load_atom73: if True, also load atom73 features
    - chain_residx_gap: residue index gap for chain breaks for PDBs with multiple chains
    """
    feats = {}
    protein_obj = protein.read_pdb(pdb)
    bb_idxs = [residue_constants.atom_order[a] for a in bb_atoms]
    bb_coords = torch.from_numpy(protein_obj.atom_positions[:, bb_idxs])
    feats["bb_coords"] = bb_coords.float()
    for k, v in vars(protein_obj).items():
        feats[k] = torch.Tensor(v)
    feats["aatype"] = feats["aatype"].long()
    #if load_atom73:
    #    feats["atom73_coords"], feats["atom73_mask"] = atom37_to_atom73(
    #        feats["atom_positions"], feats["aatype"], return_mask=True
    #    )

    # Handle residue index for PDBs with multiple chains    
    #feats["residue_index"] = add_chain_gap(feats["residue_index"], feats["chain_index"], chain_residx_gap=chain_residx_gap)

    return feats


In [36]:
feats = load_feats_from_pdb(protpaths[1])

In [37]:
feats['atom_positions'].shape

torch.Size([830, 37, 3])

In [43]:
len(feats['aatype'])

830

In [38]:
feats['bb_coords']

torch.Size([830, 4, 3])

In [11]:
from rdkit.Chem import QED
from rdkit.Chem import AllChem
from rdkit import Chem
def mol_extraction(smi):
    atom_encoder = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
    atom_decoder = ['H', 'C', 'N', 'O', 'F'] 
    mol = Chem.MolFromSmiles(smi)

    #start processing 3D coordinates
    updated_mol= Chem.AddHs(mol)
    AllChem.EmbedMolecule(updated_mol)
    AllChem.UFFOptimizeMolecule(updated_mol)
    updated_mol.GetConformer()
    #to get atom type, coordinate, edges
    mblock = Chem.MolToMolBlock(updated_mol)
    parsed = mblock.split("\n")
    stats = parsed[3]
    tot_atoms = int(stats.split(" ")[1])
    one_hot = torch.zeros(tot_atoms, len(atom_decoder))

    coors = parsed[4:tot_atoms+4]
    allcoors = []
    allatoms = []
    
    for j in range(tot_atoms):
        atom_type = coors[j][:-37][-1]
        xyz_coor = [float(x) for x in coors[j][:-37][:-2].split()]
        allcoors.append(xyz_coor)
        atom = atom_encoder[atom_type]
        allatoms.append(atom)
        one_hot[j, atom_encoder[atom_type]] = 1
        
    return allcoors, allatoms, one_hot

In [12]:
def remove_mean(x):
    mean = torch.mean(x, dim=1, keepdim=True)
    x = x - mean
    return x

In [13]:
import torch.nn.functional as F 
from tqdm import tqdm
import numpy as np

smi = "CC1=NCCC(C)O1"
#smi = "[H]C(=O)N([H])C([H])(C([H])([H])[H])C([H])([H])OC([H])([H])C([H])([H])[H]"
id2charge = {'H': 1, 'C': 6, 'N': 7, 'O': 8, 'F': 9}
new_data = {}
x, atom_types, one_hot = mol_extraction(smi)
#at_types = torch.from_numpy(np.asarray(atom_types))[:, None]
# one_hot = torch.eq(at_types, torch.Tensor(atomic_number_list)).int()
#one_hot = F.one_hot(
#                torch.from_numpy(np.asarray(atom_types)),
##                num_classes=5)
#one_hot = one_hot[None,:]
#x = remove_mean(torch.tensor([x]))

n = x.shape[1]

new_data['positions'] = x
new_data['one_hot'] = torch.unsqueeze(one_hot, 0)

c = torch.zeros(1,len(atom_types),1)
h = {'categorical': torch.unsqueeze(one_hot, 0), 'integer': c}

new_data['atom_mask'] = torch.ones((1,n), device='cpu').unsqueeze(2)

edge_mask = torch.ones((n, n), device='cpu')
edge_mask[~torch.eye(edge_mask.shape[0], dtype=torch.bool)] = 0
new_data['edge_mask'] = edge_mask.flatten()

In [14]:
xh = torch.cat([new_data['positions'], h['categorical'], h['integer']], dim=2)


In [15]:
xh.shape

torch.Size([1, 19, 9])

In [8]:
import torch.nn as nn


In [9]:
t = nn.Linear(9,512)

In [18]:
np.ones((32,32,3))

array([[[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        ...,
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]],

       [[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        ...,
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]],

       [[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        ...,
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]],

       ...,

       [[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        ...,
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]],

       [[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        ...,
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]],

       [[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        ...,
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]]])