In [None]:
import MDAnalysis as mda
from MDAnalysis.analysis.distances import distance_array
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import seaborn as sns
import pandas as pd

import warnings
warnings.filterwarnings("ignore")

In [None]:
uA = mda.Universe("data/protein.pdb", "data/5us_protein_only_center.xtc")
uB = mda.Universe("data/protein.pdb", "data/only_prot_center.xtc")

print("Number of atoms in 5us_protein_only_center.xtc:", len(uA.atoms))
print("Number of frames in 5us_protein_only_center.xtc:", len(uA.trajectory))
print("Number of atoms in only_prot_center.xtc:", len(uB.atoms))
print("Number of frames in only_prot_center.xtc:", len(uB.trajectory))

In [None]:
def extract_distances_with_metadata(universe, trajectory_id, num_frames=100):
    distances = []
    metadata = []
    
    for i, ts in enumerate(universe.trajectory[:num_frames]):
        atoms = universe.select_atoms("all")
        pos = atoms.positions
        dist = distance_array(pos, pos)
        
        triu_indices = np.triu_indices(len(pos), k=1)
        dist_upper = dist[triu_indices].flatten()
        
        distances.append(dist_upper)
        metadata.append({
            'trajectory_id': trajectory_id,
            'original_frame': ts.frame,
            'sequence_index': i,
            'label': trajectory_id 
        })
    
    return distances, metadata

In [None]:
distances_A, metadata_A = extract_distances_with_metadata(uA, 0, 500)
distances_B, metadata_B = extract_distances_with_metadata(uB, 1, 500)

all_distances = distances_A + distances_B
all_metadata = metadata_A + metadata_B

data_tensor = torch.tensor(np.stack(all_distances), dtype=torch.float32)
labels = torch.tensor([meta['label'] for meta in all_metadata], dtype=torch.long)

print(f"Data tensor shape: {data_tensor.shape}")
print(f"Labels shape: {labels.shape}")

In [None]:
class ProteinDataset(torch.utils.data.Dataset):
    def __init__(self, data, labels, metadata):
        self.data = data
        self.labels = labels
        self.metadata = metadata
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx], idx

In [None]:
dataset = ProteinDataset(data_tensor, labels, all_metadata)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)


input_dim = data_tensor.shape[1]
print(f"Input dimension: {input_dim}")
print(f"Approximate parameter count: {input_dim * 256 + 256 * 32 + 32 * 256 + 256 * input_dim:,}")

class ProteinAutoencoder(nn.Module):
    def __init__(self, input_dim, encoding_dim=32):
        super(ProteinAutoencoder, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, encoding_dim),
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(encoding_dim, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim),
        )
        
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

In [None]:
device = torch.device('cpu')
model = ProteinAutoencoder(input_dim, encoding_dim=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
criterion = nn.MSELoss()

print(f"Using device: {device}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
train_losses = []
encoded_representations = []
batch_indices = []

model.train()
for epoch in range(50): 
    epoch_loss = 0
    epoch_encoded = []
    epoch_indices = []
    
    # Store all predictions and targets for R² calculation
    all_reconstructed = []
    all_targets = []
    
    for batch_data, batch_labels, batch_idx in train_loader:
        batch_data = batch_data.to(device)
        
        # Forward pass
        encoded, reconstructed = model(batch_data)
        loss = criterion(reconstructed, batch_data)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        
        # Store data for R² calculation
        all_reconstructed.append(reconstructed.detach().cpu().numpy())
        all_targets.append(batch_data.detach().cpu().numpy())
        
        # Store encoded representations and indices for visualization
        if epoch == 49:  # Last epoch
            epoch_encoded.append(encoded.detach().cpu())
            epoch_indices.extend(batch_idx.numpy())
    
    # Calculate R² for the entire epoch
    all_reconstructed = np.concatenate(all_reconstructed, axis=0).flatten()
    all_targets = np.concatenate(all_targets, axis=0).flatten()
    
    # Calculate R² manually
    ss_res = np.sum((all_targets - all_reconstructed) ** 2)
    ss_tot = np.sum((all_targets - np.mean(all_targets)) ** 2)
    if ss_tot != 0:
        r2 = 1 - (ss_res / ss_tot)
    else:
        r2 = 0
    
    avg_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_loss)
    
    if epoch == 49:
        encoded_representations = torch.cat(epoch_encoded, dim=0)
        batch_indices = epoch_indices
    
    print(f"Epoch {epoch+1} \t  Average Loss: {avg_loss:.6f} \t  R²: {r2:.6f}")

In [None]:
latent = model.encoder
latent.to(device)

data = data_tensor.to(device)
latent.eval()
data_1 = latent(data)
data_1.shape

In [None]:
latent_2d = data_1.detach().cpu().numpy()

trajectory_labels = np.array([meta['label'] for meta in all_metadata])

plt.figure(figsize=(10, 8))

trajectory_0_mask = trajectory_labels == 0
trajectory_1_mask = trajectory_labels == 1

plt.scatter(latent_2d[trajectory_0_mask, 0], 
           latent_2d[trajectory_0_mask, 1], 
           c='blue', 
           alpha=0.7, 
           label='Trajectory A (5us_protein_only_center.xtc)', 
           s=60)

plt.scatter(latent_2d[trajectory_1_mask, 0], 
           latent_2d[trajectory_1_mask, 1], 
           c='red', 
           alpha=0.7, 
           label='Trajectory B (only_prot_center.xtc)', 
           s=60)

plt.xlabel('Latent Dimension 1', fontsize=12)
plt.ylabel('Latent Dimension 2', fontsize=12)
plt.title('Protein Conformational States in 2D Latent Space', fontsize=14)
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)

# Add some statistics
print(f"Trajectory A points: {np.sum(trajectory_0_mask)}")
print(f"Trajectory B points: {np.sum(trajectory_1_mask)}")
print(f"Total points: {len(latent_2d)}")

plt.tight_layout()
plt.show()

# Alternative: More detailed plot with trajectory information
plt.figure(figsize=(12, 10))

# Create a color map for better visualization
colors = ['#1f77b4', '#ff7f0e']  # Blue and orange
trajectory_names = ['5us_protein_only_center.xtc', 'only_prot_center.xtc']

for traj_id in [0, 1]:
    mask = trajectory_labels == traj_id
    plt.scatter(latent_2d[mask, 0], 
               latent_2d[mask, 1], 
               c=colors[traj_id], 
               alpha=0.7, 
               label=f'Trajectory {traj_id}: {trajectory_names[traj_id]}', 
               s=80,
               edgecolors='black',
               linewidth=0.5)

plt.xlabel('Latent Dimension 1', fontsize=14)
plt.ylabel('Latent Dimension 2', fontsize=14)
plt.title('Protein Conformational Clustering in Autoencoder Latent Space', fontsize=16)
plt.legend(fontsize=12, loc='best')
plt.grid(True, alpha=0.3)

# Add frame sequence information as annotations (optional - might be cluttered)
# You can uncomment this if you want to see frame numbers
# for i, meta in enumerate(all_metadata[::5]):  # Every 5th point to avoid clutter
#     plt.annotate(f"F{meta['original_frame']}", 
#                 (latent_2d[i*5, 0], latent_2d[i*5, 1]), 
#                 fontsize=8, alpha=0.6)

plt.tight_layout()
plt.show()

# Create a DataFrame for easier analysis
df_results = pd.DataFrame({
    'latent_dim_1': latent_2d[:, 0],
    'latent_dim_2': latent_2d[:, 1],
    'trajectory_id': trajectory_labels,
    'trajectory_name': [trajectory_names[label] for label in trajectory_labels],
    'original_frame': [meta['original_frame'] for meta in all_metadata],
    'sequence_index': [meta['sequence_index'] for meta in all_metadata]
})

print("\nDataFrame summary:")
print(df_results.head())
print(f"\nTrajectory distribution:")
print(df_results['trajectory_name'].value_counts())

# Seaborn version for publication-quality plot
plt.figure(figsize=(10, 8))
sns.scatterplot(data=df_results, 
                x='latent_dim_1', 
                y='latent_dim_2', 
                hue='trajectory_name',
                alpha=0.7,
                s=80)
plt.title('Protein Trajectory Separation in Latent Space')
plt.xlabel('Latent Dimension 1')
plt.ylabel('Latent Dimension 2')
plt.legend(title='Trajectory', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()