# Flip knot <-> unknot with CNN model

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.dataloader import DataLoader

In [2]:
import pandas as pd
import numpy as np
import os

from topoly import alexander

from foldingdiff import datasets
from foldingdiff.angles_and_coords import create_new_chain_nerf
from foldingdiff import utils
from foldingdiff import modelling

from GPyOpt.methods import BayesianOptimization

In [3]:
torch.manual_seed(42)
np.random.seed(42)

In [4]:
BATCH_SIZE = 512
DEVICE = "cuda"
ITERATION_STEPS = 500
FLIPPING_LR = 0.001

## Make simple CNN using foldingdiff representation as input

In [5]:
class SimpleCNN(nn.Module):
    def __init__(self):  
        super(SimpleCNN, self).__init__()
        
        self.conv_layer1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3), padding=(2, 2))
        self.relu1 = nn.LeakyReLU()
        self.max_pool1 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.dropout1 = nn.Dropout(0.1)
        
        self.conv_layer2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), padding=(2, 2))
        self.relu2 = nn.LeakyReLU()
        self.max_pool2 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.dropout2 = nn.Dropout(0.1)
        
        self.conv_layer3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), padding=(2, 2))
        self.relu3 = nn.LeakyReLU()
        self.max_pool3 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.dropout3 = nn.Dropout(0.1)
        
        # After convolution and pooling, calculate the number of features for the linear layer
        self._to_linear = None

        self.convs = nn.Sequential(
            self.conv_layer1,
            self.relu1,
            self.max_pool1,
            self.dropout1,
            self.conv_layer2,
            self.relu2,
            self.max_pool2,
            self.dropout2,
            self.conv_layer3,
            self.relu3,
            self.max_pool3,
            self.dropout3
        )
        x = torch.randn(512, 1, 128, 6)
        self._to_linear = self.convs(x).view(x.size(0), -1).shape[1]
                                                                                           
        self.fc1 = nn.Linear(self._to_linear, 128)
        
        self.fc2 = nn.Linear(128, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_representation):
        
        # input is 128 x 6
        out = self.convs(input_representation)
                
        out = out.reshape(out.size(0), -1)
        
        out = self.fc1(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        return out

## Prepare dataset

In [6]:
class myDataset(datasets.CathCanonicalAnglesOnlyDataset):
    
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.labels = list(map(lambda x: 0 if x.split('/')[1][0] == 'u' else 1, self.filenames))

    def __getitem__(self, index):
        return_dict = super().__getitem__(index)
        return_dict['angles'] = return_dict['angles'].reshape(1, 128, 6)
        label = self.labels[index]
        return return_dict, label

In [7]:
dataset = myDataset(
            pdbs='mixed_dataset', # choose the dataset folder
            split=None,
            pad=128,
            min_length=40,
            trim_strategy="randomcrop",
            zero_center=True
        )

In [8]:
train_size = int(0.8 * len(dataset))  # 80% for training
valid_size = int(0.1 * len(dataset))  # 10% for validation
test_size = len(dataset) - train_size - valid_size # 10% for testing

print(f"Training set size: {train_size}")

train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size, test_size])


train_data_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)
valid_data_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)
test_data_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)

Training set size: 3148


### Check the dataloader

In [9]:
train_features, train_labels = next(iter(train_data_loader))
print(f"Feature batch shape: {train_features['angles'].size()}")
print(f"Labels batch shape: {train_labels.size()}")

Feature batch shape: torch.Size([512, 1, 128, 6])
Labels batch shape: torch.Size([512])


## Training logic

In [10]:
# Training loop
def train_model(model, train_loader, optimizer, criterion):
    model.train()
    for batch in train_loader:
        inputs, labels = batch
        labels = labels.float().to(DEVICE)
        optimizer.zero_grad()
        
        outputs = model(inputs['angles'].to(DEVICE))
        loss = criterion(outputs.squeeze(), labels)
        loss.backward()
        optimizer.step()

In [11]:
# Evaluation loop
def evaluate_model(model, test_loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for batch in test_loader:
            inputs, labels = batch
            labels = labels.float().to(DEVICE)
            
            outputs = model(inputs['angles'].to(DEVICE))
            loss = criterion(outputs.squeeze(), labels)
            total_loss += loss.item()
            preds = (outputs.squeeze() > 0.5).float()
            correct += (preds == labels).sum().item()
    
    avg_loss = total_loss / len(test_loader)
    accuracy = correct / len(test_loader.dataset)
    return avg_loss, accuracy

## Include Bayesian optimization for hyperparameters

In [None]:
def evaluation_for_optimization(params):
    learning_rate = params[0][0]
    num_epochs = int(params[0][1])
    
    simple_model = SimpleCNN().to(DEVICE)
    
    adam = optim.AdamW(simple_model.parameters(), lr=learning_rate)
    criterion = nn.BCELoss()
    
    print(f"Learning rate: {learning_rate}, Epochs: {num_epochs}")
    
    for epoch in range(num_epochs):
        train_model(simple_model, train_data_loader, adam, criterion)
        if epoch % 10 == 0:
            avg_loss, accuracy = evaluate_model(simple_model, valid_data_loader, criterion)
            print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss}, Accuracy: {accuracy}')
    
    avg_loss, accuracy = evaluate_model(simple_model, valid_data_loader, criterion)

    print(f'Loss: {avg_loss}, Accuracy: {accuracy}\n')
    
    return 1 - accuracy

In [None]:
optimization_space = [
    {'name': 'lr', 'type': 'continuous', 'domain': (5e-6, 1e-2)},
    {'name': 'epochs', 'type': 'continuous', 'domain': (50, 200)}
]

In [None]:
optimizer = BayesianOptimization(
    f=evaluation_for_optimization, domain=optimization_space, model_type='GP',
    acquisition_type='EI', max_iter=10
)

optimizer.run_optimization(max_iter=10)

In [None]:
optimizer.plot_acquisition()

In [None]:
print("Best parameters:")
print(f"Learning rate: {optimizer.X[np.argmin(optimizer.Y), 0]}")
print(f"Number of epochs: {int(optimizer.X[np.argmin(optimizer.Y), 1])}")
print(f"Best objective value: {np.min(optimizer.Y)}")

### Get model with the optimal values

use one of the two following cells to choose parameters for model training

In [None]:
# best parameters obtained through bayesian optimization
learning_rate = optimizer.X[np.argmin(optimizer.Y), 0]
num_epochs = int(optimizer.X[np.argmin(optimizer.Y), 1])

In [12]:
# parameters chosen by hand
learning_rate = 0.006
num_epochs = 30

In [13]:
simple_model = SimpleCNN().to(DEVICE)
    
adam = optim.AdamW(simple_model.parameters(), lr=learning_rate)
criterion = nn.BCELoss()
    
print(f"Learning rate: {learning_rate}, Epochs: {num_epochs}")
    
for epoch in range(num_epochs):
    train_model(simple_model, train_data_loader, adam, criterion)
    avg_loss, accuracy = evaluate_model(simple_model, valid_data_loader, criterion)
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss}, Accuracy: {accuracy}')
    
avg_loss, accuracy = evaluate_model(simple_model, valid_data_loader, criterion)

print(f'Loss: {avg_loss}, Accuracy: {accuracy}\n')

Learning rate: 0.006, Epochs: 30
Epoch 1/30, Loss: 0.6804741024971008, Accuracy: 0.6743002544529262
Epoch 2/30, Loss: 0.7103887796401978, Accuracy: 0.5979643765903307
Epoch 3/30, Loss: 0.6391249299049377, Accuracy: 0.6335877862595419
Epoch 4/30, Loss: 0.6105177402496338, Accuracy: 0.648854961832061
Epoch 5/30, Loss: 0.5436211824417114, Accuracy: 0.7073791348600509
Epoch 6/30, Loss: 0.5273293852806091, Accuracy: 0.72264631043257
Epoch 7/30, Loss: 0.5073405504226685, Accuracy: 0.7430025445292621
Epoch 8/30, Loss: 0.4868150055408478, Accuracy: 0.7455470737913485
Epoch 9/30, Loss: 0.4698195159435272, Accuracy: 0.7557251908396947
Epoch 10/30, Loss: 0.4699390232563019, Accuracy: 0.7557251908396947
Epoch 11/30, Loss: 0.5040570497512817, Accuracy: 0.7633587786259542
Epoch 12/30, Loss: 0.43657296895980835, Accuracy: 0.7837150127226463
Epoch 13/30, Loss: 0.4467410147190094, Accuracy: 0.7684478371501272
Epoch 14/30, Loss: 0.38616636395454407, Accuracy: 0.8346055979643766
Epoch 15/30, Loss: 0.3613

### Save model

In [None]:
torch.save(simple_model, 'CNN_model_mixed_dataset2.pt')

### Potentially load already trained model

In [None]:
simple_model = torch.load('CNN_model_mixed_dataset2.pt')

## Adversarial training - flip structures from the test set

In [14]:
simple_model.eval()  # Set the model to evaluation mode

SimpleCNN(
  (conv_layer1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
  (relu1): LeakyReLU(negative_slope=0.01)
  (max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout1): Dropout(p=0.1, inplace=False)
  (conv_layer2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
  (relu2): LeakyReLU(negative_slope=0.01)
  (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout2): Dropout(p=0.1, inplace=False)
  (conv_layer3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
  (relu3): LeakyReLU(negative_slope=0.01)
  (max_pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout3): Dropout(p=0.1, inplace=False)
  (convs): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
    (1): LeakyReLU(negative_slope=0.01)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode

In [15]:
# Iterative process to modify the input
def guide_structure(input_structure, target_label, optimizer, loss_fn, steps, verbose=False):
    for iteration in range(steps):
        optimizer.zero_grad()

        # Forward pass
        output = simple_model(input_structure)

        # Compute the loss (maximize the prediction towards the target label)
        loss = loss_fn(output, target_label)

        # Backward pass and update the input structure
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            angular_idx = np.where(dataset.feature_is_angular['angles'])[0]
            # Copy the input structure to avoid in-place modification
            modified_structure = input_structure.clone()
            for s in modified_structure:
                s[..., angular_idx] = utils.modulo_with_wrapped_range(
                    s[..., angular_idx], range_min=-np.pi, range_max=np.pi
                )
            # Update the input structure with the modified one
            input_structure.copy_(modified_structure)
        
        # Print the progress
        if iteration % 100 == 0 and verbose:
            print(f"Iteration {iteration}: Loss = {loss.item()}, Prediction = {output.item()}")
            
    return input_structure

In [25]:
# convert structure from inner representation to pdb file and check knot status

def inner_to_structure(sample, length, index=0, verbose=False):
    retval = [s + dataset.get_masked_means() for s in sample]
    # Because shifting may have caused us to go across the circle boundary, re-wrap
    angular_idx = np.where(dataset.feature_is_angular['angles'])[0]
    for s in retval:
        s[..., angular_idx] = utils.modulo_with_wrapped_range(
                    s[..., angular_idx], range_min=-np.pi, range_max=np.pi
                )
    df = pd.DataFrame(retval, columns=['phi', 'psi', 'omega', 'tau', 'CA:C:1N', 'C:1N:1CA']).astype("float")[:int(length)]

    create_new_chain_nerf(f'tmp/flip_{index}.pdb', df)
    
    knots = alexander(f'tmp/flip_{index}.pdb', tries=70) # modified from 100
    
    if verbose:
        print('Knotting status: ', knots)
    
    if ('0_1' in knots.keys() and knots['0_1'] > 0.5):
        return False
    else:
        return True

In [21]:
# get some samples from test set
test_features, test_labels = next(iter(test_data_loader))

### Try to flip some structures from the dataset

In [32]:
def flip_structure(iterations, lr, num=100, debug=False):
    pos_flipped = 0
    neg_flipped = 0
    neg = 0
    pos = 0
    
    for i in range(0, num):

        length = test_features['lengths'][i]
        
        if debug:
            print(f"\n\n ---- Attempt {i} ----- \n")
            print(f"Original label: {test_labels[i]}")
            inner_to_structure(test_features['angles'][i][0], length, f"original_{i}")
        if test_labels[i] == 1:
            pos += 1
        else:
            neg += 1
        
        sample = torch.clone(test_features['angles'][i]).unsqueeze(0).to(DEVICE)
        orig_prediction = simple_model(sample)
        
        sample.requires_grad = True
        loss_fn = nn.BCELoss()
        optimizer = optim.Adam([sample], lr=lr)
        
        target = torch.tensor([[1.0 - test_labels[i]]], device=DEVICE)
            
        new_sample = guide_structure(sample, target, optimizer, loss_fn, iterations, verbose=debug)
        
        res = inner_to_structure(new_sample.detach().to('cpu')[0][0], length, i, verbose=debug)
        
        if (res and (test_labels[i] == 0)) or (not res and (test_labels[i] == 1)):
            if res:
                neg_flipped += 1
            else:
                pos_flipped += 1
            if debug:
                print(f"Prediction of original input = {orig_prediction}")
                print(f"Prediction of new input = {simple_model(new_sample)}")
                print(f"Target label: {target}")
    return pos_flipped, pos, neg_flipped, neg

In [34]:
pos_flipped, pos, neg_flipped, neg = flip_structure(ITERATION_STEPS, FLIPPING_LR, debug=True)



 ---- Attempt 0 ----- 

Original label: 0
Iteration 0: Loss = 0.8437115550041199, Prediction = 0.4301111698150635
Iteration 100: Loss = 0.18769656121730804, Prediction = 0.828866183757782
Iteration 200: Loss = 0.09230294078588486, Prediction = 0.911828875541687
Iteration 300: Loss = 0.06171214208006859, Prediction = 0.9401534795761108
Iteration 400: Loss = 0.035691507160663605, Prediction = 0.9649379253387451
Knotting status:  {'0_1': 0.9857142857142858}


 ---- Attempt 1 ----- 

Original label: 0
Iteration 0: Loss = 0.46265536546707153, Prediction = 0.6296095848083496
Iteration 100: Loss = 0.09969247132539749, Prediction = 0.9051157236099243
Iteration 200: Loss = 0.04886695742607117, Prediction = 0.9523078203201294
Iteration 300: Loss = 0.0387694425880909, Prediction = 0.9619724750518799
Iteration 400: Loss = 0.023268384858965874, Prediction = 0.9770002365112305
Knotting status:  {'0_1': 0.9714285714285714}


 ---- Attempt 2 ----- 

Original label: 0
Iteration 0: Loss = 4.6789970397

In [20]:
print(f"Flipped {pos_flipped} out of {pos} positives\nFlipped {neg_flipped} out of {neg} negatives")

Flipped 0 out of 0 positives
Flipped 0 out of 1 negatives
