# Train Model #

To train the CS prediction model, you need to provide a configuration file in YAML format. This file defines the dataset, the model to be used, and various hyperparameters. It also specifies the complete training setup, including the optimizer, learning rate, scheduler, loss function, metrics, and more.

Make sure to update the config file with the recommended settings from the previous step, where the training dataset was created.

Although this configuration file contains extensive details, for this tutorial, we will use a predefined one that already includes all the necessary information, including the dataset configuration.

In [None]:
def interactive_train():

    import os
    import torch
    import subprocess
    import ipywidgets as widgets
    from IPython.display import display

    # Step 1: Create input widgets
    options = []
    for root, _, files in os.walk('config', topdown=True):
        options.extend([os.path.join(root, f) for f in files])
    config_dropdown_w = widgets.Dropdown(
        options=options,
        description='Training config file in yaml format',
        disabled=False,
    )

    # Step 1: Dynamically check available devices (CPU and multiple GPUs)
    device_options = ['cpu']  # Always include 'CPU'

    if torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()  # Get the number of available GPUs
        for i in range(num_gpus):
            device_options.append(f'cuda:{i}')  # Add each GPU as 'CUDA:0', 'CUDA:1', etc.

    device_dropdown_w = widgets.Dropdown(
        options=device_options,
        value='cpu',  # Default value
        description='Device',
        disabled=False,
    )

    # Button to trigger the script execution
    run_button = widgets.Button(description="Run Training")

    # Output area to display the results
    output_area = widgets.Output()

    def run_script(button):
        script_name = "geqtrain-train"
        
        # Clear previous output
        output_area.clear_output()

        # Open the external script using Popen to stream stdout in real-time
        try:
            # Run the script
            with subprocess.Popen(
                [
                    script_name,
                    config_dropdown_w.value,
                    "-d", device_dropdown_w.value,
                ],
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                text=True,  # To capture text instead of bytes
                bufsize=1  # Line-buffered output
            ) as proc:
                # Read stdout line by line
                for line in proc.stdout:
                    with output_area:
                        print(line, end='')  # Print each line in the output area

        except Exception as e:
            with output_area:
                print(f"An error occurred: {e}")

    # Link the button click event to the function
    run_button.on_click(run_script)

    # Step 2: Display the widgets
    display(config_dropdown_w, device_dropdown_w, run_button, output_area)

interactive_train()

# Run Inference #

In [3]:
def interactive_run_backmapping():

    import os
    import torch
    import subprocess
    import ipywidgets as widgets

    from csnet.training.dataset import get_structure
    from IPython.display import display

    model_selection_w = widgets.Box(
        [
            widgets.Label(value='Select model:'),
            widgets.RadioButtons(
                options=[
                    'Use model from training',
                    'Use deployed model'
                ],
                value='Use model from training',
                layout={'width': 'max-content'}
            )
        ]
    )

    # Training folder containing config and model
    options = ['-']
    exclude = ['processed_datasets']
    for root, dirs, _ in os.walk('results', topdown=True):
        dirs[:] = [d for d in dirs if d not in exclude]
        if len(dirs) == 0:
            options.append(root)
    
    traindir_dropdown_w = widgets.Box(
        [
            widgets.Label(value='Training folder:'),
            widgets.Dropdown(
                options=options,
                value='-',
            ),
        ],
        layout=widgets.Layout(display='block'),
    )

    # Specify model, either from training folder or deployed model
    model_w = widgets.Box(
        [
            widgets.Label(value='Deployed model:'),
            widgets.Text(
                value='',
                placeholder='path/to/model',
            ),
        ],
        layout=widgets.Layout(display='none'),
    )

    # Function to enable/disable custom param1 input based on dropdown selection
    def on_model_selection_change(change):
        if change['new'] == 'Use model from training':
            model_w.layout.display = 'none'  # Hide
            traindir_dropdown_w.layout.display = 'block'  # Show
        elif change['new'] == 'Use deployed model':
            traindir_dropdown_w.layout.display = 'none'
            model_w.layout.display = 'block'
        print(change['new'])
    
    # Attach the function to handle changes in the radio
    model_selection_w.children[1].observe(on_model_selection_change, names='value')

    input_dataset_w = widgets.Box(
        [
            widgets.Label(value='Test dataset:'),
            widgets.RadioButtons(
                options=['From YAML', 'From topology'],
                value='From YAML',
                layout={'width': 'max-content'}, # If the items' names are long
            ),
        ],
    )

    topology_w = widgets.Text(
        value='',
        placeholder='E.g. pdb, gro, tpr...',
        description='Topology input file',
        disabled=False,
        layout=widgets.Layout(display='none'),  # Initially hidden
    )

    trajectory_w = widgets.Text(
        value='',
        placeholder='E.g. trr, xtc...',
        description='Trajectory input file',
        disabled=False,
        layout=widgets.Layout(display='none'),  # Initially hidden
    )

    selection_w = widgets.Text(
        value='all',
        placeholder='E.g. all, protein, resname POPC, ...',
        description='Atoms selection',
        disabled=False,
        layout=widgets.Layout(display='none'),  # Initially hidden
    )

    # Config for test dataset
    options = []
    exclude = ['training']
    for root, dirs, files in os.walk('config', topdown=True):
        dirs[:] = [d for d in dirs if d not in exclude]
        options.extend([os.path.join(root, f) for f in files])
    
    config_w = widgets.Box(
        [
            widgets.Label(value='Config:'),
            widgets.Dropdown(
                options=options,
                layout={'width': 'max-content'}, # If the items' names are long
            ),
        ],
        layout=widgets.Layout(display='block'),
    )

    # Device
    device_options = ['cpu']  # Always include 'CPU'

    if torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()  # Get the number of available GPUs
        for i in range(num_gpus):
            device_options.append(f'cuda:{i}')  # Add each GPU as 'CUDA:0', 'CUDA:1', etc.

    device_dropdown_w = widgets.Dropdown(
        options=device_options,
        value='cpu',  # Default value
        description='Device',
        disabled=False,
    )

    # Create a text widget for custom param1 (disabled by default)
    batch_size_w = widgets.Text(
        value='1',
        placeholder='E.g. 16',
        description='Batch Size',
        disabled=False,  # Start as disabled, only enable if "Custom" is selected
    )

    batch_max_atoms_w = widgets.Text(
        value='10000',
        placeholder='E.g. 10000',
        description='Max atoms per chunk',
        disabled=False
    )

    # Button to trigger the script execution
    run_button = widgets.Button(description="Run Inference")

    # Output area to display the results
    output_area = widgets.Output()

    # Function to enable/disable custom param1 input based on dropdown selection
    def on_input_dataset_change(change):
        if change['new'] == 'From topology':
            config_w.layout.display = 'none'      # Hide
            topology_w.layout.display = 'block'   # Show
            trajectory_w.layout.display = 'block'
            selection_w.layout.display = 'block'
        else:
            topology_w.layout.display = 'none'    # Hide
            trajectory_w.layout.display = 'none'
            selection_w.layout.display = 'none'
            config_w.layout.display = 'block'     # Show

    # Attach the function to handle changes in the radio
    input_dataset_w.children[1].observe(on_input_dataset_change, names='value')

    def run_script(button):
        
        script_name = "geqtrain-evaluate"
        
        # Clear previous output
        output_area.clear_output()

        if input_dataset_w.children[1].value == "From YAML":
            config = config_w.children[1].value
        elif input_dataset_w.children[1].value == "From topology":
            dataset, _ = get_structure(
                topology=topology_w.value,
                trajectories=[] if len(trajectory_w.value) == 0 else [trajectory_w.value],
                selection=None if len(selection_w.value) == 0 else selection_w.value,
            )

            print(dataset)

        # Open the external script using Popen to stream stdout in real-time
        try:
            # Run the script
            args = [
                script_name,
                "-tc", config,
                "-d", device_dropdown_w.value,
                "-bs", batch_size_w.value,
            ]
            if traindir_dropdown_w.children[1].value != '-':
                args.extend(["-td", traindir_dropdown_w.children[1].value])
            if model_w.children[1].value != '':
                args.extend(["-m", model_w.children[1].value])
            with subprocess.Popen(
                args=args,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                text=True,  # To capture text instead of bytes
                bufsize=1  # Line-buffered output
            ) as proc:
                # Read stdout line by line
                for line in proc.stdout:
                    with output_area:
                        print(line, end='')  # Print each line in the output area

        except Exception as e:
            with output_area:
                print(f"An error occurred: {e}")

    # Link the button click event to the function
    run_button.on_click(run_script)

    # Step 2: Display the widgets
    display(
        model_selection_w,
        traindir_dropdown_w,
        model_w,
        input_dataset_w,
        topology_w,
        trajectory_w,
        selection_w,
        config_w,
        device_dropdown_w,
        batch_size_w,
        run_button,
        output_area,
    )

interactive_run_backmapping()

Box(children=(Label(value='Select model:'), RadioButtons(layout=Layout(width='max-content'), options=('Use mod…

Box(children=(Label(value='Training folder:'), Dropdown(options=('-', 'results/SHIFTX2/production', 'results/S…

Box(children=(Label(value='Deployed model:'), Text(value='', placeholder='path/to/model')), layout=Layout(disp…

Box(children=(Label(value='Test dataset:'), RadioButtons(layout=Layout(width='max-content'), options=('From YA…

Text(value='', description='Topology input file', layout=Layout(display='none'), placeholder='E.g. pdb, gro, t…

Text(value='', description='Trajectory input file', layout=Layout(display='none'), placeholder='E.g. trr, xtc.…

Text(value='all', description='Atoms selection', layout=Layout(display='none'), placeholder='E.g. all, protein…

Box(children=(Label(value='Config:'), Dropdown(layout=Layout(width='max-content'), options=('config/testing/df…

Dropdown(description='Device', options=('cpu', 'cuda:0', 'cuda:1', 'cuda:2', 'cuda:3'), value='cpu')

Text(value='1', description='Batch Size', placeholder='E.g. 16')

Button(description='Run Inference', style=ButtonStyle())

Output()