In [14]:
from gudiff_model import PDBDataSet_GraphCon
from gudiff_model.Graph_UNet import GraphUNet
from data_rigid_diffuser.diffuser import FrameDiffNoise
from se3_transformer.model.fiber import Fiber
import torch

In [15]:
conf = {'batch_size'  : 16,
              'topk'  : 4,
            'stride'  : 4,
                'KNN' : 30,
          'num_heads' : 8,
           'channels' : 32,
       'channels_div' : 4,
        'nodefeats_0': 32,
        'nodefeats_1':  6,
         'num_layers' : 1,
     'num_layers_ca'  : 2,
   'edge_feature_dim' : 1,
  'latent_pool_type'  : 'avg',
            't_size'  : 12,
             'max_t'  : 0.2,
               'mult' : 2,
           'zero_lin' : True,
          'use_tdeg1' : True,
                'cuda': True,
      'learning rate' : 0.0005,
       'weight_decay' :  5e-6,
        'device'      : 'cuda',
        'num_epoch'   : 100,
        'log_freq'    : 1000,
        'ckpt_freq'   : 10000,
        'early_chkpt' : 2,
        'coord_scale' : 10.0,
        'dataset_max' : 5000,
        'meta_data_path' : '/mnt/h/datasets/bCov_4H/metadata.csv',
        'sample_mode' : 'single_length'}

#check use_tdeg1

In [16]:
def optimizer_to(optim, device):
    for param in optim.state.values():
        # Not sure there are any global tensors in the state dict
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)

In [18]:
class Experiment:

    def __init__(self,
                 conf=None,
                 ckpt_model=None,
                 cur_step=None,
                 cur_epoch=None,
                 name='gu_null',
                 cast_type=torch.float32,
                 ckpt_opt=None):
        """Initialize experiment.
        Args:
            exp_cfg: Experiment configuration.
        """
#         with open(config_path, 'r') as file:
#             config = yaml.safe_load(file)
#         conf = Struct(config)
        #figure out logging
        logging.basicConfig(filename='test.log', level=logging.INFO)
        self._log = logging.getLogger(__name__)
        

        self.name=name
        self._conf = conf
        
        self.coord_scale = conf['coord_scale']
        self.N_CA_dist = (Data_Graph.N_CA_dist/self.coord_scale).to('cuda')
        self.C_CA_dist = (Data_Graph.C_CA_dist/self.coord_scale).to('cuda')
        self.cast_type = cast_type
        
        self.num_epoch = conf['num_epoch']
        self.log_freq = conf['log_freq']
        self.ckpt_freq = conf['ckpt_freq']
        self.early_ckpt = conf['early_chkpt']
        
        
        self.meta_data_path = conf['/mnt/h/datasets/bCov_4H/metadata.csv']
        self.sample_mode = conf['sample_mode']
        self.B = conf['batch_size']
        self.limit = conf['dataset_max']
        
        #graph properties
        self.KNN = conf['KNN']
        self.KNN_radius = conf['KNN_radius']
        self.stride = conf['stride']
        
        #gudiff params
        self.channels_start = conf['channels']
        
        
        self._diffuser = FrameDiffNoise()
        self._graphmaker =  PDBDataSet_GraphCon.Make_KNN_MP_Graphs(mp_stride = self.stride, 
                                                           coord_div = self.coord_scale, 
                                                           cast_type = self.cast_type, 
                                                           channels_start = self.channels_start,
                                                           ndf1= conf['nodefeats_1'], 
                                                           ndf0= conf['nodefeats_0'],
                                                           cuda=True)
        #single_t dataset, for testing
        # sd = smallPDBDataset(fdn , meta_data_path = '/mnt/h/datasets/bCov_4H/metadata.csv', 
        #                      filter_dict=False, maxlen=1000, input_t=0.05)
        

        
        self._model = GraphUNet(fiber_start = Fiber({0:12, 1:2}),
                                fiber_out = Fiber({1:2}),
                                batch_size = self.B, 
                                num_layers_ca = conf['num_layers_ca'],
                                k = conf['topk'],
                                stride = conf['stride'],
                                max_degree = 3,
                                channels_div =  conf['channels_div'],
                                num_heads = conf['num_heads'],
                                num_layers = conf['num_layers'],
                                edge_feature_dim = conf['edge_feature_dim'],
                                latent_pool_type = conf['latent_pool_type'],
                                t_size = conf['t_size'],
                                zero_lin = conf['zero_lin'],
                                use_tdeg1 = conf['use_tdeg1'],
                                cuda = conf['cuda']).to('cuda')

        

        
        num_parameters = sum(p.numel() for p in self._model.parameters())
        self.num_parameters = num_parameters
        self._log.info(f'Number of model parameters {num_parameters}')
#         self._optimizer = EMA(0.980)
#         for name, param in self._model.named_parameters():
#             if param.requires_grad:
#                 self._optimizer.register(name, param.data)

        if ckpt_model is not None:
            ckpt_model = {k.replace('module.', ''):v for k,v in ckpt_model.items()}
            self._model.load_state_dict(ckpt_model, strict=True)
        
        
        self._optimizer = torch.optim.Adam( self._model.parameters(),
                                                       lr=conf['learning rate'],
                                                       weight_decay=conf['weight_decay'])
        if ckpt_opt is not None:
            self._optimizer.load_state_dict(ckpt_opt)
            optimizer_to(self._optimizer, self.device)
        
        
        dt_string = datetime.now().strftime("%dD_%mM_%YY_%Hh_%Mm_%Ss")
        dt_string_short = datetime.now().strftime("%dD_%mM_%YY")
        self.ckpt_dir =  conf_check['ckpt_dir']
        self.eval_dir = conf_check['eval_dir']
        eval_name = f'{self.name}_{dt_string_short}'
        if self.ckpt_dir is not None:
            # Set-up checkpoint location
            ckpt_dir = os.path.join(
                 self.ckpt_dir,
                 self.name,
                 dt_string)
            if not os.path.exists(ckpt_dir):
                os.makedirs(ckpt_dir, exist_ok=True)
            self.ckpt_dir = ckpt_dir
            self._log.info(f'Checkpoints saved to: {ckpt_dir}')
        else:  
            self._log.info('Checkpoint not being saved.')
            
        if self.eval_dir is not None :
            self.eval_dir = os.path.join(
                self.eval_dir,
                eval_name,
                dt_string)
            self.eval_dir = self.eval_dir
            self._log.info(f'Evaluation saved to: {self.eval_dir}')
        else:
            self.eval_dir = os.devnull
            self._log.info(f'Evaluation will not be saved.')
    #         self._aux_data_history = deque(maxlen=100)
    
        if cur_epoch is None:
            self.trained_epochs = 0
        else:
            self.trained_epochs = cur_epoch
            
        if cur_step is None:
            self.trained_steps = 0
        else:
            self.trained_steps = cur_step
            
    @property
    def diffuser(self):
        return self._diffuser

    @property
    def model(self):
        return self._model

    @property
    def conf(self):
        return self._conf
    
    def create_dataset(self, fake_valid=True):
        
        
        self.dataset = PDBDataSet_GraphCon.smallPDBDataset(self.fdn , meta_data_path = self.meta_data_path, 
                             filter_dict=False, maxlen=self.limit)
        
        self.train_sample = PDBDataSet_GraphCon.TrainSampler(self.B, self.dataset, sample_mode='single_length')
        
        train_dL = torch.utils.data.DataLoader(self.dataset, sampler=self.train_sample,
                                                     batch_size=self.B, shuffle=False, collate_fn=None)
        
        if fake_valid:
            valid_dL = train_dL
        else:
            valid_dL = train_dL
            #not implemented yet
        
        return train_dL, valid_dL