# Flip knot <-> unknot

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.dataloader import DataLoader
from huggingface_hub import snapshot_download

import pandas as pd
import numpy as np
import os
from pathlib import Path

from topoly import alexander

from foldingdiff import datasets
from foldingdiff.angles_and_coords import create_new_chain_nerf
from foldingdiff import utils
from foldingdiff import modelling

from GPyOpt.methods import BayesianOptimization

In [2]:
torch.manual_seed(42)

<torch._C.Generator at 0x7f2bf819f290>

In [35]:
MODEL = "wukevin/foldingdiff_cath"
BATCH_SIZE = 512
DEVICE = torch.device("cuda:0")
LR = 5e-4
EPOCHS = 40
ITERATION_STEPS = 500
FLIPPING_LR = 0.001

## Prepare dataset

In [4]:
class myDataset(datasets.CathCanonicalAnglesOnlyDataset):
    
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.labels = list(map(lambda x: 0 if x.split('/')[1][0] == 'u' else 1, self.filenames))

    def __getitem__(self, index):
        return_dict = super().__getitem__(index)
        label = self.labels[index]
        return return_dict, label

In [5]:
dataset = myDataset(
            pdbs='mixed_dataset',
            split=None,
            pad=128,
            min_length=40,
            trim_strategy="randomcrop",
            zero_center=True
        )

In [6]:
train_size = int(0.8 * len(dataset))  # 80% for training
valid_size = int(0.1 * len(dataset))  # 10% for validation
test_size = len(dataset) - train_size - valid_size # 10% for testing

train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size, test_size])


train_data_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=32)
valid_data_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=32)
test_data_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=32)

### Check the dataloader

In [7]:
train_features, train_labels = next(iter(train_data_loader))
print(f"Feature batch shape: {train_features['angles'].size()}")
print(f"Labels batch shape: {train_labels.size()}")

Feature batch shape: torch.Size([512, 128, 6])
Labels batch shape: torch.Size([512])


## Make simple classifier using foldingdiff representation as input

In [8]:
class SimpleClassifier(nn.Module):
    def __init__(self):  
        super(SimpleClassifier, self).__init__()
        
        self.pooling = nn.AdaptiveAvgPool1d(64)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(384, 384),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(384, 1),
            nn.Sigmoid()
        )

    def forward(self, input_representation):
        
        # Pool the input
        pooled_output = self.pooling(input_representation.transpose(1, 2))
        
        # Flatten the pooled output
        flattened_output = pooled_output.reshape(pooled_output.size(0), -1)
        
        # Classify
        logits = self.classifier(flattened_output)
        
        return logits

In [9]:
# Training loop
def train_model(model, train_loader, optimizer, criterion):
    model.train()
    for batch in train_loader:
        inputs, labels = batch
        labels = labels.float().to(DEVICE)
        optimizer.zero_grad()
        
        outputs = model(inputs['angles'].to(DEVICE))
        loss = criterion(outputs.squeeze(), labels)
        loss.backward()
        optimizer.step()

In [10]:
# Evaluation loop
def evaluate_model(model, test_loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for batch in test_loader:
            inputs, labels = batch
            labels = labels.float().to(DEVICE)
            
            outputs = model(inputs['angles'].to(DEVICE))
            loss = criterion(outputs.squeeze(), labels)
            total_loss += loss.item()
            preds = (outputs.squeeze() > 0.5).float()
            correct += (preds == labels).sum().item()
    
    avg_loss = total_loss / len(test_loader)
    accuracy = correct / len(test_loader.dataset)
    return avg_loss, accuracy

## Include Bayesian optimization for hyperparameters

In [11]:
def evaluation_for_optimization(params):
    learning_rate = params[0][0]
    num_epochs = int(params[0][1])
    
    simple_model = SimpleClassifier().to(DEVICE)
    
    adam = optim.AdamW(simple_model.parameters(), lr=learning_rate)
    criterion = nn.BCELoss()
    
    print(f"Learning rate: {learning_rate}, Epochs: {num_epochs}")
    
    for epoch in range(num_epochs):
        train_model(simple_model, train_data_loader, adam, criterion)
    
    avg_loss, accuracy = evaluate_model(simple_model, valid_data_loader, criterion)

    print(f'Loss: {avg_loss}, Accuracy: {accuracy}\n')
    
    return 1 - accuracy

In [12]:
optimization_space = [
    {'name': 'lr', 'type': 'continuous', 'domain': (5e-6, 5e-2)},
    {'name': 'epochs', 'type': 'continuous', 'domain': (50, 300)}
]

In [13]:
optimizer = BayesianOptimization(
    f=evaluation_for_optimization, domain=optimization_space, model_type='GP',
    acquisition_type='EI', max_iter=10
)
optimizer.run_optimization(max_iter=10)

Learning rate: 0.03802280052487008, Epochs: 110
Loss: 2.310053586959839, Accuracy: 0.8058252427184466

Learning rate: 0.02443337646252407, Epochs: 112
Loss: 1.4338037967681885, Accuracy: 0.8252427184466019

Learning rate: 0.03145738976359821, Epochs: 221
Loss: 1.3800090551376343, Accuracy: 0.8349514563106796

Learning rate: 0.009474440598891633, Epochs: 251
Loss: 1.1808702945709229, Accuracy: 0.8543689320388349

Learning rate: 0.02039576845329432, Epochs: 91
Loss: 1.4331005811691284, Accuracy: 0.8252427184466019

Learning rate: 0.03368800597087874, Epochs: 250
Loss: 1.3053553104400635, Accuracy: 0.8252427184466019

Learning rate: 0.008944459807988619, Epochs: 251
Loss: 1.1283518075942993, Accuracy: 0.8446601941747572

Learning rate: 0.0203827049580371, Epochs: 250
Loss: 1.2729368209838867, Accuracy: 0.8349514563106796

Learning rate: 5e-06, Epochs: 251
Loss: 0.6698043942451477, Accuracy: 0.6116504854368932

Learning rate: 0.04487936979552982, Epochs: 251
Loss: 1.365583896636963, Accura

In [14]:
print("Best parameters:")
print(f"Learning rate: {optimizer.X[np.argmin(optimizer.Y), 0]}")
print(f"Number of epochs: {int(optimizer.X[np.argmin(optimizer.Y), 1])}")
print(f"Best objective value: {np.min(optimizer.Y)}")

Best parameters:
Learning rate: 0.009474440598891633
Number of epochs: 251
Best objective value: 0.1456310679611651


### Get model with the optimal values

In [16]:
learning_rate = optimizer.X[np.argmin(optimizer.Y), 0]
num_epochs = int(optimizer.X[np.argmin(optimizer.Y), 1])

simple_model = SimpleClassifier().to(DEVICE)
    
adam = optim.AdamW(simple_model.parameters(), lr=learning_rate)
criterion = nn.BCELoss()
    
print(f"Learning rate: {learning_rate}, Epochs: {num_epochs}")
    
for epoch in range(num_epochs):
    train_model(simple_model, train_data_loader, adam, criterion)
    
avg_loss, accuracy = evaluate_model(simple_model, valid_data_loader, criterion)

print(f'Loss: {avg_loss}, Accuracy: {accuracy}\n')

Learning rate: 0.009474440598891633, Epochs: 251
Loss: 1.218287706375122, Accuracy: 0.8543689320388349



## Adversarial training - flip structures from the test set

In [17]:
simple_model.eval()  # Set the model to evaluation mode

SimpleClassifier(
  (pooling): AdaptiveAvgPool1d(output_size=64)
  (classifier): Sequential(
    (0): Linear(in_features=384, out_features=384, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=384, out_features=1, bias=True)
    (4): Sigmoid()
  )
)

In [36]:
# Iterative process to modify the input
def guide_structure(input_structure, target_label, verbose=False):
    for iteration in range(ITERATION_STEPS):
        optimizer.zero_grad()

        # Forward pass
        output = simple_model(input_structure)

        # Compute the loss (maximize the prediction towards the target label)
        loss = loss_fn(output, target_label)

        # Backward pass and update the input structure
        loss.backward()
        optimizer.step()

        # TODO: Apply constraints to keep the structure valid?
        with torch.no_grad():
            angular_idx = np.where(dataset.feature_is_angular['angles'])[0]
            # Copy the input structure to avoid in-place modification
            modified_structure = input_structure.clone()
            for s in modified_structure:
                s[..., angular_idx] = utils.modulo_with_wrapped_range(
                    s[..., angular_idx], range_min=-np.pi, range_max=np.pi
                )
            # Update the input structure with the modified one
            input_structure.copy_(modified_structure)
        
        # Print the progress
        if iteration % 10 == 0 and verbose:
            print(f"Iteration {iteration}: Loss = {loss.item()}, Prediction = {output.item()}")
            
    return input_structure

In [19]:
def inner_to_structure(sample, length, index=0):
    retval = [s + dataset.get_masked_means() for s in sample]
    # Because shifting may have caused us to go across the circle boundary, re-wrap
    angular_idx = np.where(dataset.feature_is_angular['angles'])[0]
    for s in retval:
        s[..., angular_idx] = utils.modulo_with_wrapped_range(
                    s[..., angular_idx], range_min=-np.pi, range_max=np.pi
                )
    df = pd.DataFrame(retval, columns=['phi', 'psi', 'omega', 'tau', 'CA:C:1N', 'C:1N:1CA']).astype("float")[:int(length)]

    create_new_chain_nerf(f'tmp/flip_{index}.pdb', df)
    
    knots = alexander(f'tmp/flip_{index}.pdb', tries=100)
    
    print('Knotting status: ', knots)
    
    if ('0_1' in knots.keys() and knots['0_1'] > 0.5):
        return False
    else:
        return True

In [20]:
test_features, test_labels = next(iter(test_data_loader))

In [37]:
for i in range(0, 3):
    print(f"\n\n ---- Attempt {i} ----- \n")
    
    length = test_features['lengths'][i]
    sample = torch.clone(test_features['angles'][i]).unsqueeze(0).to(DEVICE)
    orig_prediction = simple_model(sample)
    
    inner_to_structure(test_features['angles'][i], length, f"original_{i}" )
    
    sample.requires_grad = True
    loss_fn = nn.BCELoss()
    optimizer = optim.Adam([sample], lr=FLIPPING_LR)
    
    target = torch.tensor([[1.0 - test_labels[i]]], device=DEVICE)
        
    new_sample = guide_structure(sample, target, verbose=True)
    
    res = inner_to_structure(new_sample.detach().to('cpu')[0], length, i)
    
    if (res and (test_labels[i] == 0)) or (not res and (test_labels[i] == 1)):
        if res:
            print("\n----ATTENTION----\n")
        print(f"Prediction of original input = {orig_prediction}")
        print(f"Original label: {test_labels[i]}")
        print(f"Prediction of new input = {simple_model(new_sample)}")
        print(f"Target label: {target}")



 ---- Attempt 0 ----- 

Knotting status:  {'0_1': 0.99}
Iteration 0: Loss = 2.6374919414520264, Prediction = 0.0715404748916626
Iteration 10: Loss = 0.12626273930072784, Prediction = 0.8813832402229309
Iteration 20: Loss = 0.01303718239068985, Prediction = 0.9870474338531494
Iteration 30: Loss = 0.004926578141748905, Prediction = 0.9950855374336243
Iteration 40: Loss = 0.003274909919127822, Prediction = 0.9967304468154907
Iteration 50: Loss = 0.002712288871407509, Prediction = 0.9972913861274719
Iteration 60: Loss = 0.0024566403590142727, Prediction = 0.997546374797821
Iteration 70: Loss = 0.002305122558027506, Prediction = 0.9976975321769714
Iteration 80: Loss = 0.002192096784710884, Prediction = 0.9978103041648865
Iteration 90: Loss = 0.0020945535507053137, Prediction = 0.9979076385498047
Iteration 100: Loss = 0.0020046643912792206, Prediction = 0.9979973435401917
Iteration 110: Loss = 0.001919322181493044, Prediction = 0.9980825185775757
Iteration 120: Loss = 0.0018379281973466277