In [52]:
from gudiff_model import PDBDataSet_GraphCon
from gudiff_model.Graph_UNet import GraphUNet
from data_rigid_diffuser.diffuser import FrameDiffNoise
from se3_transformer.model.fiber import Fiber
import torch
import os
import logging
from datetime import datetime
from collections import defaultdict
import time
import tree
from se3_transformer.model.FAPE_Loss import FAPE_loss, Qs2Rs, normQ
from torch import einsum
import numpy as np
import se3_diffuse.utils as du
import copy
import util.pdb_writer 

In [53]:
conf = {'batch_size'  : 16,
              'topk'  : 4,
            'stride'  : 4,
                'KNN' : 30,
          'num_heads' : 8,
           'channels' : 32,
       'channels_div' : 4,
        'nodefeats_0': 32,
        'nodefeats_1':  6,
         'num_layers' : 1,
     'num_layers_ca'  : 2,
   'edge_feature_dim' : 1,
  'latent_pool_type'  : 'avg',
            't_size'  : 12,
             'max_t'  : 0.2,
               'mult' : 2,
           'zero_lin' : True,
          'use_tdeg1' : True,
                'cuda': True,
      'learning rate' : 0.0005,
       'weight_decay' :  5e-6,
        'device'      : 'cuda',
        'num_epoch'   : 100,
        'log_freq'    : 1000,
        'ckpt_freq'   : 10000,
        'early_chkpt' : 2,
        'coord_scale' : 10.0,
        'dataset_max' : 5000,
        'meta_data_path' : '/mnt/h/datasets/bCov_4H/metadata.csv',
        'sample_mode' : 'single_length',
           'ckpt_dir' : 'GUN_checkpoints/',
              'eval_dir' : 'Eval_Direc/',
       }

#check use_tdeg1

In [54]:
def optimizer_to(optim, device):
    for param in optim.state.values():
        # Not sure there are any global tensors in the state dict
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)

In [72]:
class Experiment:

    def __init__(self,
                 conf,
                 ckpt_model=None,
                 cur_step=None,
                 cur_epoch=None,
                 name='gu_null',
                 cast_type=torch.float32,
                 ckpt_opt=None):
        """Initialize experiment.
        Args:
            exp_cfg: Experiment configuration.
        """
#         with open(config_path, 'r') as file:
#             config = yaml.safe_load(file)
#         conf = Struct(config)
        #figure out logging
        logging.basicConfig(filename='test.log', level=logging.INFO)
        self._log = logging.getLogger(__name__)
        

        self.name=name
        self._conf = conf
        if conf['cuda']:
            self.device = 'cuda'
        else:
            self.device = 'cpu'
        
        self.coord_scale = conf['coord_scale']
        self.N_CA_dist = (PDBDataSet_GraphCon.N_CA_dist/self.coord_scale).to(self.device)
        self.C_CA_dist = (PDBDataSet_GraphCon.C_CA_dist/self.coord_scale).to(self.device)
        self.cast_type = cast_type
        
        self.num_epoch = conf['num_epoch']
        self.log_freq = conf['log_freq']
        self.ckpt_freq = conf['ckpt_freq']
        self.early_ckpt = conf['early_chkpt']
        
        
        self.meta_data_path = conf['meta_data_path']
        self.sample_mode = conf['sample_mode']
        self.B = conf['batch_size']
        self.limit = conf['dataset_max']
        
        #graph properties
        self.KNN = conf['KNN']
        self.stride = conf['stride']
        
        #gudiff params
        self.channels_start = conf['channels']
        
        
        self._diffuser = FrameDiffNoise()
        self._graphmaker =  PDBDataSet_GraphCon.Make_KNN_MP_Graphs(mp_stride = self.stride, 
                                                           coord_div = self.coord_scale, 
                                                           cast_type = self.cast_type, 
                                                           channels_start = self.channels_start,
                                                           ndf1= conf['nodefeats_1'], 
                                                           ndf0= conf['nodefeats_0'],
                                                           cuda=conf['cuda']) #cuda is bool True, mod at some point
        #single_t dataset, for testing
        # sd = smallPDBDataset(fdn , meta_data_path = '/mnt/h/datasets/bCov_4H/metadata.csv', 
        #                      filter_dict=False, maxlen=1000, input_t=0.05)
        


        
        self._model = GraphUNet(fiber_start = Fiber({0:12, 1:2}),
                                fiber_out = Fiber({1:2}),
                                batch_size = self.B, 
                                num_layers_ca = conf['num_layers_ca'],
                                k = conf['topk'],
                                stride = conf['stride'],
                                max_degree = 3,
                                channels_div =  conf['channels_div'],
                                num_heads = conf['num_heads'],
                                num_layers = conf['num_layers'],
                                edge_feature_dim = conf['edge_feature_dim'],
                                latent_pool_type = conf['latent_pool_type'],
                                t_size = conf['t_size'],
                                zero_lin = conf['zero_lin'],
                                use_tdeg1 = conf['use_tdeg1'],
                                cuda = conf['cuda']).to(self.device) #cuda is bool True, mod at some point

        

        
        num_parameters = sum(p.numel() for p in self._model.parameters())
        self.num_parameters = num_parameters
        self._log.info(f'Number of model parameters {num_parameters}')
#         self._optimizer = EMA(0.980)
#         for name, param in self._model.named_parameters():
#             if param.requires_grad:
#                 self._optimizer.register(name, param.data)

        if ckpt_model is not None:
            ckpt_model = {k.replace('module.', ''):v for k,v in ckpt_model.items()}
            self._model.load_state_dict(ckpt_model, strict=True)
        
        
        self._optimizer = torch.optim.Adam( self._model.parameters(),
                                                       lr=conf['learning rate'],
                                                       weight_decay=conf['weight_decay'])
        if ckpt_opt is not None:
            self._optimizer.load_state_dict(ckpt_opt)
            optimizer_to(self._optimizer, self.device)
        
        
        dt_string = datetime.now().strftime("%dD_%mM_%YY_%Hh_%Mm_%Ss")
        dt_string_short = datetime.now().strftime("%dD_%mM_%YY")
        self.ckpt_dir =  conf['ckpt_dir']
        self.eval_dir = conf['eval_dir']
        eval_name = f'{self.name}_{dt_string_short}'
        if self.ckpt_dir is not None:
            # Set-up checkpoint location
            ckpt_dir = os.path.join(
                 self.ckpt_dir,
                 self.name,
                 dt_string)
            if not os.path.exists(ckpt_dir):
                os.makedirs(ckpt_dir, exist_ok=True)
            self.ckpt_dir = ckpt_dir
            self._log.info(f'Checkpoints saved to: {ckpt_dir}')
        else:  
            self._log.info('Checkpoint not being saved.')
            
        if self.eval_dir is not None :
            self.eval_dir = os.path.join(
                self.eval_dir,
                eval_name,
                dt_string)
            self.eval_dir = self.eval_dir
            self._log.info(f'Evaluation saved to: {self.eval_dir}')
        else:
            self.eval_dir = os.devnull
            self._log.info(f'Evaluation will not be saved.')
    #         self._aux_data_history = deque(maxlen=100)
    
        if cur_epoch is None:
            self.trained_epochs = 0
        else:
            self.trained_epochs = cur_epoch
            
        if cur_step is None:
            self.trained_steps = 0
        else:
            self.trained_steps = cur_step
            
    @property
    def diffuser(self):
        return self._diffuser

    @property
    def model(self):
        return self._model

    @property
    def conf(self):
        return self._conf
    
    def create_dataset(self, fake_valid=True):
        
        
        self.dataset = PDBDataSet_GraphCon.smallPDBDataset( self._diffuser , meta_data_path = self.meta_data_path, 
                             filter_dict=False, maxlen=self.limit)
        
        self.train_sample = PDBDataSet_GraphCon.TrainSampler(self.B, self.dataset, sample_mode='single_length')
        
        train_dL = torch.utils.data.DataLoader(self.dataset, sampler=self.train_sample,
                                                     batch_size=self.B, shuffle=False, collate_fn=None)
        
        if fake_valid:
            valid_dL = train_dL
        else:
            valid_dL = train_dL
            #not implemented yet
        
        return train_dL, valid_dL
    #unchecked
    def start_training(self, return_logs=False):


        self._model = self._model.to(self.device)
        print(f"Using device: {self.device}")

        self._model.train()
        (train_loader, valid_loader) = self.create_dataset()

        logs = []
        print('number of epochs', self.num_epoch)
        for epoch in range(self.trained_epochs, self.num_epoch+self.trained_epochs):
            print('epoch', epoch)
            print('mem_used',torch.cuda.memory_allocated('cuda:0'))
            epoch_log = self.train_epoch(
                train_loader,
                valid_loader,
                epoch=epoch,
                return_logs=return_logs
            )
            if return_logs:
                logs.append(epoch_log)

        self._log.info('Done')
        return logs
    #unchecked
    def train_epoch(self, train_loader, valid_loader,epoch=0, return_logs=False):
        
        log_lossses = defaultdict(list)
    
        global_logs = []
        log_time = time.time()
        step_time = time.time()
        losskeeper = []
        for train_feats in train_loader:
            
            #train_feats = tree.map_structure(lambda x: x.to(device), train_feats)
            loss, aux_data = self.update_fn(train_feats)
#             for k,v in aux_data.items():
#                 log_lossses[k].append(np.array(v))
            log_lossses['loss'].append(loss.to('cpu').numpy())
            
            self.trained_steps += 1

            
            
            # Logging to terminal
            if self.trained_steps == 1 or self.trained_steps % self.log_freq == 0:
                elapsed_time = time.time() - log_time
                log_time = time.time()
                step_per_sec = self.log_freq / elapsed_time
                rolling_losses = tree.map_structure(np.mean, log_lossses)
                loss_log = ' '.join([
                    f'{k}={v[0]:.4f}'
                    for k,v in rolling_losses.items() if 'batch' not in k
                ])
                
                self._log.info(
                    f'[{self.trained_steps}]: {loss_log}, steps/sec={step_per_sec:.5f}')
                log_lossses = defaultdict(list)
                
                print(f'[{self.trained_steps}]: {loss_log}, steps/sec={step_per_sec:.5f}')
                print(np.mean(losskeeper[-1000:]))

            # Take checkpoint
            
            if self.ckpt_dir is not None and (
                    (self.trained_steps % self.ckpt_freq) == 0
                    or (self.early_ckpt and self.trained_steps == 2)
                ):
                ckpt_path = os.path.join(
                    self.ckpt_dir, f'step_{self.trained_steps}.pth')
                du.write_checkpoint(
                    ckpt_path,
                    copy.deepcopy(self.model.state_dict()),
                    self._conf,
                    copy.deepcopy(self._optimizer.state_dict()),
                    self.trained_epochs,
                    self.trained_steps,
                    logger=self._log,
                    use_torch=True
                )
                

                # Run evaluation
                self._log.info(f'Running evaluation of {ckpt_path}')
                start_time = time.time()
                eval_dir = os.path.join(self.eval_dir, f'step_{self.trained_steps}')
                print('eval',eval_dir)
                os.makedirs(eval_dir, exist_ok=True)
                ckpt_metrics = self.eval_fn(valid_loader,eval_dir,epoch=epoch)
                eval_time = time.time() - start_time
                self._log.info(f'Finished evaluation in {eval_time:.2f}s')
            else:
                ckpt_metrics = None
                eval_time = None


            if torch.isnan(loss):                
                raise Exception(f'NaN encountered')
                
    def update_fn(self, data):
        """Updates the state using some data and returns metrics."""
        self._optimizer.zero_grad()
        
        batch_feats= tree.map_structure(
                        lambda x: x.to(self.device), data)
        noised_dict =   {'CA': batch_feats['CA_noised'] ,
                         'N_CA': batch_feats['N_CA_noised'].unsqueeze(-2) ,
                         'C_CA': batch_feats['C_CA_noised'].unsqueeze(-2)  }
        
        
        loss, aux_data = self.loss_fn(batch_feats, noised_dict)
        loss.backward()
        self._optimizer.step()
        loss_out = loss.detach().cpu()
        return loss_out , aux_data
    
    
    def generate_tbatch(self, index_in, input_t):
        batch_list = []
        for i,t in enumerate(input_t):
            batch_list.append(self.dataset.get_specific_t(index_in[i], input_t[i]))

        batch_feats = {}
        for k in batch_list[0].keys():
            batch_feats[k] = torch.stack([batch_list[i][k] for i in range(len(batch_list))])
            
        return batch_feats
    
    def eval_model(self, batch_feats, noised_dict, t_val=None):
    
        L = batch_feats['CA'].shape[1]
        B = batch_feats['CA'].shape[0]
        CA_t  = batch_feats['CA']
        NC_t = CA_t +  batch_feats['N_CA']
        CC_t = CA_t +  batch_feats['C_CA']
        true =  torch.cat((NC_t,CA_t,CC_t),dim=2).reshape(B,L,3,3)

        CA_n  = batch_feats['CA_noised'].reshape(B, L, 3)
        NC_n = CA_n + batch_feats['N_CA_noised'].reshape(B, L, 3)
        CC_n = CA_n + batch_feats['C_CA_noised'].reshape(B, L, 3)
        noise_xyz =  torch.cat((NC_n,CA_n,CC_n),dim=2).reshape(B,L,3,3)

        x = self._graphmaker.prep_for_network(noised_dict)
        
        with torch.no_grad():
            out = self._model(x, batch_feats['t'])
            CA_p = out['1'][:,0,:].reshape(B, L, 3) + CA_n #translation of Calpha
            Qs = out['1'][:,1,:] # rotation of frame
            Qs = Qs.unsqueeze(1).repeat((1,2,1))
            Qs = torch.cat((torch.ones((B*L,2,1),device=Qs.device),Qs),dim=-1).reshape(B,L,2,4)
            Qs = normQ(Qs)
            Rs = Qs2Rs(Qs)
            N_C_to_Rot = torch.cat((noised_dict['N_CA'].reshape(B, L, 3),
                                    noised_dict['C_CA'].reshape(B, L, 3)),dim=2).reshape(B,L,2,1,3)

            rot_vecs = einsum('bnkij,bnkhj->bnki',Rs, N_C_to_Rot)
            NC_p = CA_p + rot_vecs[:,:,0,:]*self.N_CA_dist 
            CC_p = CA_p + rot_vecs[:,:,1,:]*self.C_CA_dist 

            pred = torch.cat((NC_p,CA_p,CC_p),dim=2).reshape(B,L,3,3)

            tloss, loss = FAPE_loss(pred.unsqueeze(0), true, batch_feats['score_scale'])
            
            eval_dict = {'true'  : true.to('cpu').numpy()*self.coord_scale,
                    'noise' : noise_xyz.to('cpu').numpy()*self.coord_scale,
                    'pred'  : pred.to('cpu').numpy()*self.coord_scale,
                    'loss'  : tloss.to('cpu').numpy()}
            
        return eval_dict
    
    def eval_fn(self, valid_loader, eval_dir, epoch=0, input_t=None, max_cycles=10):
        
        train_feats = next(iter(valid_loader))

        if input_t is None:
            #visualize_T
            vis_t = np.array([0.01,0.05,0.1,0.2,0.3,0.5,0.8,1.0])
            vis_t = vis_t[None,...].repeat(int(np.ceil(self.B/len(vis_t))),axis=0).flatten()[:self.B]
        elif type(input_t) == float:
            vis_t = np.ones((self.B,))*input_t
        else:
            vis_t = input_t
            
            
        index_in = np.random.choice(np.arange(len(self.dataset)), size=len(vis_t))
        batch_feats = self.generate_tbatch( index_in,vis_t)

        batch_feats= tree.map_structure(
                        lambda x: x.to(self.device), batch_feats)
        noised_dict =   {'CA': batch_feats['CA_noised'] ,
                         'N_CA': batch_feats['N_CA_noised'].unsqueeze(-2) ,
                         'C_CA': batch_feats['C_CA_noised'].unsqueeze(-2)  }


        eval_dict = self.eval_model(batch_feats,noised_dict)
        util.pdb_writer.dump_tnp(eval_dict['true'], 
                                      eval_dict['noise'], 
                                      eval_dict['pred'], vis_t, e=epoch, 
                                      numOut=len(vis_t), outdir=eval_dir)
        losskeeper = []
        eval_steps = 0


        for i,train_feats in enumerate(valid_loader):
            
            batch_feats= tree.map_structure(
                lambda x: x.to(self.device),train_feats)
            noised_dict =   {'CA': batch_feats['CA_noised'] ,
                             'N_CA': batch_feats['N_CA_noised'].unsqueeze(-2) ,
                             'C_CA': batch_feats['C_CA_noised'].unsqueeze(-2)  }

            eval_dict = self.eval_model(batch_feats, noised_dict)
            eval_steps += 1
            losskeeper.append(eval_dict['loss'])   

            if i>max_cycles:
                break
        print('eval_loss',np.mean(losskeeper[-1000:]),len(losskeeper))

    
    def loss_fn(self, batch_feats, noised_dict, t_val=None):
        
        L = batch_feats['CA'].shape[1]
        B = batch_feats['CA'].shape[0]
        CA_t  = batch_feats['CA']
        NC_t = CA_t +  batch_feats['N_CA']
        CC_t = CA_t +  batch_feats['C_CA']
        true =  torch.cat((NC_t,CA_t,CC_t),dim=2).reshape(B,L,3,3)

        CA_n  = batch_feats['CA_noised'].reshape(B, L, 3)
        NC_n = CA_n + batch_feats['N_CA_noised'].reshape(B, L, 3)
        CC_n = CA_n + batch_feats['C_CA_noised'].reshape(B, L, 3)
        noise_xyz =  torch.cat((NC_n,CA_n,CC_n),dim=2).reshape(B,L,3,3)

        x = self._graphmaker.prep_for_network(noised_dict)
        out = self._model(x, batch_feats['t'])
        CA_p = out['1'][:,0,:].reshape(B, L, 3) + CA_n #translation of Calpha
        Qs = out['1'][:,1,:] # rotation of frame
        Qs = Qs.unsqueeze(1).repeat((1,2,1))
        Qs = torch.cat((torch.ones((B*L,2,1),device=Qs.device),Qs),dim=-1).reshape(B,L,2,4)
        Qs = normQ(Qs)
        Rs = Qs2Rs(Qs)
        N_C_to_Rot = torch.cat((noised_dict['N_CA'].reshape(B, L, 3),
                                noised_dict['C_CA'].reshape(B, L, 3)),dim=2).reshape(B,L,2,1,3)

        rot_vecs = einsum('bnkij,bnkhj->bnki',Rs, N_C_to_Rot)
        NC_p = CA_p + rot_vecs[:,:,0,:]*self.N_CA_dist 
        CC_p = CA_p + rot_vecs[:,:,1,:]*self.C_CA_dist 

        pred = torch.cat((NC_p,CA_p,CC_p),dim=2).reshape(B,L,3,3)

        tloss, loss = FAPE_loss(pred.unsqueeze(0), true, batch_feats['score_scale'])

        return tloss, loss #final_loss, aux_loss

In [73]:
exp = Experiment(conf)

In [74]:
tl, vl = exp.create_dataset()

In [75]:
exp.train_epoch(tl, vl,epoch=0, return_logs=False)

[1]: loss=0.2667, steps/sec=638.73152
nan


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


eval Eval_Direc/gu_null_09D_08M_2024Y/09D_08M_2024Y_00h_53m_42s/step_2




eval_loss 0.24677116 12


KeyboardInterrupt: 

In [None]:
def eval_fn(self, valid_loader, eval_dir, epoch=0, input_t=None, max_cycles=10):
        
    train_feats = next(iter(valid_loader))

    if input_t is None:
        #visualize_T
        vis_t = np.array([0.01,0.05,0.1,0.2,0.3,0.5,0.8,1.0])
        vis_t = vis_t[None,...].repeat(int(np.ceil(self.B/len(vis_t))),axis=0).flatten()[:self.B]
    elif type(input_t) == float:
        vis_t = np.ones((self.B,))*input_t
    else:
        vis_t = input_t

    noised_dict = self.fnd.forward(train_feats, t_vec=vis_t)
    batched_t = to_cuda( noised_dict['t_vec'] )

    eval_dict = self.eval_model(noised_dict, batched_t)
    util.pdb_writer.dump_tnp(eval_dict['true'], 
                                  eval_dict['noise_xyz'], 
                                  eval_dict['pred'], vis_t, e=epoch, 
                                  numOut=len(vis_t), outdir=eval_dir)
    losskeeper = []
    eval_steps = 0


    for i,train_feats in enumerate(valid_loader):
        noised_dict = self.fnd.forward(train_feats)
        batched_t = to_cuda( noised_dict['t_vec'] )
        eval_dict = self.eval_model(noised_dict, batched_t)
        eval_steps += 1
        losskeeper.append(eval_dict['loss'])   

        if i>max_cycles:
            break

    print(f'[{eval_steps}]: {loss_log}')
    print('eval_loss',np.mean(losskeeper[-1000:]),len(losskeeper))
        

In [None]:
            eval_dict = {'true'  : true.to('cpu').numpy()*self.coords_scale,
                    'noise' : noise_xyz.to('cpu').numpy()*self.coords_scale,
                    'pred'  : pred.to('cpu').numpy()*self.coords_scale,
                    'loss'  : tloss.to('cpu').numpy()}

In [None]:
    def eval_model(self, noised_dict, batched_t):
        
        def convert_pV_to_points(dict_in, L, key_in='bb_firstp', return_indiv=False):
            """Concatenates to xyz from Calpha+atom vectors"""
            CA_fp  = dict_in[key_in]['CA'].reshape(self.B, L, 3).to(self.device)
            NC_fp = CA_fp + dict_in[key_in]['N_CA'].reshape(self.B, L, 3).to(self.device)
            CC_fp = CA_fp + dict_in[key_in]['C_CA'].reshape(self.B, L, 3).to(self.device)
            fp =  torch.cat((NC_fp,CA_fp,CC_fp),dim=2).reshape(self.B,L,3,3)
            if return_indiv:
                return fp, CA_fp, NC_fp, CC_fp
            return fp
    
        CA_t  = noised_dict['bb_shifted']['CA'].reshape(self.B, self.L, 3).to(self.device)
        NC_t = CA_t + noised_dict['bb_shifted']['N_CA'].reshape(self.B, self.L, 3).to(self.device)*self.N_CA_dist
        CC_t = CA_t + noised_dict['bb_shifted']['C_CA'].reshape(self.B, self.L, 3).to(self.device)*self.C_CA_dist
        true =  torch.cat((NC_t,CA_t,CC_t),dim=2).reshape(self.B,self.L,3,3)

        CA_n  = noised_dict['bb_noised']['CA'].reshape(self.B, self.L, 3).to(self.device)
        NC_n = CA_n + noised_dict['bb_noised']['N_CA'].reshape(self.B, self.L, 3).to(self.device)*self.N_CA_dist
        CC_n = CA_n + noised_dict['bb_noised']['C_CA'].reshape(self.B, self.L, 3).to(self.device)*self.C_CA_dist
        noise_xyz =  torch.cat((NC_n,CA_n,CC_n),dim=2).reshape(self.B, self.L,3,3)
        
        feat_dict = self.mkg.prep_for_network(noised_dict) #prepares graphs
        with torch.no_grad():
            out = self._model(feat_dict, batched_t)
            CA_p = out['1'][:,0,:].reshape(self.B, self.L, 3) + CA_n #translation of Calpha
            Qs = out['1'][:,1,:] # rotation
            Qs = Qs.unsqueeze(1).repeat((1,2,1))
            Qs = torch.cat((torch.ones((self.B*self.L,2,1),device=Qs.device),Qs),dim=-1).reshape(self.B,self.L,2,4)
            Qs = normQ(Qs)
            Rs = Qs2Rs(Qs)
            N_C_to_Rot = torch.cat((noised_dict['bb_noised']['N_CA'].reshape(self.B, self.L, 3).to(self.device),
                                    noised_dict['bb_noised']['C_CA'].reshape(self.B, self.L, 3).to(self.device)),
                                   dim=2).reshape(self.B,self.L,2,1,3)


            rot_vecs = einsum('bnkij,bnkhj->bnki',Rs, N_C_to_Rot)
            NC_p = CA_p + rot_vecs[:,:,0,:].to(self.device)*self.N_CA_dist
            CC_p = CA_p + rot_vecs[:,:,1,:].reshape(self.B, self.L, 3).to(self.device)*self.C_CA_dist

            pred = torch.cat((NC_p,CA_p,CC_p),dim=2).reshape(self.B,self.L,3,3)
            
            fp = convert_pV_to_points(noised_dict, 1, key_in='bb_firstp')
            lp = convert_pV_to_points(noised_dict,1,  key_in='bb_lastp')                                             
            real_mask = noised_dict['real_nodes_mask'].to(self.device)
            score_scales = noised_dict['score_scales'].to(self.device)

            nf_pred = out['0']
            real_nodes_pred = torch.round(nf_pred).clamp(0,1)
            real_nodes_pred_mask = (real_nodes_pred.squeeze().sum(-1)>self.real_threshold).reshape(self.B,self.L)
            
            lr, lr_d = FAPE_loss_real(pred, true, score_scales, real_mask,  d_clamp=10.0, d_clamp_inter=30.0,
                           A=10.0, gamma=1.0, eps=1e-6)

            ln, ln_d = FAPE_loss_null(pred, fp, lp, real_mask, score_scales,  d_clamp=10.0,
                               d_clamp_inter=30.0, A=10.0, gamma=1.0, eps=1e-6)

            ln = ln*self.score_weights['3D_null']
            lr = lr*self.score_weights['3D_real']

            structure_loss = lr + ln

            #score for node feats determining whether node is real or fake
            nf_pred = out['0']

            nf_feat_dim = noised_dict['real_nodes_noise'].shape[-1]
            nf_true = torch.ones(noised_dict['real_nodes_mask'].shape+(nf_feat_dim,) + (1,),
                                 dtype=torch.float,device = self.device)

            nf_real_mask_mult = real_mask.unsqueeze(-1).unsqueeze(-1).to(self.device)
            nf_true = nf_true*nf_real_mask_mult

            nf_pred = nf_pred.reshape(self.B,-1,nf_feat_dim)
            pred_nf_loss = torch.sum(torch.square(nf_true.squeeze()-nf_pred),dim=-1)

            ss_scales = to_cuda(noised_dict['score_scales'])[:,None,None]
            pnfloss = (torch.sum((pred_nf_loss*ss_scales))/(self.L*self.nf_dim))*self.score_weights['nf_real']

            final_loss = structure_loss + pnfloss
            val_loss = {'pnf_loss': pnfloss.detach().cpu(),
                        'structure_loss':structure_loss.detach().cpu(),
                        'structure_null':ln.detach().cpu(),
                        'structure_real':lr.detach().cpu()}

        real_nodes_true_mask = noised_dict['real_nodes_mask']
        
        del nf_real_mask_mult
        del ss_scales
        del real_mask
        del score_scales
        del batched_t
        del structure_loss
        del pnfloss
        del nf_pred
        del nf_true
        del N_C_to_Rot
        del CA_n, NC_n, CC_n
        
        for k,v in out.items():
            del v
        for k,v in feat_dict.items():
            del v
        
        #needs to be rolled to N-terminal = 0 for pymol output, add coord_scale
        return true.detach().cpu(), noise_xyz.detach().cpu(), pred.detach().cpu() , real_nodes_pred_mask.detach().cpu(), real_nodes_true_mask.detach().cpu(), val_loss
    
    def eval_fn(self, valid_loader, eval_dir, epoch=0, input_t=None, max_cycles=10):
        
        train_feats = next(iter(valid_loader))
        
        if input_t is None:
            #visualize_T
            vis_t = np.array([0.01,0.05,0.1,0.2,0.3,0.5,0.8,1.0])
            vis_t = vis_t[None,...].repeat(int(np.ceil(self.B/len(vis_t))),axis=0).flatten()[:self.B]
        elif type(input_t) == float:
            vis_t = np.ones((self.B,))*input_t
        else:
            vis_t = input_t

        noised_dict = self.fnd.forward(train_feats, t_vec=vis_t)
        batched_t = to_cuda( noised_dict['t_vec'] )

        true, noise_xyz, pred, real_nodes_pred_mask, real_nodes_true_mask, val_loss = self.eval_model(noised_dict, batched_t)
        util.pdb_writer.dump_tnp_null(true, noise_xyz, pred, vis_t, e=epoch, 
                      numOut=len(vis_t), real_mask=real_nodes_true_mask, 
                      pred_mask=real_nodes_pred_mask.detach().cpu(), outdir=eval_dir)
        
        log_lossses = defaultdict(list)
        losskeeper = []
        eval_steps = 0
        
        
        for i,train_feats in enumerate(valid_loader):
            noised_dict = self.fnd.forward(train_feats)
            batched_t = to_cuda( noised_dict['t_vec'] )
            true, noise_xyz, pred, real_nodes_pred_mask, real_nodes_true_mask, val_loss = self.eval_model(noised_dict, batched_t)
            eval_steps += 1
            lsum = 0 
            for k,v in val_loss.items():
                log_lossses[k].append(np.array(v))
                lsum+=v.detach().cpu()
            losskeeper.append(lsum)   
            
            del true
            del noise_xyz
            del pred
            del real_nodes_pred_mask
            for k,v in val_loss.items():
                del v
            del batched_t
            if i>max_cycles:
                break
            
            
        # Logging to terminal
        rolling_losses = tree.map_structure(np.mean, log_lossses)
        loss_log = ' '.join([
            f'{k}={v[0]:.4f}'
            for k,v in rolling_losses.items() if 'batch' not in k
        ])

        log_lossses = defaultdict(list)
        print(f'[{eval_steps}]: {loss_log}')
        print('eval_loss',np.mean(losskeeper[-1000:]),len(losskeeper))
        
        for k,v in noised_dict.items():
            del v