# Guide the structure towards knot while doing diffusion

In [1]:
import numpy as np
import pandas as pd

from huggingface_hub import snapshot_download
import torch
from torch import nn
import torch.optim as optim
from topoly import alexander

from typing import *

from foldingdiff import sampling
from foldingdiff import modelling
from foldingdiff import datasets
from foldingdiff import utils
from foldingdiff.angles_and_coords import create_new_chain_nerf

## Prepare original training dataset

The dataset (folder with the pdb files, starting with 'u' for unknotted and 'k' for knotted) is needed to properly sample the initial noise and then to be able to reconstruct PDB from the inner model representation

In [2]:
DATASET = 'mixed_dataset'

In [3]:
class myDataset(datasets.CathCanonicalAnglesOnlyDataset):
    
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.labels = list(map(lambda x: 0 if x.split('/')[1][0] == 'u' else 1, self.filenames))

    def __getitem__(self, index):
        return_dict = super().__getitem__(index)
        return_dict['angles'] = return_dict['angles'].reshape(1, 128, 6)
        label = self.labels[index]
        return return_dict, label

In [4]:
clean_dataset = myDataset(
            pdbs=DATASET,
            split=None,
            pad=128,
            min_length=40,
            trim_strategy="randomcrop",
            zero_center=True
        )

INFO:root:Found 3937 PDB files in mixed_dataset
INFO:root:Loading cached full dataset from /home/jovyan/Aknots/Diffusion/foldingdiff/foldingdiff/cache_canonical_structures_mixed_dataset_0fc317992ef2176584cce17d8b0cb0b5.pkl
INFO:root:Hash matches between codebase and cached values!
INFO:root:Removing structures shorter than 40 residues excludes 0/3936 --> 3936 sequences
INFO:root:Offsetting features ['phi', 'psi', 'omega', 'tau', 'CA:C:1N', 'C:1N:1CA'] by means [ 1.32475     1.4534563   1.5235982  -1.4041979   0.09982978  3.1017764
  1.9426514   2.0457344   2.1426508 ]
INFO:root:Length of angles: 70-128, mean 105.02540650406505
INFO:root:CATH canonical angles only dataset with ['phi', 'psi', 'omega', 'tau', 'CA:C:1N', 'C:1N:1CA'] (subset idx [3, 4, 5, 6, 7, 8])


make from the clean dataset instance of noised dataset to be able to sample noise and do the diffusion process

In [5]:
noised_dataset = datasets.NoisedAnglesDataset(clean_dataset)

INFO:root:Getting linear variance schedule with 250 timesteps


## Prepare flipping model

Load the classifier recognizing knots x unknots.

I tried to train two different architectures - tiny dense model (SimpleClassifier) and a simple CNN (SimpleCNN). 

It's possible to replace this part with arbitrary model

In [6]:
class SimpleCNN(nn.Module):
    def __init__(self):  
        super(SimpleCNN, self).__init__()
        
        self.conv_layer1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3), padding=(2, 2))
        self.relu1 = nn.LeakyReLU()
        self.max_pool1 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.dropout1 = nn.Dropout(0.1)
        
        self.conv_layer2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), padding=(2, 2))
        self.relu2 = nn.LeakyReLU()
        self.max_pool2 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.dropout2 = nn.Dropout(0.1)
        
        self.conv_layer3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), padding=(2, 2))
        self.relu3 = nn.LeakyReLU()
        self.max_pool3 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.dropout3 = nn.Dropout(0.1)
        
        # After convolution and pooling, calculate the number of features for the linear layer
        self._to_linear = None

        self.convs = nn.Sequential(
            self.conv_layer1,
            self.relu1,
            self.max_pool1,
            self.dropout1,
            self.conv_layer2,
            self.relu2,
            self.max_pool2,
            self.dropout2,
            self.conv_layer3,
            self.relu3,
            self.max_pool3,
            self.dropout3
        )
        x = torch.randn(512, 1, 128, 6)
        self._to_linear = self.convs(x).view(x.size(0), -1).shape[1]
                                                                                           
        self.fc1 = nn.Linear(self._to_linear, 128)
        
        self.fc2 = nn.Linear(128, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_representation):
        
        # input is 128 x 6
        out = self.convs(input_representation)
                
        out = out.reshape(out.size(0), -1)
        
        out = self.fc1(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        return out

In [7]:
class SimpleClassifier(nn.Module):
    def __init__(self):  
        super(SimpleClassifier, self).__init__()
        
        self.pooling = nn.AdaptiveAvgPool1d(64)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(384, 384),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(384, 1),
            nn.Sigmoid()
        )

    def forward(self, input_representation):
        
        # Pool the input
        pooled_output = self.pooling(input_representation.transpose(1, 2))
        
        # Flatten the pooled output
        flattened_output = pooled_output.reshape(pooled_output.size(0), -1)
        
        # Classify
        logits = self.classifier(flattened_output)
        
        return logits

Choose the model here:

In [8]:
classification_model = torch.load('CNN_model_mixed_dataset.pt', weights_only=False)
classification_model.eval()

SimpleCNN(
  (conv_layer1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): LeakyReLU(negative_slope=0.01)
  (max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout1): Dropout(p=0.1, inplace=False)
  (conv_layer2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2): LeakyReLU(negative_slope=0.01)
  (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout2): Dropout(p=0.1, inplace=False)
  (convs): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.01)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Dropout(p=0.1, inplace=False)
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): LeakyReLU(negative_slope=0.01)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Drop

!!

I was stupid and made CNN and dense model with different input dimensions (CNN has one dimension extra).

To properly connect the pipeline, please flip the bool according the used model.

In [9]:
CNN_model = True

## Prepare diffusion model

Load the original foldingiff model from paper to make the diffusion steps

In [10]:
diffusion_model = modelling.BertForDiffusion.from_dir(snapshot_download("wukevin/foldingdiff_cath"))

Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

INFO:root:Auto constructed ft_is_angular: [True, True, True, True, True, True]
INFO:root:Found 1 checkpoints
INFO:root:Loading weights from /home/jovyan/.cache/huggingface/hub/models--wukevin--foldingdiff_cath/snapshots/98d77b1e68468db5ca03cdba1c0a90f2a2a33edc/models/best_by_valid/epoch=1488-step=565820.ckpt
/home/jovyan/my-conda-envs/foldingdiff2/lib/python3.8/site-packages/lightning_fabric/utilities/cloud_io.py:57: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicit

Rewrite the sampling logic a bit to be able to sample different timesteps and just one structure

In [11]:
def my_sample(
    model: nn.Module,
    noise,
    train_dset: datasets.NoisedAnglesDataset,
    length: int,
    timesteps: int,
    feature_key: str = "angles",
    disable_pbar: bool = False,
    trim_to_length: bool = True,  # Trim padding regions to reduce memory
):

    # Produces (timesteps, batch_size, seq_len, n_ft)
    sampled = sampling.p_sample_loop(
        model=model,
        lengths=[length],
        noise=noise,
        timesteps=timesteps,
        betas=train_dset.alpha_beta_terms["betas"],
        is_angle=train_dset.feature_is_angular[feature_key],
        disable_pbar=disable_pbar,
    )
    
    # Gets to size (timesteps, seq_len, n_ft)
    trimmed_sampled = sampled[-1, :, :length, :]
    
    return trimmed_sampled

## Guide the structure

The logic to take the structure and guide it towards knotting using the classification model

In [12]:
def guide_structure(input_structure, target_label, steps, lr, verbose=False):
    # Iterative process to modify the input structure towards target label
    
    loss_fn = nn.BCELoss()
    optimizer = optim.Adam([input_structure], lr=lr)
    
    for iteration in range(steps):
        optimizer.zero_grad()

        # Forward pass
        output = classification_model(input_structure)

        # Compute the loss (maximize the prediction towards the target label)
        loss = loss_fn(output, target_label)

        # Backward pass and update the input structure
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            angular_idx = np.where(clean_dataset.feature_is_angular['angles'])[0]
            # Copy the input structure to avoid in-place modification
            modified_structure = input_structure.clone()
            for s in modified_structure:
                s[..., angular_idx] = utils.modulo_with_wrapped_range(
                    s[..., angular_idx], range_min=-np.pi, range_max=np.pi
                )
            # Update the input structure with the modified one
            input_structure.copy_(modified_structure)
        
        # Print the progress
        if iteration % 100 == 0 and verbose:
            print(f"Iteration {iteration}: Loss = {loss.item()}, Prediction = {output.item()}")
    return input_structure

In [13]:
def inner_to_structure(sample, length, name, verbose=False, knot=False):
    # reconstruct PDB file from the inner representation
    
    retval = [s + clean_dataset.get_masked_means() for s in sample]
    # Because shifting may have caused us to go across the circle boundary, re-wrap
    angular_idx = np.where(clean_dataset.feature_is_angular['angles'])[0]
    for s in retval:
        s[..., angular_idx] = utils.modulo_with_wrapped_range(
                    s[..., angular_idx], range_min=-np.pi, range_max=np.pi
                )
    df = pd.DataFrame(retval, columns=['phi', 'psi', 'omega', 'tau', 'CA:C:1N', 'C:1N:1CA']).astype("float")[:int(length)]

    create_new_chain_nerf(name, df)
    
    if knot:
        knots = alexander(name, tries=70) # TODO modified from 100
    
        if verbose:
            print('Knotting status: ', knots)
    
        if ('0_1' in knots.keys() and knots['0_1'] > 0.5):
            return False
        return True
    return False

In [14]:
def flip_structure(input_sample, target_label, iterations, lr, debug=False):
    # full logic to guide the structure
        
    sample = torch.clone(input_sample).unsqueeze(0).to('cuda')
    orig_prediction = classification_model(sample)
        
    sample.requires_grad = True
        
    target_label = torch.tensor([[target_label]], device='cuda')
            
    new_sample = guide_structure(sample, target_label, iterations, lr, verbose=False)
        
    if debug:
            print(f"Prediction of original input = {orig_prediction.item()}")
            print(f"Prediction of new input = {classification_model(new_sample).item()}")
    
    new_sample = new_sample.detach().to('cpu')[0]
    
    return new_sample

## Make the flipping loop

Unite the diffusion with guiding structure towards knotting

In [15]:
def flip_with_diffusion(timesteps, debug=False):

    length = 128 # TODO make possible different lengths
    
    # Sample noise
    noise = noised_dataset.sample_noise(torch.zeros((1, noised_dataset.pad, diffusion_model.n_inputs), dtype=torch.float32))

    # do diffusion and then guide towards knot
    for t in range(timesteps, 0, -1):
        print(f"-------Starting step {t}-------")
        print("--Running diffusion step--")
        s = my_sample(diffusion_model, noise, noised_dataset, length=length, timesteps=t)
        inner_to_structure(s.detach()[0], length, f'tmp/diffused_{2*(timesteps - t)}.pdb', verbose=debug, knot=False)
        print("--Flipping structure--")

        if CNN_model:
            noise = flip_structure(s, 1.0, 50, 0.001, debug).unsqueeze(0)[0] # TODO maybe different steps, lr?
            inner_to_structure(noise.detach()[0], length, f'tmp/diffused_{2*(timesteps - t) + 1}.pdb', verbose=debug, knot=debug)
        else:
            noise = flip_structure(s[0], 1.0, 50, 0.001, debug).unsqueeze(0) # TODO maybe different steps, lr?
            inner_to_structure(noise.detach()[0], length, f'tmp/diffused_{2*(timesteps - t) + 1}.pdb', verbose=debug, knot=debug)

In [None]:
flip_with_diffusion(250, debug=True) # 250 is default timesteps in the foldingdiff model