In [1]:
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import logging
import os
from typing import Optional

import numpy as np
import torch
from torch_geometric.nn import radius_graph
from torch_scatter import scatter, segment_coo

from ocpmodels.common.registry import registry
from ocpmodels.common.utils import (
    compute_neighbors,
    conditional_grad,
    get_max_neighbors_mask,
    get_pbc_distances,
    radius_graph_pbc,
    scatter_det
    
)
from ocpmodels.datasets import LmdbDataset
from ocpmodels.modules.scaling.compat import load_scales_compat

from ocpmodels.models.gemnet_oc.initializers import get_initializer
from ocpmodels.models.gemnet_oc.interaction_indices import (
    get_mixed_triplets,
    get_quadruplets,
    get_triplets,
)
from ocpmodels.models.gemnet_oc.layers.atom_update_block import OutputBlock
from ocpmodels.models.gemnet_oc.layers.base_layers import Dense, ResidualLayer
from ocpmodels.models.gemnet_oc.layers.efficient import BasisEmbedding
from ocpmodels.models.gemnet_oc.layers.embedding_block import AtomEmbedding, EdgeEmbedding
from ocpmodels.models.gemnet_oc.layers.force_scaler import ForceScaler
from ocpmodels.models.gemnet_oc.layers.interaction_block import InteractionBlock
from ocpmodels.models.gemnet_oc.layers.radial_basis import RadialBasis
from ocpmodels.models.gemnet_oc.layers.spherical_basis import CircularBasisLayer, SphericalBasisLayer
from ocpmodels.models.gemnet_oc.utils import (
    get_angle,
    get_edge_id,
    get_inner_idx,
    inner_product_clamped,
    mask_neighbors,
    repeat_blocks,
)
from ocpmodels.trainers import ForcesTrainer
from ocpmodels import models
from ocpmodels.models.gemnet_oc.gemnet_oc import GemNetOC

  from .autonotebook import tqdm as notebook_tqdm
ERROR:root:Invalid setup for SCN. Either the e3nn library or Jd.pt is missing.


# Model

In [2]:
@registry.register_model('Transfer_Gem')
class Transfer_Gem(GemNetOC):
    """
    Arguments
    ---------
    num_atoms (int): Unused argument
    bond_feat_dim (int): Unused argument
    num_targets: int
        Number of prediction targets.

    num_spherical: int
        Controls maximum frequency.
    num_radial: int
        Controls maximum frequency.
    num_blocks: int
        Number of building blocks to be stacked.

    emb_size_atom: int
        Embedding size of the atoms.
    emb_size_edge: int
        Embedding size of the edges.
    emb_size_trip_in: int
        (Down-projected) embedding size of the quadruplet edge embeddings
        before the bilinear layer.
    emb_size_trip_out: int
        (Down-projected) embedding size of the quadruplet edge embeddings
        after the bilinear layer.
    emb_size_quad_in: int
        (Down-projected) embedding size of the quadruplet edge embeddings
        before the bilinear layer.
    emb_size_quad_out: int
        (Down-projected) embedding size of the quadruplet edge embeddings
        after the bilinear layer.
    emb_size_aint_in: int
        Embedding size in the atom interaction before the bilinear layer.
    emb_size_aint_out: int
        Embedding size in the atom interaction after the bilinear layer.
    emb_size_rbf: int
        Embedding size of the radial basis transformation.
    emb_size_cbf: int
        Embedding size of the circular basis transformation (one angle).
    emb_size_sbf: int
        Embedding size of the spherical basis transformation (two angles).

    num_before_skip: int
        Number of residual blocks before the first skip connection.
    num_after_skip: int
        Number of residual blocks after the first skip connection.
    num_concat: int
        Number of residual blocks after the concatenation.
    num_atom: int
        Number of residual blocks in the atom embedding blocks.
    num_output_afteratom: int
        Number of residual blocks in the output blocks
        after adding the atom embedding.
    num_atom_emb_layers: int
        Number of residual blocks for transforming atom embeddings.
    num_global_out_layers: int
        Number of final residual blocks before the output.

    regress_forces: bool
        Whether to predict forces. Default: True
    direct_forces: bool
        If True predict forces based on aggregation of interatomic directions.
        If False predict forces based on negative gradient of energy potential.
    use_pbc: bool
        Whether to use periodic boundary conditions.
    scale_backprop_forces: bool
        Whether to scale up the energy and then scales down the forces
        to prevent NaNs and infs in backpropagated forces.

    cutoff: float
        Embedding cutoff for interatomic connections and embeddings in Angstrom.
    cutoff_qint: float
        Quadruplet interaction cutoff in Angstrom.
        Optional. Uses cutoff per default.
    cutoff_aeaint: float
        Edge-to-atom and atom-to-edge interaction cutoff in Angstrom.
        Optional. Uses cutoff per default.
    cutoff_aint: float
        Atom-to-atom interaction cutoff in Angstrom.
        Optional. Uses maximum of all other cutoffs per default.
    max_neighbors: int
        Maximum number of neighbors for interatomic connections and embeddings.
    max_neighbors_qint: int
        Maximum number of quadruplet interactions per embedding.
        Optional. Uses max_neighbors per default.
    max_neighbors_aeaint: int
        Maximum number of edge-to-atom and atom-to-edge interactions per embedding.
        Optional. Uses max_neighbors per default.
    max_neighbors_aint: int
        Maximum number of atom-to-atom interactions per atom.
        Optional. Uses maximum of all other neighbors per default.
    enforce_max_neighbors_strictly: bool
        When subselected edges based on max_neighbors args, arbitrarily
        select amongst degenerate edges to have exactly the correct number.
    rbf: dict
        Name and hyperparameters of the radial basis function.
    rbf_spherical: dict
        Name and hyperparameters of the radial basis function used as part of the
        circular and spherical bases.
        Optional. Uses rbf per default.
    envelope: dict
        Name and hyperparameters of the envelope function.
    cbf: dict
        Name and hyperparameters of the circular basis function.
    sbf: dict
        Name and hyperparameters of the spherical basis function.
    extensive: bool
        Whether the output should be extensive (proportional to the number of atoms)
    forces_coupled: bool
        If True, enforce that |F_st| = |F_ts|. No effect if direct_forces is False.
    output_init: str
        Initialization method for the final dense layer.
    activation: str
        Name of the activation function.
    scale_file: str
        Path to the pytorch file containing the scaling factors.

    quad_interaction: bool
        Whether to use quadruplet interactions (with dihedral angles)
    atom_edge_interaction: bool
        Whether to use atom-to-edge interactions
    edge_atom_interaction: bool
        Whether to use edge-to-atom interactions
    atom_interaction: bool
        Whether to use atom-to-atom interactions

    scale_basis: bool
        Whether to use a scaling layer in the raw basis function for better
        numerical stability.
    qint_tags: list
        Which atom tags to use quadruplet interactions for.
        0=sub-surface bulk, 1=surface, 2=adsorbate atoms.
    latent: bool
        Decide if output the latent space or not.
    
    """
    def __init__(
        self,
        num_atoms: Optional[int],
        bond_feat_dim: int,
        num_targets: int,
        num_spherical=7,
        num_radial=128,
        num_blocks=4,
        emb_size_atom=256,
        emb_size_edge=512,
        emb_size_trip_in=64,
        emb_size_trip_out=64,
        emb_size_quad_in=32,
        emb_size_quad_out=32,
        emb_size_aint_in=64,
        emb_size_aint_out=64,
        emb_size_rbf=16,
        emb_size_cbf=16,
        emb_size_sbf=32,
        num_before_skip=2,
        num_after_skip=2,
        num_concat=1,
        num_atom=3,
        num_output_afteratom=3,
        num_atom_emb_layers = 0,
        num_global_out_layers = 2,
        regress_forces = True,
        direct_forces = False,
        use_pbc = True,
        scale_backprop_forces = False,
        cutoff = 12.0,
        cutoff_qint = 12.0,
        cutoff_aeaint = 12.0,
        cutoff_aint = 12.0,
        max_neighbors = 30,
        max_neighbors_qint =8,
        max_neighbors_aeaint =20,
        max_neighbors_aint = 1000,
        enforce_max_neighbors_strictly = True,
        rbf = {"name": "gaussian"},
        rbf_spherical = None,
        envelope = {"name": "polynomial", "exponent": 5},
        cbf = {"name": "spherical_harmonics"},
        sbf = {"name": "spherical_harmonics"},
        extensive = True,
        forces_coupled = False,
        output_init = "HeOrthogonal",
        activation = "silu",
        quad_interaction = True,
        atom_edge_interaction = True,
        edge_atom_interaction = True,
        atom_interaction = True,
        scale_basis = False,
        qint_tags = [1, 2],
        num_elements = 83,
        otf_graph = True,
        scale_file = None,
        latent: bool = True,
        **kwargs,  # backwards compatibility with deprecated arguments
    ):
        super().__init__(
                        num_atoms,
                        bond_feat_dim,
                        num_targets,
                        num_spherical=7,
                        num_radial=128,
                        num_blocks=4,
                        emb_size_atom=256,
                        emb_size_edge=512,
                        emb_size_trip_in=64,
                        emb_size_trip_out=64,
                        emb_size_quad_in=32,
                        emb_size_quad_out=32,
                        emb_size_aint_in=64,
                        emb_size_aint_out=64,
                        emb_size_rbf=16,
                        emb_size_cbf=16,
                        emb_size_sbf=32,
                        num_before_skip=2,
                        num_after_skip=2,
                        num_concat=1,
                        num_atom=3,
                        num_output_afteratom=3,
                        num_atom_emb_layers = 0,
                        num_global_out_layers = 2,
                        regress_forces = True,
                        direct_forces = False,
                        use_pbc = True,
                        scale_backprop_forces = False,
                        cutoff = 12.0,
                        cutoff_qint = 12.0,
                        cutoff_aeaint = 12.0,
                        cutoff_aint = 12.0,
                        max_neighbors = 30,
                        max_neighbors_qint =8,
                        max_neighbors_aeaint =20,
                        max_neighbors_aint = 1000,
                        enforce_max_neighbors_strictly = True,
                        rbf = {"name": "gaussian"},
                        rbf_spherical = None,
                        envelope = {"name": "polynomial", "exponent": 5},
                        cbf = {"name": "spherical_harmonics"},
                        sbf = {"name": "spherical_harmonics"},
                        extensive = True,
                        forces_coupled = False,
                        output_init = "HeOrthogonal",
                        activation = "silu",
                        quad_interaction = True,
                        atom_edge_interaction = True,
                        edge_atom_interaction = True,
                        atom_interaction = True,
                        scale_basis = False,
                        qint_tags = [1, 2],
                        num_elements = 83,
                        otf_graph = True,
                        scale_file = None,
                        latent = True,
                        )
        
        self.latent=latent
    
    

    
    def forward(self, data):
        pos = data.pos
        batch = data.batch
        atomic_numbers = data.atomic_numbers.long()
        num_atoms = atomic_numbers.shape[0]
        

        if self.regress_forces and not self.direct_forces:
            pos.requires_grad_(True)

        (
            main_graph,
            a2a_graph,
            a2ee2a_graph,
            qint_graph,
            id_swap,
            trip_idx_e2e,
            trip_idx_a2e,
            trip_idx_e2a,
            quad_idx,
        ) = self.get_graphs_and_indices(data)
        # print('checkpoint1')
        _, idx_t = main_graph["edge_index"]

        (
            basis_rad_raw,
            basis_atom_update,
            basis_output,
            bases_qint,
            bases_e2e,
            bases_a2e,
            bases_e2a,
            basis_a2a_rad,
        ) = self.get_bases(
            main_graph=main_graph,
            a2a_graph=a2a_graph,
            a2ee2a_graph=a2ee2a_graph,
            qint_graph=qint_graph,
            trip_idx_e2e=trip_idx_e2e,
            trip_idx_a2e=trip_idx_a2e,
            trip_idx_e2a=trip_idx_e2a,
            quad_idx=quad_idx,
            num_atoms=num_atoms,
        )
        # print('checkpoint2')

        # Embedding block
        h = self.atom_emb(atomic_numbers)
        # (nAtoms, emb_size_atom)
        m = self.edge_emb(h, basis_rad_raw, main_graph["edge_index"])
        # (nEdges, emb_size_edge)

        x_E, x_F = self.out_blocks[0](h, m, basis_output, idx_t)
        # print(x_E.shape)
        xs_E, xs_F = [x_E], [x_F]
        # (nAtoms, num_targets), (nEdges, num_targets)
        for i in range(self.num_blocks):
            # Interaction block
            h, m = self.int_blocks[i](
                h=h,
                m=m,
                bases_qint=bases_qint,
                bases_e2e=bases_e2e,
                bases_a2e=bases_a2e,
                bases_e2a=bases_e2a,
                basis_a2a_rad=basis_a2a_rad,
                basis_atom_update=basis_atom_update,
                edge_index_main=main_graph["edge_index"],
                a2ee2a_graph=a2ee2a_graph,
                a2a_graph=a2a_graph,
                id_swap=id_swap,
                trip_idx_e2e=trip_idx_e2e,
                trip_idx_a2e=trip_idx_a2e,
                trip_idx_e2a=trip_idx_e2a,
                quad_idx=quad_idx,
            )  # (nAtoms, emb_size_atom), (nEdges, emb_size_edge)

            x_E, x_F = self.out_blocks[i + 1](h, m, basis_output, idx_t)
            # (nAtoms, emb_size_atom), (nEdges, emb_size_edge)
            xs_E.append(x_E)
            xs_F.append(x_F)

        # Implementing attention across pretrained blocks
        E_all = torch.stack(xs_E, dim=0)
#-----------------------------------------------------------------------------------------------------------------------------
        # if self.add_positional_embedding:
        #     E_all = self.MHA_positional_embedding(E_all)

        # if self.attn_type == "base":

        #     alpha = torch.bmm(E_all, torch.transpose(E_all, 1, 2))
        #     alpha = alpha / math.sqrt(E_all.shape[-1])
        #     alpha = self.softmax(alpha)

        #     E_t = torch.bmm(alpha, E_all)
        #     E_t = torch.sum(E_t, dim=0)

        # elif self.attn_type == "multi":

        #     q = self.lin_query_MHA(E_all)
        #     k = self.lin_key_MHA(E_all)
        #     v = self.lin_value_MHA(E_all)

        #     E_t, w = self.MHA(q, k, v)
        #     E_t = torch.sum(E_t, dim=0)

        # if self.attn_type != "base":
        #     E_t = self.out_energy(E_t)
#----------------------------------------------need modified----attempt1-----------------------------------------------------------------
        # if self.freeze:
        #     for i in range(self.after_freeze_numblocks):
        #         h, m = self.int_blocks[i](
        #         h=h,
        #         m=m,
        #         bases_qint=bases_qint,
        #         bases_e2e=bases_e2e,
        #         bases_a2e=bases_a2e,
        #         bases_e2a=bases_e2a,
        #         basis_a2a_rad=basis_a2a_rad,
        #         basis_atom_update=basis_atom_update,
        #         edge_index_main=main_graph["edge_index"],
        #         a2ee2a_graph=a2ee2a_graph,
        #         a2a_graph=a2a_graph,
        #         id_swap=id_swap,
        #         trip_idx_e2e=trip_idx_e2e,
        #         trip_idx_a2e=trip_idx_a2e,
        #         trip_idx_e2a=trip_idx_e2a,
        #         quad_idx=quad_idx,
        #     )  # (nAtoms, emb_size_atom), (nEdges, emb_size_edge)

        #     x_E, x_F = self.out_blocks[i + 1](h, m, basis_output, idx_t)
        #     # (nAtoms, emb_size_atom), (nEdges, emb_size_edge)
        #     xs_E.append(x_E)
        #     xs_F.append(x_F)
        if self.latent:
            # print('checkpoint3')
            return E_all
        else:
#-----------------------------------------------------------------------------------------------------------------------------                
            nMolecules = torch.max(batch) + 1
            if self.extensive:
                E_t = scatter_det(
                    E_t, batch, dim=0, dim_size=nMolecules, reduce="add"
                )  # (nMolecules, num_targets)
            else:
                E_t = scatter_det(
                    E_t, batch, dim=0, dim_size=nMolecules, reduce="mean"
                )  # (nMolecules, num_targets)

            if self.regress_forces:
                if self.direct_forces:
                    if self.forces_coupled:  # enforce F_st = F_ts
                        nEdges = idx_t.shape[0]
                        id_undir = repeat_blocks(
                            main_graph["num_neighbors"] // 2,
                            repeats=2,
                            continuous_indexing=True,
                        )
                        F_st = scatter_det(
                            F_st,
                            id_undir,
                            dim=0,
                            dim_size=int(nEdges / 2),
                            reduce="mean",
                        )  # (nEdges/2, num_targets)
                        F_st = F_st[id_undir]  # (nEdges, num_targets)

                    # map forces in edge directions
                    F_st_vec = F_st[:, :, None] * main_graph["vector"][:, None, :]
                    # (nEdges, num_targets, 3)
                    F_t = scatter_det(
                        F_st_vec,
                        idx_t,
                        dim=0,
                        dim_size=num_atoms,
                        reduce="add",
                    )  # (nAtoms, num_targets, 3)
                else:
                    F_t = self.force_scaler.calc_forces_and_update(E_t, pos)

                E_t = E_t.squeeze(1)  # (num_molecules)
                F_t = F_t.squeeze(1)  # (num_atoms, 3)
                return E_t, F_t
            else:
                E_t = E_t.squeeze(1)  # (num_molecules)
                return E_t




# Generate transformer data

In [3]:
myGemnet=Transfer_Gem(0,0,0,latent='True')




In [4]:
checkpoint_path='params/gemnet_oc_base_oc20_oc22.pt'
pretrained_state_dict = torch.load(checkpoint_path)
new_model_state_dict = myGemnet.state_dict()
filtered_pretrained_state_dict = {k: v for k, v in pretrained_state_dict.items() if k in new_model_state_dict}
new_model_state_dict.update(filtered_pretrained_state_dict)
myGemnet.load_state_dict(new_model_state_dict)

<All keys matched successfully>

In [5]:
import lmdb
import pickle
def generate_lmdb(data, pathname: str):
    """
    atoms_list:: Can be either a list of atoms objects or list of Data objects
    """
    pathname = pathname + '.lmdb' if '.lmdb' not in pathname else pathname
    db = lmdb.open(
        pathname,
        map_size=1099511627 * 3,
        subdir=False,
        meminit=False,
        map_async=True,
    ) 
    txn = db.begin(write=True)        
    length=txn.stat()['entries']        
    txn.put(f"{length}".encode('ascii'), pickle.dumps(data, protocol=0))
    txn.commit()
    db.sync()
    db.close() 

In [6]:
import lmdb
import pickle
def lmdb_add_info(data, pathname: str):
    """
    atoms_list:: Can be either a list of atoms objects or list of Data objects
    """
    pathname = pathname + '.lmdb' if '.lmdb' not in pathname else pathname
    db = lmdb.open(
        pathname,
        map_size=1099511626 * 3,
        subdir=False,
        meminit=False,
        map_async=True,
    ) 
    del data.edge_index
    del data.cell_offsets
    txn = db.begin(write=True)        
    length=txn.stat()['entries']        
    txn.put(f"{length}".encode('ascii'), pickle.dumps(data, protocol=0))
    txn.commit()
    db.sync()
    db.close() 

In [7]:
DEVICE='cuda'

In [8]:
def out_fn(dataloader,model,pathname):

    model.eval()    
    with torch.no_grad():
        for data in dataloader:            
            data=data.to(DEVICE)  
            model=model.to(DEVICE) 
            output=model(data)
            data=data.cpu()
            data.latent=output.cpu()
            del data.edge_index
            del data.cell_offsets
            lmdb_add_info(data,pathname)

In [9]:
dataset=LmdbDataset({"src":"Data/eoh.lmdb"})
dataset[0]

Data(edge_index=[2, 2790], pos=[58, 3], cell=[1, 3, 3], atomic_numbers=[58], natoms=58, fixed=[58], tags=[58], nads=1, y_relaxed=-1.8839389085769653, pos_relaxed=[58, 3], sid=82967, id='0_85', oc22=1, cell_offsets=[2790, 3])

In [10]:
import torch_geometric.data as geom_data
node_data_loader = geom_data.DataLoader(dataset, batch_size=16)
sv_path='Data/eoh_c.lmdb'




In [11]:
out_fn(node_data_loader,myGemnet,sv_path)

MapFullError: mdb_put: MDB_MAP_FULL: Environment mapsize limit reached

# Useless now

In [None]:
task = {
  "dataset": "oc22_lmdb",
  "description": "Regressing to energies and forces for DFT trajectories from OCP",
  "type": "regression",
  "metric": "mae",
  "labels":"potential energy",
}

In [None]:
model={'name':'gemne',
  'num_spherical': 7,
  'num_radial': 128,
  'num_blocks': 4,
  'emb_size_atom': 256,
  'emb_size_edge': 512,
  'emb_size_trip_in': 64,
  'emb_size_trip_out': 64,
  'mb_size_quad_in': 32,
  'emb_size_quad_out': 32,
  'emb_size_aint_in': 64,
  'emb_size_aint_out': 64,
  'emb_size_rbf': 16,
  'emb_size_cbf': 16,
  'emb_size_sbf': 32,
  'num_before_skip': 2,
  'num_after_skip': 2,
  'num_concat': 1,
  'num_atom': 3,
  'num_output_afteratom': 3,
  'cutoff': 12.0,
  'cutoff_qint': 12.0,
  'cutoff_aeaint': 12.0,
  'cutoff_aint': 12.0,
  'max_neighbors': 30,
  'max_neighbors_qint': 8,
  'max_neighbors_aeaint': 20,
  'max_neighbors_aint': 1000,
  'rbf':
      {'name': 'gaussian'},
  'envelope':
      {'name': 'polynomial',
    'exponent': 5},
  'cbf':
      {'name': 'spherical_harmonics'},
  'sbf':
      {'name': 'legendre_outer'},
  'extensive': True,
  'output_init': 'HeOrthogonal',
  'activation': 'silu',
  'regress_forces': True,
  'direct_forces': True,
  'forces_coupled': False,
  'otf_graph': True,
  'quad_interaction': True,
  'atom_edge_interaction': True,
  'edge_atom_interaction': True,
  'atom_interaction': True,
  'num_atom_emb_layers': 2,
  'num_global_out_layers': 2,
  'qint_tags': [1, 2]
  #'latent':True    
}


In [None]:
optimizer = {
  'batch_size': 16,
  'eval_batch_size': 16,
  'load_balancing': 'atoms',
  'eval_every': 5000,
  'num_workers': 2,
  'lr_initial': 5.e-4,
  'optimizer': 'AdamW',
  'optimizer_params': {"amsgrad": True},
  'scheduler': 'ReduceLROnPlateau',
  'mode': min,
  'factor': 0.8,
  'patience': 3,
  'max_epochs': 80,
  'ema_decay': 0.999,
  'clip_grad_norm': 10,
  'weight_decay': 0,  # 2e-6 (TF weight decay) / 1e-4 (lr) = 2e-2
  'loss_energy': 'mae',
  'loss_force': 'atomwisel2',
  'force_coefficient': 1,
  'energy_coefficient': 1,
}

In [None]:
dataset = [
  {'src': 'Data/bulk_val.lmdb'},
  {'src': 'Data/bulk_val.lmdb'} # val set (optional)
]

In [None]:
trainer = ForcesTrainer(
    task=task,
    model=model, # copied for later use, not necessary in practice.
    dataset=dataset,
    optimizer=optimizer,
    identifier="S2EF-example",
    run_dir="./", # directory to save results if is_debug=False. Prediction files are saved here so be careful not to override!
    is_debug=False, # if True, do not save checkpoint, logs, or results
    logger='tensorboard',
    print_every=5,
    seed=0, # random seed to use
     # logger of choice (tensorboard and wandb supported)
    local_rank=0,
    amp=True, # use PyTorch Automatic Mixed Precision (faster training and less memory usage),
)

In [None]:
from ocpmodels.trainers import ForcesTrainer
from ocpmodels.datasets import TrajectoryLmdbDataset
from ocpmodels import models
from ocpmodels.common import logger
from ocpmodels.common.utils import setup_logging
setup_logging()

import numpy as np
import copy
import os


In [None]:
class Rectangle:
    def __init__(self, length, width):
        self.length = length
        self.width = width

    def area(self):
        return self.length * self.width

    def perimeter(self):
        return 2 * self.length + 2 * self.width

# Here we declare that the Square class inherits from the Rectangle class
class Square(Rectangle):
    def __init__(self, length,width,height):        
        super().__init__(length,width)
        self.height=height
    def vol(self,):
        return self.height*2+self.length

In [None]:
ffa=Square(7,6,height=2)
ffa.vol()

In [None]:
train_src='Data/bulk_train.lmdb'
val_src='Data/bulk_train.lmdb'