# Create PDB frames from structure flipping steps

In [33]:
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
import torch.optim as optim

import numpy as np
import pandas as pd

from topoly import alexander

from foldingdiff import datasets
from foldingdiff.angles_and_coords import create_new_chain_nerf
from foldingdiff import utils
from foldingdiff import modelling

In [27]:
torch.manual_seed(42)
np.random.seed(42)

BATCH_SIZE = 512
DEVICE = "cuda"

In [5]:
class SimpleCNN(nn.Module):
    def __init__(self):  
        super(SimpleCNN, self).__init__()
        
        self.conv_layer1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3), padding=(2, 2))
        self.relu1 = nn.LeakyReLU()
        self.max_pool1 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.dropout1 = nn.Dropout(0.1)
        
        self.conv_layer2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), padding=(2, 2))
        self.relu2 = nn.LeakyReLU()
        self.max_pool2 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.dropout2 = nn.Dropout(0.1)
        
        self.conv_layer3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), padding=(2, 2))
        self.relu3 = nn.LeakyReLU()
        self.max_pool3 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.dropout3 = nn.Dropout(0.1)
        
        # After convolution and pooling, calculate the number of features for the linear layer
        self._to_linear = None

        self.convs = nn.Sequential(
            self.conv_layer1,
            self.relu1,
            self.max_pool1,
            self.dropout1,
            self.conv_layer2,
            self.relu2,
            self.max_pool2,
            self.dropout2,
            self.conv_layer3,
            self.relu3,
            self.max_pool3,
            self.dropout3
        )
        x = torch.randn(512, 1, 128, 6)
        self._to_linear = self.convs(x).view(x.size(0), -1).shape[1]
                                                                                           
        self.fc1 = nn.Linear(self._to_linear, 128)
        
        self.fc2 = nn.Linear(128, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_representation):
        
        # input is 128 x 6
        out = self.convs(input_representation)
                
        out = out.reshape(out.size(0), -1)
        
        out = self.fc1(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        return out

## Prepare dataset

In [6]:
class myDataset(datasets.CathCanonicalAnglesOnlyDataset):
    
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.labels = list(map(lambda x: 0 if x.split('/')[1][0] == 'u' else 1, self.filenames))

    def __getitem__(self, index):
        return_dict = super().__getitem__(index)
        return_dict['angles'] = return_dict['angles'].reshape(1, 128, 6)
        label = self.labels[index]
        return return_dict, label

In [16]:
dataset = myDataset(
            pdbs='mixed_dataset',
            split=None,
            pad=128,
            min_length=40,
            trim_strategy="randomcrop",
            zero_center=True
        )

In [17]:
train_size = int(0.8 * len(dataset))  # 80% for training
valid_size = int(0.1 * len(dataset))  # 10% for validation
test_size = len(dataset) - train_size - valid_size # 10% for testing

print(f"Test set size: {test_size}")

train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size, test_size])


train_data_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)
valid_data_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)
test_data_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)

Test set size: 395


### Load model

In [20]:
simple_model = torch.load('../models/CNN_model_mixed_dataset.pt', weights_only=False)
simple_model.eval()

SimpleCNN(
  (conv_layer1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
  (relu1): LeakyReLU(negative_slope=0.01)
  (max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout1): Dropout(p=0.1, inplace=False)
  (conv_layer2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
  (relu2): LeakyReLU(negative_slope=0.01)
  (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout2): Dropout(p=0.1, inplace=False)
  (conv_layer3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
  (relu3): LeakyReLU(negative_slope=0.01)
  (max_pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout3): Dropout(p=0.1, inplace=False)
  (convs): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
    (1): LeakyReLU(negative_slope=0.01)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode

## Flip the structure and save the frame

In [31]:
# Iterative process to modify the input
def guide_structure(input_structure, prefix, target_label, length, steps, lr, verbose=False):
    
    loss_fn = nn.BCELoss()
    optimizer = optim.Adam([input_structure], lr=lr)
    
    for iteration in range(steps):
        inner_to_structure(input_structure.detach().to('cpu')[0][0], length, f"{prefix}{iteration}.pdb")    
        optimizer.zero_grad()

        # Forward pass
        output = simple_model(input_structure)

        # Compute the loss (maximize the prediction towards the target label)
        loss = loss_fn(output, target_label)

        # Backward pass and update the input structure
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            angular_idx = np.where(dataset.feature_is_angular['angles'])[0]
            # Copy the input structure to avoid in-place modification
            modified_structure = input_structure.clone()
            for s in modified_structure:
                s[..., angular_idx] = utils.modulo_with_wrapped_range(
                    s[..., angular_idx], range_min=-np.pi, range_max=np.pi
                )
            # Update the input structure with the modified one
            input_structure.copy_(modified_structure)
        
        # Print the progress
        if iteration % 100 == 0 and verbose:
            print(f"Iteration {iteration}: Loss = {loss.item()}, Prediction = {output.item()}")

    inner_to_structure(input_structure.detach().to('cpu')[0][0], length, f"{prefix}{steps}.pdb")
    return input_structure

In [22]:
def inner_to_structure(sample, length, name, verbose=False, knot=False):
    retval = [s + dataset.get_masked_means() for s in sample]
    # Because shifting may have caused us to go across the circle boundary, re-wrap
    angular_idx = np.where(dataset.feature_is_angular['angles'])[0]
    for s in retval:
        s[..., angular_idx] = utils.modulo_with_wrapped_range(
                    s[..., angular_idx], range_min=-np.pi, range_max=np.pi
                )
    df = pd.DataFrame(retval, columns=['phi', 'psi', 'omega', 'tau', 'CA:C:1N', 'C:1N:1CA']).astype("float")[:int(length)]

    create_new_chain_nerf(name, df)
    
    if knot:
        knots = alexander(name, tries=70) # TODO modified from 100
    
        if verbose:
            print('Knotting status: ', knots)
    
        if ('0_1' in knots.keys() and knots['0_1'] > 0.5):
            return False
        return True
    return False

In [24]:
def flip_structure(input_sample, length, orig_label, iterations, lr, debug=False):
        
    sample = torch.clone(input_sample).unsqueeze(0).to(DEVICE)
    orig_prediction = simple_model(sample)
        
    sample.requires_grad = True
        
    target = torch.tensor([[1.0 - orig_label]], device=DEVICE)
            
    new_sample = guide_structure(sample, "pdbs/timestep_", target, length, iterations, lr, verbose=debug)
    res = inner_to_structure(new_sample.detach().to('cpu')[0][0], length, 'tmp/fliped.pdb', verbose=debug, knot=True)    
        
    if debug and ((res and (orig_label == 0)) or (not res and (orig_label == 1))):
            print(f"Prediction of original input = {orig_prediction}")
            print(f"Original label: {orig_label}")
            print(f"Prediction of new input = {simple_model(new_sample)}")
            print(f"Target label: {target}")

In [25]:
test_features, test_labels = next(iter(test_data_loader))

In [37]:
i = 2
flip_structure(test_features['angles'][i], test_features['lengths'][i], test_labels[i], 200, 0.001, True)

Iteration 0: Loss = 2.8862407207489014, Prediction = 0.055785536766052246
Iteration 100: Loss = 1.0499870777130127, Prediction = 0.34994229674339294
Knotting status:  {'4_1': 0.7428571428571429, '0_1': 0.21428571428571427}
Prediction of original input = tensor([[0.0558]], device='cuda:0', grad_fn=<SigmoidBackward0>)
Original label: 0
Prediction of new input = tensor([[0.6331]], device='cuda:0', grad_fn=<SigmoidBackward0>)
Target label: tensor([[1.]], device='cuda:0')


movie created from test sample #2