In [None]:
import yaml
import h5py
import numpy as np
import pickle as pkl

import matplotlib.pyplot as plt

from rdkit import Chem

# Molecular structure elucidation from NMR spectra: Part 2

Now that we have built and pretrained our substructure to structure model, let us now in the second part of this exercise integrate this into the full multi-task model that we can use for spectra to structure predictions. Figure 3 below summarizes the architectual details of the multi-task model (highlighted with the green box) that we will build and train for predicting structure from NMR spectra. Note that we will slot in our pretrained substructure to structure model from the previous part (highlighted with the black box) directly into the "Encoder-Decoder Transformer" of the part of the spectra to structure model.

<div style="text-align: center;">
    <img src="figures/architecture2.png" width="80%"/>
    <figcaption><i>Figure 3: Diagram summarizing the architecture of our multi-task model used for spectra to substructure and spectra to structure predictions</i></figcaption>
</div>

## Training our multi-task model

Just like in part 1, since the training of our multi-task model will take some time, let us first get that started before we dive into the other details. Training our model will be straightforward since we will use the full implementation of the model provided by the NMR2Struct software package. To start training the model:

1. Login to the NYU Shanghai cluster and navigate to where you would like to perform the training of the model.

```
ssh <user>@hpclogin.shanghai.nyu.edu
```

2. Navigate to the run directory for exercise 2 then submit the training run to the job scheduler:

```
cd nmr2mol/exercise_part2/
sbatch -J multitask-train submit.sh
```

3. You can monitor the progress training your model by downloading the tfevents file generated to your local machine and using Tensorboard as show below and navigating in your web browser to the indicated URL (e.g., "http://localhost:6006/"). Note that the losses used to optimize our model parameters and that are visualized via Tensorboard are the standard cross-entropy losses

```
rsync -avz "<user>@hpclogin.shanghai.nyu.edu:~/nmr2mol/exercise_part2/checkpoints/events.out.*" model2/
conda activate NMR_env
tensorboard --logdir=model2/

```

Now you just need to wait for your job to run and your model will then start training, which will take on the order of an hour. In the meantime, let us understand the details of the model we are training by building our own version of it from scratch using Pytorch.


## Checking the dataset

We will use the same dataset of 146595 molecules used in part 1 but now we will additionally make use of their corresponding $^1$H and $^{13}$C NMR spectra. The following files specify our spectral data:

1. <i>CNMR_shifts.p</i> The chemical shifts that we use for the centers of the $^{13}$C NMR spectra bins (80 bins from 3 to 232 ppm). 
2. <i>HNMR_shifts.p</i> The chemical shifts used to discretize the $^{1}$H NMR spectrum (28000 points from -2 to 12 ppm).
3. <i>spectra.h5</i> Contains arrays of shape 28080 for each molecule where the first 28000 values are the normalized intensities of the $^1$H NMR spectrum and the last 80 values are the binned $^{13}$C NMR peaks.

Note that we opt to use highly resolved H NMR spectra inputs but coarsely resolved C NMR spectra inputs because the former is richer in information content. More specifically, the coupling between the nuclear spins of neighboring H atoms can lead to the splitting of NMR peaks for those H atoms. Hence, H NMR peak splittings carry valuable information about the neighboring environment around a given H atom. On the other hand, $^{13}$C NMR spectra are typically "proton decoupled", which means that the measurement is made in such a way that the coupling of $^{13}$C nuclear spins with neighboring $^{1}$H nuclear spins is averaged out so that usually no C NMR peak splittings are observed. Also, since $^{13}$C is also a relatively rare isotope of C, there will likely not be much $^{13}$C-$^{13}$C coupling eiter.

Below we load the data for one of the molecules in our dataset. Plot both the H and C NMR spectra for that molecule.

In [None]:
i = 10000

ismi = np.load('data/smiles.npy')[i].decode('UTF-8')

substruct_list = pkl.load((open('data/substructures_957.p', 'rb')))
hf = h5py.File('data/substructures.h5', 'r')
isubs = hf['substructure_labels'][i]
hf.close()

cnmr_shifts = pkl.load(open('data/CNMR_shifts.p', 'rb'))
hnmr_shifts = pkl.load(open('data/HNMR_shifts.p', 'rb'))
hf = h5py.File('data/spectra.h5', 'r')
ispectra = hf['spectra'][i]
ihnmr = ispectra[:28000]
icnmr = ispectra[28000:]
hf.close()

# Plotting NMR spectra for molecule i
### MODIFY BELOW ###
fig, axs = plt.subplots(2, 1, figsize=(10, 5))
axs[0].plot(hnmr_shifts, ihnmr)
axs[0].set_xlim(-2, 12)
axs[0].set_xlabel('H NMR shift (ppm)')
axs[0].set_ylabel('Normalized Intensity')
axs[1].plot(cnmr_shifts, icnmr, 'o')
axs[1].set_xlim(np.min(cnmr_shifts), np.max(cnmr_shifts))
axs[1].set_xlabel('C NMR shift (ppm)')
axs[1].set_ylabel('Bin Occupancy')
plt.subplots_adjust(hspace=0.3)
plt.show()
### MODIFY ABOVE ###

The following block uses the rdKit software package to visualize the molecule as specified by its SMILES string like we did in part 1

In [None]:
print(ismi)
imol = Chem.MolFromSmiles(ismi)
Chem.Draw.MolToImage(imol)

## Building our spectra to structure model

Since the multi-task model will use the same Transformer and PositionalEncoding classes we developed in part 1, we will simply load the implementation from the NMR2Struct package here

In [None]:
import math
import torch
import torch.nn.functional as f
from typing import Tuple, Callable, Optional, Any
from torch import nn, Tensor

from nmr.models import TransformerModel
from nmr.networks.encoder import PositionalEncoding

Let us start by working on the part of the model that processes and encodes the inputed NMR spectra. We will process the H NMR spectrum, as specified by an array of normalized NMR intensities for 28000 chemical shift values from -2 to 12 ppm, with 1D convolutional layers interspersed with pooling layers. On the other hand, we will use a smaller input array for the binned C NMR spectrum and hence use a simple embedding scheme like we have done for other components already. 

Convolutional layers are particularly well-suited for processing image-like data where there are a lot of pixels but information content is localized but has translational symmetry (e.g., identifying whether a large image contains a dog). The H NMR spectra we use here are effectively sparse 1D images with motifs like peak splittings that are local in character, but that can be translationally shifted in ppm depending on the local environment around a specific H atom.

We will use the function below without modification to first process our raw spectra inputs (both the H and C NMR spectra)

In [None]:
def src_fwd_fxn_conv_embedding(src: Tensor,
                               d_model: int,
                               src_embed: nn.Module,
                               src_pad_token: int,
                               pos_encoder: nn.Module) -> Tuple[Tensor, Optional[Tensor]]:
    """Forward processing for the source tensor where the input is an unprocessed spectrum
    This forward function is only to be used with the convolutional source embedding. It has hard-coded values
    to allow for all shapes to line up with the current spectrum representation
    Args:
        src: The unembedded source tensor, in this case representing a spectrum, (batch_size, seq_len)
        d_model: The dimensionality of the model
        src_embed: The source embedding layer
        src_pad_token: The source padding token index. In the case of this forward function, it is 
            hard-coded to 0
        pos_encoder: The positional encoder layer
    """
    assert(src_embed is not None)
    #Only construct cnmr padding mask if using cnmr information
    if src_embed.use_cnmr:
        cnmr_start = src_embed.n_spectral_features
        cnmr_end = cnmr_start + src_embed.n_Cfeatures
        cnmr = src[:, cnmr_start:cnmr_end]
        assert(cnmr.shape[-1] == src_embed.n_Cfeatures)
        sorted_cnmr = torch.sort(cnmr, dim = -1, descending=True).values
        cnmr_key_pad_mask = (sorted_cnmr == 0).bool().to(src.device)
    else:
        cnmr_key_pad_mask = torch.tensor([]).bool().to(src.device)

    if src_embed.use_hnmr:
        src_key_pad_mask = torch.zeros(src.shape[0], 
                                    src_embed.h_spectrum_final_seq_len).bool().to(src.device)
    else:
        src_key_pad_mask = torch.tensor([]).bool().to(src.device)
    
    #Concatenate the padding masks together
    src_key_pad_mask = torch.cat((src_key_pad_mask, cnmr_key_pad_mask), dim = -1)
    
    src_embedded = src_embed(src) * math.sqrt(d_model)
    if src_embed.add_pos_encoder:
        src_embedded = pos_encoder(src_embedded, None)
    return src_embedded, src_key_pad_mask

Now we will construct the part of the model that will perform the embeddings of both the H and C NMR spectra. You are asked in particular to complete the initialization for the ConvolutionalEmbedding class below as well as the _embed_cnmr, _embed_spectra, and forward methods.

For the initialization, please see the specific comments for what built-in Pytorch nn layer should be initialized.

In the _embed_spectra method you will want to chain together the appropriate layers that were initialzed upon the creation of a ConvolutionalEmbedding object in ordering as is described on the righthand side of Figure 3 for embedding inputted H NMR spectra.

In _embed_cnmr, to match the architecture of the model we trained, you will want to transform the inputted C NMR spectrum from its binary binned representation to a zero-padded array that lists in order of index which C NMR bins were occupied (Hint: this is almost the same thing you did for the substructure array inputs in part 1!). You will then want to pass the 1-indexed representation of the C NMR data through an embedding layer.

In [None]:
class ConvolutionalEmbedding(nn.Module):
    
    def __init__(self, 
                 d_model: int,
                 n_hnmr_features: int = 28000,
                 n_cnmr_features: int = 40,
                 pool_variation: str = 'max',
                 pool_size_1: int = 12,
                 out_channels_1: int = 64,
                 kernel_size_1: int = 5,
                 pool_size_2: int = 20,
                 out_channels_2: int = 128,
                 kernel_size_2: int = 9,
                 add_pos_encode: bool = True,
                 use_hnmr: bool = True,
                 use_cnmr: bool = True):
        """Construct features over the spectrum using the same convolutional heads as 
        the convolutional neural network. The convolutional head involves 
        two 1D convolutions interspersed with max pooling. The channel dimensionalities
        are tunable, but default is out_channels_one = 64, out_channels_two = 128. 
        Convolution strides are 1, padding is 'valid' (no padding), and the activation
        function is ReLU. 

        Args:
            d_model: Model dimensionality for downstream transformer
            n_hnmr_features: The number of hnmr features, defaults to 28000
            n_cnmr_features: The number of cnmr features, defaults to 40
            pool_variation: The type of pooling to use, either 'max' or 'avg' where
                'max' is max pooling and 'avg' is average pooling, both 1D variants
            pool_size_1/2: Size and stride for the respective max pooling layer
            out_channels_1/2: Number of output channels after the respective convolutional layer
            kernel_size_1/2: Kernel size for the respective convolutional layer
            add_pos_encode: Whether to add a positional encoding to the output of this source 
                embedding.
            use_hnmr: Whether HNMR information is used by the network, defaults to True
            use_cnmr: Whether CNMR information is used by the network, defaults to True

        Notes: 
            Original architectures:
                conv1: Kernel size = 5, Filters (out channels) = 64, in channels = 1
                pool1: Max pool of size 12 with stride 12
                conv2: Kernel size of 9, Filters (out channels) = 128, in channels = 64
                pool2: Max pool of size 20 with stride 20
        """
        super().__init__()
        self.n_spectral_features = n_hnmr_features
        self.n_Cfeatures = n_cnmr_features
        self.d_model = d_model
        
        ### MODIFY BELOW ###

        ### MODIFY ABOVE ###
        
        self.h_spectrum_final_seq_len = self._compute_final_seq_len(
            self.n_spectral_features,
            [(kernel_size_1, pool_size_1, pool_variation), 
             (kernel_size_2, pool_size_2, pool_variation)]
        )
        self.add_pos_encoder = add_pos_encode
        self.use_hnmr = use_hnmr
        self.use_cnmr = use_cnmr
        #Have to use at least one source of spectral information as input
        #   to the model!
        assert self.use_hnmr or self.use_cnmr

        print("Final sequence length after conv embedding:")
        print(self.h_spectrum_final_seq_len)
    
    #From https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
    def _calculate_dim_after_conv(self, 
                                  L_in: int,
                                  kernel: int,
                                  padding: int,
                                  dilation: int,
                                  stride: int) -> int:
        numerator = L_in + (2 * padding) - (dilation * (kernel - 1)) - 1
        return math.floor(
            (numerator/stride) + 1
        )
    #From https://pytorch.org/docs/stable/generated/torch.nn.AvgPool1d.html
    # and https://pytorch.org/docs/stable/generated/torch.nn.MaxPool1d.html
    def _calculate_dim_after_pool(self,
                                  pool_variation: str,
                                  L_in: int,
                                  kernel: int,
                                  padding: int,
                                  dilation: int,
                                  stride: int) -> int:
        if pool_variation == 'max':
            numerator = L_in + (2 * padding) - (dilation * (kernel - 1)) - 1
            return math.floor(
                (numerator/stride) + 1
            )
        elif pool_variation == 'avg':
            numerator = L_in + (2 * padding) - kernel
            return math.floor(
                (numerator/stride) + 1
            )
    
    def _compute_final_seq_len(self,
                               L_in: int,
                               block_args: list[tuple[int]]) -> int:
        '''Computes the final sequence after a series of convolution + pooling operations
        Args:
            L_in: The initial sequence length
            block_args: A list of tuples, each containing the:
                convolution kernel
                pool kernel
                pooling_variation
        
        This function assumes:
            padding = 0
            dilation = 1
            stride = 1 for conv, stride = pool_size for pool
        '''
        L_final = L_in
        for conv_kernel, pool_kernel, pool_variation in block_args:
            L_final = self._calculate_dim_after_conv(L_final, conv_kernel, 0, 1, 1)
            L_final = self._calculate_dim_after_pool(pool_variation, L_final, pool_kernel, 0, 1, pool_kernel)
        return L_final

    def _separate_spectra_components(self, x: Tensor):
        if len(x.shape) == 2:
            x = torch.unsqueeze(x, 1)
        spectral_x = x[:, :, :self.n_spectral_features]
        cnmr_x = x[:, :, self.n_spectral_features:self.n_spectral_features + self.n_Cfeatures]
        mol_x = x[:, :, self.n_spectral_features + self.n_Cfeatures:]
        return spectral_x, cnmr_x, mol_x
    
    def _embed_cnmr(self, cnmr: Tensor):
        """Embeds the binary tensor into a continuous space
        Convert to 1-indexed indices, pad with 0
        """
        assert(cnmr.shape[-1] == self.n_Cfeatures)
        if cnmr.ndim == 3:
            cnmr = cnmr.squeeze(1)
        padder_idx = self.n_Cfeatures * 2
        ### MODIFY BELOW ###

        ### MODIFY ABOVE ###

    def _embed_spectra(self, spectra: Tensor):
        assert spectra.ndim == 3
        ### MODIFY BELOW ###

        ### MODIFY ABOVE ###

    def forward(self, x):
        spectra, cnmr, mol = self._separate_spectra_components(x)
        ### MODIFY BELOW ###

        ### MODIFY ABOVE ###

For the Transformer Encoder part of the model that takes the embedded spectral data and outputs the predicted substructure probabilities, most of the code has been provided below since the architecture is similar to that of the Encoder-Decoder Transformer you already built but with just the Encoder part. The only part that you need to complete is the implementation of the final feedforward layer, which is basically a single layer neural network with a linear layer followed with the application of sigmoid activation functions, as defined by the SingleLinear class.

In [None]:
def src_fwd_fxn_packed_tensor(src: tuple[Tensor],
                              d_model: int,
                              src_embed: nn.Module,
                              src_pad_token: int,
                              pos_encoder: nn.Module) -> Tuple[Tensor, Optional[Tensor]]:
    """ Forward processing for a source tensor that is a tuple which contains 
    the embedded sequence and the padding mask"""
    assert(src_embed is None)
    src_embedded, src_key_pad_mask = src
    return src_embedded, src_key_pad_mask


class SeqPool(nn.Module):
    """Sequence pooling operation introduced in https://arxiv.org/abs/2104.05704"""
    def __init__(self, d_model: int):
        super().__init__()
        self.g = nn.Linear(d_model, 1)
    
    def forward(self, x: Tensor) -> Tensor:
        """
        x: (bsize, seq_len, d_model) -> (bsize, d_model)
        Double check this code: https://github.com/SHI-Labs/Compact-Transformers/blob/main/src/utils/transformers.py
        """
        x_p = f.softmax(self.g(x), dim = 1)
        z = torch.bmm(x_p.transpose(1, 2), x)
        return z.squeeze(1)


class SingleLinear(nn.Module):
    """A single linear layer with a sigmoid activation"""
    def __init__(self, d_model: int, d_out: int):
        ### MODIFY BELOW ###

        ### MODIFY ABOVE ###
    
    def forward(self, x: Tensor) -> Tensor:
        ### MODIFY BELOW ###

        ### MODIFY ABOVE ###


class EncoderNetwork(nn.Module):

    model_id = 'EncoderNet'

    def __init__(self, 
                 src_embed: nn.Module,
                 src_pad_token: int,
                 src_forward_function: Callable[[Tensor, nn.Module, int, Optional[nn.Module]], tuple[Tensor, Optional[Tensor]]],
                 pooler: nn.Module,
                 pooler_opts: dict, 
                 output_head: nn.Module,
                 output_head_opts: dict,
                 d_model: int = 512,
                 nhead: int = 8,
                 dim_feedforward: int = 2048,
                 d_out: int = 957,
                 source_size: int = 957,
                 dropout: float = 0.1,
                 activation: str = 'relu',
                 layer_norm_eps: float = 1e-05,
                 batch_first: bool = True,
                 enable_norm: bool = True,
                 norm_first: bool = False,
                 bias: bool = True,
                 num_layers: int = 6,
                 enable_nested_tensor: bool = True,
                 device: torch.device = None,
                 dtype: torch.dtype = torch.float
                 ):
        r"""Most parameters are for the PyTorch TransformerEncoderLayer module.
        Args:
            src_embed: The embedding module for the src tensor passed to the model
            src_pad_token: The index used to indicate padding in the source sequence
            src_forward_function: A function that processes the src tensor using the src embedding, src pad token, and positional encoding to generate
                the embedded src and the src_key_pad_mask
            pooler: Module for pooling the output of the transformer. This is typically seen in sequence classification 
                tasks where pooling is applied in the sequence length/time dimension, i.e.:
                (N, T, E) --> (N, E) where N = batch size, T = seq len, and E = feature dimension.
                The pooler should take in a 3D input of shape (N, T, E) and produce an output of shape (N, E)
            pooler_opts: Additional options for the pooling head
            output_head: Module for generating the output from the pooled transformer output. 
                This module should take at minimum a 2D input of shape (N, E) and produce an output of shape (N, d_out)
            output_head_opts: Additional options for the output head
            d_model: The dimensionality of the model
            nhead: The number of heads in the multiheadattention models
            dim_feedforward: The inner dimension of the feedforward network model
            d_out: The dimensionality of the output (e.g. 957 for the number of substructures)
            source_size: The size of the source vocabulary, 
            batch_first: If True, then all tensors are shaped as (N, T, E), with N being batch size
            enable_norm: If True, then layer normalization is applied as in the original transformer encoder + decoder.
                Default is True
            num_layers: The number of layers to use in the encoder
            enable_nested_tensor: Whether nested tensor operations are enabled inside the transformer. By default enabled,
                improves performance when padding is high
            device: The device for the model
            dtype: The dtype for the model
        """
        super().__init__()
        self.src_embed = src_embed
        self.src_size = source_size
        self.src_fwd_fn = src_forward_function
        self.src_pad_token = src_pad_token
        self.d_model = d_model
        self.d_out = d_out
        self.nhead = nhead
        self.nlayers = num_layers

        self.pooler = pooler(**pooler_opts)
        self.output_head = output_head(**output_head_opts)
        self.pos_encoder = PositionalEncoding(d_model, dropout)

        self.dtype = dtype
        self.device = device

        #Construct the encoder model. Process taken from 
        #   https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#Transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation,
            layer_norm_eps=layer_norm_eps,
            batch_first=batch_first,
            norm_first=norm_first,
            bias=bias,
            device=device,
            dtype=dtype
        )
        if enable_norm:
            norm = nn.LayerNorm(d_model, eps=layer_norm_eps, bias=bias, device=device, dtype=dtype)
        else:
            norm = None
        self.encoder = nn.TransformerEncoder(
            encoder_layer, 
            num_layers=num_layers, 
            norm=norm,
            enable_nested_tensor=enable_nested_tensor
        )
    
    def _sanitize_forward_args(self, x: tuple[Tensor, tuple[str]]) -> Tensor:
        x, _ = x
        if isinstance(self.src_embed, nn.Embedding):
            x = x.long()
        return x
    
    def forward(self, x: tuple[Tensor, tuple[str]]) -> Tensor:
        src = self._sanitize_forward_args(x)
        src_embedded, src_key_pad_mask = self.src_fwd_fn(src, self.d_model, self.src_embed, self.src_pad_token, self.pos_encoder)
        src_out = self.encoder(src=src_embedded,
                               src_key_padding_mask=src_key_pad_mask)
        src_out = self.pooler(src_out)
        return self.output_head(src_out)
    
    def get_loss(self, 
                 x: tuple[Tensor, tuple[str]],
                 y: tuple[Tensor],
                 loss_fn: Callable[[Tensor, Tensor], Tensor]) -> Tensor:
        pred = self.forward(x)
        y_target, = y
        loss = loss_fn(pred, y_target.to(self.dtype).to(self.device))
        return loss


class EncoderModel(nn.Module):
    
    def __init__(self,
                 src_embed: str,
                 src_embed_options: dict,
                 src_pad_token: int, 
                 src_forward_function: str,
                 pooler: str,
                 pooler_opts: dict,
                 output_head: str,
                 output_head_opts: dict,
                 d_model: int,
                 nhead: int = 8,
                 dim_feedforward: int = 2048,
                 d_out: int = 957,
                 source_size: int = 957,
                 dropout: float = 0.1,
                 activation: str = 'relu',
                 layer_norm_eps: float = 1e-05,
                 batch_first: bool = True,
                 enable_norm: bool = True,
                 norm_first: bool = False,
                 bias: bool = True,
                 num_layers: int = 6,
                 enable_nested_tensor: bool = True,
                 freeze_components: Optional[list] = None,
                 device: torch.device = None,
                 dtype: torch.dtype = torch.float
                 ):
        r"""See documentation for EncoderNetwork for description of parameters. The parameters replaced by
        string values are meant to be names of the compnents to be fetched using getattr"""
        super().__init__()

        src_embed_layer = None
        pooler = SeqPool
        output_head = SingleLinear

        src_fwd_fn = src_fwd_fxn_packed_tensor
        self.network = EncoderNetwork(
            src_embed=src_embed_layer,
            src_pad_token=src_pad_token,
            src_forward_function=src_fwd_fn,
            pooler=pooler,
            pooler_opts=pooler_opts,
            output_head=output_head,
            output_head_opts=output_head_opts,
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            d_out=d_out,
            source_size=source_size,
            dropout=dropout,
            activation=activation,
            layer_norm_eps=layer_norm_eps,
            batch_first=batch_first,
            enable_norm=enable_norm,
            norm_first=norm_first,
            bias=bias,
            num_layers=num_layers,
            enable_nested_tensor=enable_nested_tensor,
            device=device,
            dtype=dtype
        )

        self.initialize_weights()
        self.freeze_components=freeze_components
        self.device=device
        self.dtype=dtype
    
    def initialize_weights(self) -> None:
        '''Initializes network weights'''
        for p in self.network.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def freeze(self) -> None:
        '''Disables greadients for specific network components'''
        if self.freeze_components is not None:
            for component in self.freeze_components:
                if hasattr(self.network, component):
                    for param in getattr(self.network, component).parameters():
                        param.requires_grad = False
    
    def forward(self, x: tuple[Tensor, tuple[str]]) -> Tensor:
        return self.network(x)
    
    def get_loss(self, 
                 x: tuple[Tensor, tuple[str]],
                 y: tuple[Tensor],
                 loss_fn: Callable[[Tensor, Tensor], Tensor]) -> Tensor:
        return self.network.get_loss(x, y, loss_fn)

For the last step, let us connect all the separate parts together to finalize the implementaion of the multi-task model. To do that, complete the indicated parts of initialization and forward methods for the MultiTaskModel class below.

In [None]:
class MultiTaskModel(nn.Module):
    """ Model that predicts substructures and structures """

    def __init__(self, 
                 src_embed: str,
                 src_embed_options: dict,
                 structure_model: str,
                 structure_model_args: dict,
                 substructure_model: str,  
                 substructure_model_args: dict,
                 forward_fxn: str, 
                 structure_model_ckpt: str,
                 substructure_model_ckpt: str,
                 device: torch.device = None,
                 dtype: torch.dtype = torch.float):
        """Constructor for multitask model that takes in a src embedding used to produce a structure and substructure prediction"""
        super().__init__()
        self.src_embed = ConvolutionalEmbedding(**src_embed_options)
        self.structure_model = TransformerModel(**structure_model_args)
        self.substructure_model = EncoderModel(**substructure_model_args)
        src_embed_dim = self.src_embed.d_model
        structure_model_dim = self.structure_model.network.d_model
        substructure_model_dim = self.substructure_model.network.d_model
        
        #Connecting linear transformations
        ### MODIFY BELOW ###

        ### MODIFY ABOVE ###

        self.fwd_fn = src_fwd_fxn_conv_embedding
        self.device = device
        self.dtype = dtype
        self.initialize_weights()
        if structure_model_ckpt is not None:
            self._partial_load_weights(self.structure_model, structure_model_ckpt)
        if substructure_model_ckpt is not None:
            self._partial_load_weights(self.substructure_model, substructure_model_ckpt)

    def initialize_weights(self) -> None:
        """initialize network weights"""
        self.structure_model.initialize_weights()
        self.substructure_model.initialize_weights()
    
    def freeze(self) -> None:
        """Disables gradients for specific components of the network"""
        self.structure_model.freeze()
        self.substructure_model.freeze()

    def _partial_load_weights(self, model: nn.Module, ckpt: str) -> None:
        ckpt = torch.load(ckpt, map_location = self.device)['model_state_dict']
        model_state = model.state_dict()
        pretrained_dictionary = {}
        for k, v in ckpt.items():
            if k in model_state:
                if model_state[k].shape == v.shape:
                    pretrained_dictionary[k] = v
                else:
                    warnings.warn(f"Could not load {k}: expected {model_state[k].shape} but got {v.shape}")
            else:
                warnings.warn(f"Could not load {k} because it is not in the model state dictionary")
        print("The following keys are ignored in the model:")
        for k in model_state:
            if k not in pretrained_dictionary:
                print(k)
        model.load_state_dict(pretrained_dictionary, strict=False)
    
    def _sanitize_forward_args(self, x, y):
        inp, _ = x
        structure_targets, substructure_targets = y
        return inp, structure_targets, substructure_targets
    
    def _unpack_to_list(self, x: Tensor, dim: int) -> list[Tensor]:
        """
        Given a tensor of elements, unpacks the tensor into a tuple of tensors. For exasmple,
        a tensor of shape (N, 2, E) -> ((N, E), (N, E))
        """
        return [torch.select(x, dim, i) for i in range(x.shape[dim])]

    def forward(self, 
                x: Tuple[Tensor, Tuple], 
                y: Tuple[Tensor, Tensor],
                eval_paths: list[str]) -> Tensor:
        """
        The argument eval_paths indicates which submodels to evaluate. It can contain the following values:
            'structure': The structure model is evaluated on this forward pass
            'substructure': The substructure model is evaluated on this forward pass
        Note that the forward() function is not used in get_loss. This is because get_loss will always 
        evaluate both submodels. If one wishes to only use one submodel, they should train the models 
        using the convolutional embedding instead. 
        """
        src, struct_targs, substruct_targs = self._sanitize_forward_args(x, y)
        if 'structure' in eval_paths:
            ### MODIFY BELOW ###

            ### MODIFY ABOVE ###
        else:
            structure_output = None
        if 'substructure' in eval_paths:
            ### MODIFY BELOW ###

            ### MODIFY ABOVE ###
        else:
            substructure_output = None
        return structure_output, substructure_output
    
    def get_loss(self,
                 x: Tuple[Tensor, Tuple], 
                 y: Tuple[Tensor], 
                 loss_fn: Callable[[Tensor, Tensor], Tensor]) -> Tensor:
        structure_loss = lambda x, y : loss_fn('structure', x, y)
        substructure_loss = lambda x, y : loss_fn('substructure', x, y)
        src, struct_targs, substruct_targs = self._sanitize_forward_args(x, y)
        src_struct_embedded, src_struct_key_pad_mask = self.fwd_fn(src,
                                                                   self.structure_model.network.d_model, 
                                                                   self.src_embed, 
                                                                   self.structure_model.network.src_pad_token, 
                                                                   self.structure_model.network.pos_encoder)
        src_substruct_embedded, src_substruct_key_pad_mask = self.fwd_fn(src,
                                                                       self.substructure_model.network.d_model, 
                                                                       self.src_embed, 
                                                                       self.substructure_model.network.src_pad_token, 
                                                                       self.substructure_model.network.pos_encoder)
        src_struct_embedded = self.structure_connector(src_struct_embedded)
        src_substruct_embedded = self.substructure_connector(src_substruct_embedded)
        #Loss scaling factors are multiplied within forward() calls
        structure_loss = self.structure_model.get_loss(((src_struct_embedded, src_struct_key_pad_mask), None), 
                                                       self._unpack_to_list(struct_targs, 1), 
                                                       structure_loss)
        substructure_loss = self.substructure_model.get_loss(((src_substruct_embedded, src_substruct_key_pad_mask), None),
                                                              (substruct_targs,), substructure_loss)
        return structure_loss + substructure_loss

## Loading our multi-task model

While you were implementing your model, hopefully your multi-task model has also been training. If it has trained to a point where the loss over the validation set has seemingly plateaud, copy the contents of the corresponding "checkpoints" subfolder into a folder in the directory here named "model2". If training has not seemingly saturated, you can instead use a provided trained model (see "multitask_checkpoint.pt" provided by the NMR2Struct package [here](https://github.com/MarklandGroup/NMR2Struct/tree/main/checkpoints)).

Below we will try to load the trained model using the code you implemented in this notebook, which is a stripped down reimplementation of the relevant code from the NMR2Struct package that we used to train this model. If you receive an error here, you will likely need to debug your implementation of the model in the previous section.

In [None]:
listdoc =  yaml.safe_load(open('model2/full_inference_config.yaml', 'r'))
model_args = listdoc['model']
model_config = model_args['model_args']

model = MultiTaskModel(dtype=torch.float32, device=torch.device('cpu'), **model_config)
ckpt = torch.load('model2/multitask_checkpoint.pt', map_location=torch.device('cpu'))['model_state_dict']
model.load_state_dict(ckpt)
model.eval()

## Inference with our multi-task model

With the model loaded, let us now attempt to use it to try predict both the substructures and the full structure of a molecule given only the corresponding pair of H and C NMR spectra for the molecule. Below we load the SMILES string, substructure list, and concatenated H and C NMR spectra for one of the molecules in our test set.

In [None]:
i = 10000

ismi = np.load('data/smiles.npy')[i].decode('UTF-8')

substruct_list = pkl.load((open('data/substructures_957.p', 'rb')))
hf = h5py.File('data/substructures.h5', 'r')
isubs = hf['substructure_labels'][i]
hf.close()
hf = h5py.File('data/spectra.h5', 'r')
ispectra = hf['spectra'][i]
hf.close()

Complete the code block below to pass the spectra loaded above to our multi-task model to make a prediction for the first character of the molecule's SMILES string given the list of substructures for that molecule.

In [None]:
alphabet = np.load('model2/alphabet.npy')
start_token = 22
stop_token = 23

### MODIFY BELOW ###

### MODIFY ABOVE ###

To roughly get a picture for how accurate our spectra to substructure model is, plot below the difference between the predicted substructure probabilities and the actual substructure labels for this molecule.

In [None]:
### MODIFY BELOW ###

### MODIFY ABOVE ###

Below, identify and visualize (remember rdKit?) which substructure predictions were false positives, which we will define here as substructures not a part of the molecule having a predicted probability of greater than 0.2. 

In [None]:
### MODIFY BELOW ###

### MODIFY ABOVE ###

Below, identify and visualize which substructure predictions were false negatives, which we will define here as substructures contained by the molecule having a predicted probability of less than 0.4. 

In [None]:
### MODIFY BELOW ###

### MODIFY ABOVE ###

Complete the generate_structure function below so that it autoregressively predicts for a molecular structure given an input array that is concatenation of the H and C NMR spectra. As you might have guessed, it will look quite similar to what you implemented in part 1.

In [None]:
def get_top_k_sample_batched(k_val: int | float , 
                             character_probabilities: Tensor) -> tuple[Tensor, Tensor]:
    """
    Generates the next character using top-k sampling scheme.

    In top-k sampling, the probability mass is redistributed among the
    top-k next tokens, where k is a hyperparameter. Once redistributed, 
    the next token is sampled from the top-k tokens.
    """
    top_values, top_indices = torch.topk(character_probabilities, k_val, sorted = True)
    #Take the sum of the top probabilities and renormalize
    tot_probs = top_values / torch.sum(top_values, dim = -1).reshape(-1, 1)
    #Sample from the top k probabilities. This represents a multinomial distribution
    try:
        assert(torch.allclose(torch.sum(tot_probs, dim = -1), torch.tensor(1.0)))
    except:
        print("Probabilities did not pass allclose check!")
        print(f"Sum of probs is {torch.sum(tot_probs)}")
    selected_index = torch.multinomial(tot_probs, 1)
    #For gather to work, both tensors have to have the same number of dimensions:
    if len(top_indices.shape) != len(selected_index.shape):
        top_indices = top_indices.reshape(selected_index.shape[0], -1)
    output = torch.gather(top_indices, -1, selected_index)
    output_token_probs = torch.gather(tot_probs, -1, selected_index)
    return output, output_token_probs
    

def generate_structure(model, ispectra, start_token=22, stop_token=23, max_len=74, sample_val=5, max_steps=100):
    ### MODIFY BELOW ###

    ### MODIFY ABOVE ###
    

pred_tokens, pred_token_probs = generate_structure(model, ispectra)
pred_smi = ''
for ichar in alphabet[pred_tokens[0,1:-1]]:
    pred_smi+=ichar
print('Target: ' + ismi)
print('Predicted: ' + pred_smi)
print('Predicted Score: ' + str(np.log(pred_token_probs[0,:-1].detach().numpy()).sum()))

Again just like in part 1, write a routine that samples 10 molecule predictions, saving for each molecule its SMILES string and its log probability.

In [None]:
def generate_k_structures(model, ispectra, num_pred_per_tgt):
    ### MODIFY BELOW ###

    ### MODIFY ABOVE ###


sampled_smis, sampled_smi_scores = generate_k_structures(model, ispectra, 10)

Use the below code block to visualize the 10 predicted molecules as ordered by their log probabilities

In [None]:
sampled_mols = [Chem.MolFromSmiles(x) for x in sampled_smis]
Chem.Draw.MolsToGridImage(sampled_mols,molsPerRow=5,subImgSize=(200,200), legends=[str(x) for x in sampled_smis])

Now you have a fully working mulit-task model that can be used to make spectra-to-substructure and even spectra-to-structure predictions, and all in just an afternoon's work!