In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math
import random
import re # Import regex for parsing PDB REMARK line

# Define a vocabulary for atom types and coordinate bins
ATOM_VOCAB = {
    'PAD': 0, 'START': 1, 'END': 2,
    'C': 3, 'O': 4, 'N': 5, 'H': 6,
    'S': 7, 'F': 8, 'Cl': 9, 'Br': 10, 'I': 11 # Add more common atoms as needed
}

REV_ATOM_VOCAB = {v: k for k, v in ATOM_VOCAB.items()}

COORD_MIN = -25.0
COORD_MAX = 25.0
NUM_COORD_BINS = 50 # Increased bins for better coordinate resolution
COORD_BIN_SIZE = (COORD_MAX - COORD_MIN) / NUM_COORD_BINS
COORD_TOKEN_OFFSET = len(ATOM_VOCAB)
TOTAL_VOCAB_SIZE = COORD_TOKEN_OFFSET + (NUM_COORD_BINS * 3) # Max bin index for coords

MAX_ATOMS_IN_DATASET = 200 # This should be determined by the largest molecule in your ~9000 PDBs.

def discretize_coord(coord_val):
    """Discretizes a single coordinate value into a bin index."""
    bin_idx = int((coord_val - COORD_MIN) / COORD_BIN_SIZE)
    return max(0, min(NUM_COORD_BINS - 1, bin_idx))

def parse_pdb_content_and_generate_dataset(pdb_content, num_molecules_to_generate=1000, coord_noise_std=0.1, pol_noise_std=0.5):
    base_atoms_data = []
    base_polarizability = None
    base_connectivity = set() # Store as a set of tuples (atom_idx1, atom_idx2) for uniqueness

    # Parse polarizability from REMARK line
    polarizability_match = re.search(r"REMARK static_polarizability ([\d.]+)", pdb_content)
    if polarizability_match:
        base_polarizability = float(polarizability_match.group(1))
    else:
        print("Warning: 'static_polarizability' not found in PDB REMARK. Using a default value of 0.")
        base_polarizability = 0 # Default if not found

    # Parse HETATM lines to get atom types and coordinates
    atom_pdb_id_to_idx = {} # Map PDB atom ID (1-indexed) to 0-indexed list index
    current_atom_idx = 0
    for line in pdb_content.splitlines():
        if line.startswith("HETATM"):
            try:
                pdb_atom_id = int(line[6:11].strip()) # PDB atom serial number
                atom_type = line[76:78].strip() # Element symbol
                x = float(line[30:38])
                y = float(line[38:46])
                z = float(line[46:54])

                if atom_type not in ATOM_VOCAB:
                    # Add new atom types to vocabulary if encountered
                    print(f"Adding new atom type '{atom_type}' to vocabulary.")
                    ATOM_VOCAB[atom_type] = len(ATOM_VOCAB)
                    REV_ATOM_VOCAB[ATOM_VOCAB[atom_type]] = atom_type
                    # Recompute COORD_TOKEN_OFFSET and TOTAL_VOCAB_SIZE
                    global COORD_TOKEN_OFFSET, TOTAL_VOCAB_SIZE
                    COORD_TOKEN_OFFSET = len(ATOM_VOCAB)
                    TOTAL_VOCAB_SIZE = COORD_TOKEN_OFFSET + (NUM_COORD_BINS * 3)

                base_atoms_data.append({
                    'type': atom_type,
                    'coords': (x, y, z)
                })
                atom_pdb_id_to_idx[pdb_atom_id] = current_atom_idx
                current_atom_idx += 1
            except ValueError as e:
                print(f"Error parsing HETATM line: {line} - {e}")
                continue # Skip malformed lines

    # Parse CONECT lines to get connectivity
    for line in pdb_content.splitlines():
        if line.startswith("CONECT"):
            try:
                # PDB CONECT format: CONECT atom1 atom2 atom3 ...
                # All IDs are 1-indexed
                connected_ids = [int(line[i:i+5].strip()) for i in range(6, len(line), 5) if line[i:i+5].strip()]
                if not connected_ids:
                    continue

                atom1_pdb_id = connected_ids[0]
                if atom1_pdb_id not in atom_pdb_id_to_idx:
                    print(f"Warning: Atom ID {atom1_pdb_id} in CONECT not found in HETATM records. Skipping.")
                    continue
                atom1_idx = atom_pdb_id_to_idx[atom1_pdb_id]

                for atom2_pdb_id in connected_ids[1:]:
                    if atom2_pdb_id not in atom_pdb_id_to_idx:
                        print(f"Warning: Atom ID {atom2_pdb_id} in CONECT not found in HETATM records. Skipping.")
                        continue
                    atom2_idx = atom_pdb_id_to_idx[atom2_pdb_id]
                    # Ensure consistent order for bond pairs (smaller index first)
                    bond = tuple(sorted((atom1_idx, atom2_idx)))
                    if bond[0] != bond[1]: # Avoid self-loops
                        base_connectivity.add(bond)
            except ValueError as e:
                print(f"Error parsing CONECT line: {line} - {e}")
                continue

    if not base_atoms_data:
        raise ValueError("No HETATM records found in the provided PDB content.")

    num_base_atoms = len(base_atoms_data)
    # Ensure MAX_ATOMS_IN_DATASET is large enough
    if num_base_atoms > MAX_ATOMS_IN_DATASET:
        raise ValueError(f"Base molecule has {num_base_atoms} atoms, but MAX_ATOMS_IN_DATASET is {MAX_ATOMS_IN_DATASET}. Please increase MAX_ATOMS_IN_DATASET.")

    # Generate augmented dataset
    data = []
    max_seq_len_overall = 0

    for i in range(num_molecules_to_generate):
        mol_tokens = [ATOM_VOCAB['START']]
        # Add noise to polarizability
        current_polarizability = base_polarizability + random.gauss(0, pol_noise_std)
        current_polarizability = max(1.0, current_polarizability) # Ensure positive

        # Store noisy coordinates to reconstruct connectivity matrix for this augmented molecule
        noisy_coords_for_connectivity = []

        for atom_info in base_atoms_data:
            atom_type = atom_info['type']
            x, y, z = atom_info['coords']

            # Add Gaussian noise to coordinates
            x_noisy = x + random.gauss(0, coord_noise_std)
            y_noisy = y + random.gauss(0, coord_noise_std)
            z_noisy = z + random.gauss(0, coord_noise_std)

            mol_tokens.append(ATOM_VOCAB[atom_type])
            mol_tokens.append(discretize_coord(x_noisy) + COORD_TOKEN_OFFSET)
            mol_tokens.append(discretize_coord(y_noisy) + COORD_TOKEN_OFFSET + NUM_COORD_BINS) # Offset for Y bins
            mol_tokens.append(discretize_coord(z_noisy) + COORD_TOKEN_OFFSET + (NUM_COORD_BINS * 2)) # Offset for Z bins
            noisy_coords_for_connectivity.append((x_noisy, y_noisy, z_noisy))

        mol_tokens.append(ATOM_VOCAB['END'])

        current_seq_len = len(mol_tokens)
        max_seq_len_overall = max(max_seq_len_overall, current_seq_len)

        # Create the target connectivity matrix for this augmented molecule
        # We'll use a flattened upper triangular matrix representation
        # Size of flattened upper triangular matrix for N atoms is N * (N - 1) / 2
        num_possible_bonds = MAX_ATOMS_IN_DATASET * (MAX_ATOMS_IN_DATASET - 1) // 2
        connectivity_target = torch.zeros(num_possible_bonds, dtype=torch.float)

        # Populate the connectivity target based on base_connectivity
        # Only consider bonds between atoms present in the base molecule
        bond_idx = 0
        for i in range(MAX_ATOMS_IN_DATASET):
            for j in range(i + 1, MAX_ATOMS_IN_DATASET):
                if i < num_base_atoms and j < num_base_atoms: # Only for atoms actually in this molecule
                    if (i, j) in base_connectivity:
                        connectivity_target[bond_idx] = 1.0
                bond_idx += 1

        data.append({
            'tokens': torch.tensor(mol_tokens, dtype=torch.long),
            'polarizability': torch.tensor(current_polarizability, dtype=torch.float),
            'connectivity_matrix': connectivity_target # This is the flattened upper triangular matrix
        })

    # After collecting all data, pad them to the maximum sequence length found
    padded_data = []
    for item in data:
        padded_mol_tokens = item['tokens'].tolist() + [ATOM_VOCAB['PAD']] * (max_seq_len_overall - len(item['tokens']))
        padded_data.append({
            'tokens': torch.tensor(padded_mol_tokens, dtype=torch.long),
            'polarizability': item['polarizability'],
            'connectivity_matrix': item['connectivity_matrix']
        })

    return padded_data, max_seq_len_overall, num_base_atoms

# Custom Dataset class
class MoleculeDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]['tokens'], self.data[idx]['polarizability'], self.data[idx]['connectivity_matrix']


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math
import random
import re # Import regex for parsing PDB REMARK line
import os # Import the os module for file system operations

# --- 1. Data Generation and Preprocessing (PDB Parsing and Data Augmentation) ---

# Define a vocabulary for atom types and coordinate bins
ATOM_VOCAB = {
    'PAD': 0, 'START': 1, 'END': 2,
    'C': 3, 'O': 4, 'N': 5, 'H': 6,
    'S': 7, 'F': 8, 'Cl': 9, 'Br': 10, 'I': 11 # Add more common atoms as needed
}
# Reverse mapping for decoding
REV_ATOM_VOCAB = {v: k for k, v in ATOM_VOCAB.items()}

# Coordinate binning parameters
COORD_MIN = -15.0
COORD_MAX = 15.0
NUM_COORD_BINS = 40 # Increased bins for better coordinate resolution
COORD_BIN_SIZE = (COORD_MAX - COORD_MIN) / NUM_COORD_BINS

# These will be updated dynamically based on the parsed data
# Initialize with placeholder values. Actual values will be determined during data loading.
COORD_TOKEN_OFFSET = len(ATOM_VOCAB)
TOTAL_VOCAB_SIZE = COORD_TOKEN_OFFSET + (NUM_COORD_BINS * 3)

def update_vocab_and_offsets(atom_type):
    """Dynamically updates ATOM_VOCAB and recalculates global offsets."""
    global ATOM_VOCAB, REV_ATOM_VOCAB, COORD_TOKEN_OFFSET, TOTAL_VOCAB_SIZE
    if atom_type not in ATOM_VOCAB:
        print(f"Adding new atom type '{atom_type}' to vocabulary.")
        ATOM_VOCAB[atom_type] = len(ATOM_VOCAB)
        REV_ATOM_VOCAB[ATOM_VOCAB[atom_type]] = atom_type
        # Recalculate offsets and total vocab size
        COORD_TOKEN_OFFSET = len(ATOM_VOCAB)
        TOTAL_VOCAB_SIZE = COORD_TOKEN_OFFSET + (NUM_COORD_BINS * 3)


def discretize_coord(coord_val):
    """Discretizes a single coordinate value into a bin index."""
    bin_idx = int((coord_val - COORD_MIN) / COORD_BIN_SIZE)
    return max(0, min(NUM_COORD_BINS - 1, bin_idx))

def parse_pdb_content_and_generate_dataset(pdb_content, num_molecules_to_augment_per_pdb=1, coord_noise_std=0.1, pol_noise_std=0.5):
    """
    Parses a single PDB content string, extracts atom types, coordinates, polarizability,
    and connectivity. Then, generates `num_molecules_to_augment_per_pdb` perturbed versions
    of this molecule.

    Args:
        pdb_content (str): The content of a single PDB file.
        num_molecules_to_augment_per_pdb (int): The number of augmented versions to create from this single PDB.
        coord_noise_std (float): Standard deviation of Gaussian noise added to coordinates.
        pol_noise_std (float): Standard deviation of Gaussian noise added to polarizability.

    Returns:
        tuple: A tuple containing:
            - list: List of dictionaries, each with 'tokens', 'polarizability', and 'connectivity_matrix'.
            - int: Maximum sequence length observed for this molecule's augmentations.
            - int: Number of atoms in the base molecule.
            - float: The base polarizability from the PDB.
    """
    base_atoms_data = []
    base_polarizability = None
    base_connectivity = set() # Store as a set of tuples (atom_idx1, atom_idx2) for uniqueness

    # Parse polarizability from REMARK line
    polarizability_match = re.search(r"REMARK static_polarizability ([\d.]+)", pdb_content)
    if polarizability_match:
        base_polarizability = float(polarizability_match.group(1))
    else:
        # Fallback if polarizability not found in PDB (adjust as per your data)
        # For a full dataset, you might want to exclude files without this REMARK
        print("Warning: 'static_polarizability' not found in PDB REMARK. Using a default value of 10.0.")
        base_polarizability = 10.0

    # Parse HETATM lines to get atom types and coordinates
    atom_pdb_id_to_idx = {} # Map PDB atom ID (1-indexed) to 0-indexed list index
    current_atom_idx = 0
    for line in pdb_content.splitlines():
        if line.startswith("HETATM"):
            try:
                pdb_atom_id = int(line[6:11].strip()) # PDB atom serial number
                atom_type = line[76:78].strip() # Element symbol

                # Dynamically update vocabulary for new atom types
                update_vocab_and_offsets(atom_type)

                x = float(line[30:38])
                y = float(line[38:46])
                z = float(line[46:54])

                base_atoms_data.append({
                    'type': atom_type,
                    'coords': (x, y, z)
                })
                atom_pdb_id_to_idx[pdb_atom_id] = current_atom_idx
                current_atom_idx += 1
            except ValueError as e:
                print(f"Error parsing HETATM line: '{line}' - {e}. Skipping.")
                continue # Skip malformed lines

    # Parse CONECT lines to get connectivity
    for line in pdb_content.splitlines():
        if line.startswith("CONECT"):
            try:
                connected_ids = [int(line[i:i+5].strip()) for i in range(6, len(line), 5) if line[i:i+5].strip()]
                if not connected_ids:
                    continue

                atom1_pdb_id = connected_ids[0]
                if atom1_pdb_id not in atom_pdb_id_to_idx:
                    print(f"Warning: Atom ID {atom1_pdb_id} in CONECT not found in HETATM records. Skipping bond.")
                    continue
                atom1_idx = atom_pdb_id_to_idx[atom1_pdb_id]

                for atom2_pdb_id in connected_ids[1:]:
                    if atom2_pdb_id not in atom_pdb_id_to_idx:
                        print(f"Warning: Atom ID {atom2_pdb_id} in CONECT not found in HETATM records. Skipping bond.")
                        continue
                    atom2_idx = atom_pdb_id_to_idx[atom2_pdb_id]
                    # Ensure consistent order for bond pairs (smaller index first)
                    bond = tuple(sorted((atom1_idx, atom2_idx)))
                    if bond[0] != bond[1]: # Avoid self-loops
                        base_connectivity.add(bond)
            except ValueError as e:
                print(f"Error parsing CONECT line: '{line}' - {e}. Skipping.")
                continue

    if not base_atoms_data:
        raise ValueError("No HETATM records found in the provided PDB content. Cannot generate data.")

    num_base_atoms = len(base_atoms_data)
    # The actual `MAX_ATOMS_IN_DATASET` (overall largest molecule) will be determined
    # in train_model and passed to the LLM. For individual parsing, we just need `num_base_atoms`.

    # Generate augmented dataset
    augmented_data = []
    max_seq_len_for_this_pdb = 0

    for i in range(num_molecules_to_augment_per_pdb):
        mol_tokens = [ATOM_VOCAB['START']]
        # Add noise to polarizability
        current_polarizability = base_polarizability + random.gauss(0, pol_noise_std)
        current_polarizability = max(1.0, current_polarizability) # Ensure positive

        for atom_info in base_atoms_data:
            atom_type = atom_info['type']
            x, y, z = atom_info['coords']

            # Add Gaussian noise to coordinates
            x_noisy = x + random.gauss(0, coord_noise_std)
            y_noisy = y + random.gauss(0, coord_noise_std)
            z_noisy = z + random.gauss(0, coord_noise_std)

            mol_tokens.append(ATOM_VOCAB[atom_type])
            mol_tokens.append(discretize_coord(x_noisy) + COORD_TOKEN_OFFSET)
            mol_tokens.append(discretize_coord(y_noisy) + COORD_TOKEN_OFFSET + NUM_COORD_BINS)
            mol_tokens.append(discretize_coord(z_noisy) + COORD_TOKEN_OFFSET + (NUM_COORD_BINS * 2))

        mol_tokens.append(ATOM_VOCAB['END'])

        current_seq_len = len(mol_tokens)
        max_seq_len_for_this_pdb = max(max_seq_len_for_this_pdb, current_seq_len)

        # For augmented data, bonds remain the same as the base molecule
        # Connectivity matrix will be created later, after determining global MAX_ATOMS_IN_DATASET
        # Store base_connectivity for now.
        augmented_data.append({
            'tokens': torch.tensor(mol_tokens, dtype=torch.long),
            'polarizability': torch.tensor(current_polarizability, dtype=torch.float),
            'base_connectivity': base_connectivity, # Store the original set of bonds
            'num_base_atoms': num_base_atoms # Store the count for this specific molecule
        })

    return augmented_data, max_seq_len_for_this_pdb, num_base_atoms, base_polarizability

# Custom Dataset class
class MoleculeDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]['tokens'], self.data[idx]['polarizability'], self.data[idx]['connectivity_matrix']

# --- 2. Property Embedding ---
class GaussianExpansion(nn.Module):
    """
    Applies Gaussian expansion to a scalar property.
    Transforms a single scalar value into a higher-dimensional vector representation.
    """
    def __init__(self, start=0.0, stop=200.0, num_gaussians=100): # Adjusted default range for polarizability
        super().__init__()
        self.register_buffer('offset', torch.linspace(start, stop, num_gaussians))
        self.register_buffer('widths', torch.tensor((stop - start) / num_gaussians, dtype=torch.float).expand_as(self.offset))

    def forward(self, x):
        """
        x: (batch_size, 1) or (batch_size,) tensor of scalar property values
        Output: (batch_size, num_gaussians)
        """
        if x.dim() == 1:
            x = x.unsqueeze(-1) # Ensure x is (batch_size, 1)
        return torch.exp(-((x - self.offset) ** 2) / (2 * self.widths ** 2))

class PropertyEmbedding(nn.Module):
    """
    Combines Gaussian expansion with an MLP to create a high-dimensional property embedding.
    """
    def __init__(self, num_gaussians, hidden_dim, embedding_dim, prop_min_val, prop_max_val):
        super().__init__()
        # Adjusted GaussianExpansion range
        self.gaussian_expansion = GaussianExpansion(start=prop_min_val, stop=prop_max_val, num_gaussians=num_gaussians)
        self.mlp = nn.Sequential(
            nn.Linear(num_gaussians, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embedding_dim)
        )

    def forward(self, polarizability):
        """
        polarizability: (batch_size,) tensor of polarizability values
        Output: (batch_size, embedding_dim)
        """
        gaussian_features = self.gaussian_expansion(polarizability)
        property_embed = self.mlp(gaussian_features)
        return property_embed

# --- 3. Transformer-based Generative Model ---
class PositionalEncoding(nn.Module):
    """
    Injects positional information into token embeddings.
    """
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1) # (max_len, 1, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        x: (seq_len, batch_size, d_model)
        """
        x = x + self.pe[:x.size(0), :]
        return x

class MoleculeGeneratorLLM(nn.Module):
    """
    A simplified Transformer Decoder-only model for molecule generation,
    conditioned on a property embedding, and predicting connectivity.
    """
    def __init__(self, vocab_size, d_model, nhead, num_decoder_layers,
                 dim_feedforward, dropout, max_seq_len,
                 num_gaussians_prop, prop_hidden_dim, prop_embedding_dim,
                 max_atoms_in_dataset_overall, prop_min_val, prop_max_val):
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        self.vocab_size = vocab_size
        self.max_atoms_in_dataset_overall = max_atoms_in_dataset_overall
        # Calculate the size of the flattened upper triangular matrix for connectivity
        self.num_possible_bonds = max_atoms_in_dataset_overall * (max_atoms_in_dataset_overall - 1) // 2

        # Property embedding module
        self.property_embedder = PropertyEmbedding(
            num_gaussians=num_gaussians_prop,
            hidden_dim=prop_hidden_dim,
            embedding_dim=prop_embedding_dim,
            prop_min_val=prop_min_val,
            prop_max_val=prop_max_val
        )

        # Token embedding layer
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_len)

        # Linear layer to project property embedding to d_model for concatenation/addition
        self.prop_proj = nn.Linear(prop_embedding_dim, d_model)

        # Transformer Decoder
        # Ensure batch_first=True is consistently applied
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers)

        # Output layer for token prediction
        self.token_output_layer = nn.Linear(d_model, vocab_size)

        # New: Output head for connectivity prediction
        self.bond_prediction_head = nn.Sequential(
            nn.Linear(d_model, dim_feedforward), # Use d_model as input since we'll pool
            nn.ReLU(),
            nn.Linear(dim_feedforward, self.num_possible_bonds)
        )

        self.init_weights()

    def init_weights(self):
        """Initializes weights for better training stability."""
        initrange = 0.1
        self.token_embedding.weight.data.uniform_(-initrange, initrange)
        self.token_output_layer.bias.data.zero_()
        self.token_output_layer.weight.data.uniform_(-initrange, initrange)
        # Initialize bond prediction head weights
        for layer in self.bond_prediction_head:
            if isinstance(layer, nn.Linear):
                layer.bias.data.zero_()
                layer.weight.data.uniform_(-initrange, initrange)

    def generate_square_subsequent_mask(self, sz):
        """Generates an upper-triangular matrix of -inf, used for masking future tokens."""
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src_tokens, polarizability):
        """
        src_tokens: (batch_size, seq_len) - input sequence (e.g., [START_TOKEN, atom1, x1, y1, z1, ...])
        polarizability: (batch_size,) - desired polarizability for each molecule
        """
        batch_size, seq_len = src_tokens.shape

        # 1. Embed property
        prop_embed = self.property_embedder(polarizability) # (batch_size, prop_embedding_dim)

        # 2. Embed input tokens
        token_embed = self.token_embedding(src_tokens) * math.sqrt(self.d_model) # (batch_size, seq_len, d_model)
        # Positional encoding expects (seq_len, batch_size, d_model), so transpose for PE, then transpose back
        token_embed = self.positional_encoding(token_embed.transpose(0, 1)).transpose(0, 1) # (batch_size, seq_len, d_model)

        # TransformerDecoder forward pass
        # tgt should be (batch_size, seq_len, d_model) because batch_first=True
        # memory should be (batch_size, memory_len, d_model) because batch_first=True
        decoder_output = self.transformer_decoder(
            tgt=token_embed, # (batch_size, seq_len, d_model) - NO TRANSPOSE HERE
            memory=self.prop_proj(prop_embed).unsqueeze(1), # (batch_size, 1, d_model) - Corrected memory shape
            tgt_mask=self.generate_square_subsequent_mask(seq_len).to(src_tokens.device), # Mask is (seq_len, seq_len)
            tgt_key_padding_mask=(src_tokens == ATOM_VOCAB['PAD']),
        )
        # decoder_output is already (batch_size, seq_len, d_model) because batch_first=True
        # No need for decoder_output.transpose(0, 1) here

        # Predict tokens
        token_logits = self.token_output_layer(decoder_output) # (batch_size, seq_len, vocab_size)

        # Predict connectivity: Use the mean of the decoder output across the sequence length
        pooled_decoder_output = decoder_output.mean(dim=1) # (batch_size, d_model)
        connectivity_logits = self.bond_prediction_head(pooled_decoder_output) # (batch_size, num_possible_bonds)

        return token_logits, connectivity_logits

    def generate(self, polarizability, max_new_tokens=50, temperature=1.0, top_k=None, top_p=None):
        """
        Generates a new molecule sequence and predicts its connectivity
        given a desired polarizability.
        """
        self.eval() # Set model to evaluation mode
        device = next(self.parameters()).device

        # Prepare polarizability input
        polarizability_tensor = torch.tensor([polarizability], dtype=torch.float, device=device)

        # Initialize sequence with START token
        generated_sequence = [ATOM_VOCAB['START']]
        input_tokens = torch.tensor(generated_sequence, dtype=torch.long, device=device).unsqueeze(0) # (1, 1)

        for _ in range(max_new_tokens):
            if input_tokens.shape[1] >= self.max_seq_len:
                break # Avoid exceeding max sequence length

            with torch.no_grad():
                token_logits, _ = self.forward(input_tokens, polarizability_tensor)

            next_token_logits = token_logits[:, -1, :] # (1, vocab_size)

            # Apply sampling strategies
            if temperature != 1.0:
                next_token_logits = next_token_logits / temperature
            if top_k is not None:
                v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)))
                next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')
            if top_p is not None:
                sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                next_token_logits = next_token_logits.scatter_(1, sorted_indices[sorted_indices_to_remove], float('-Inf'))

            # Sample the next token
            probs = torch.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(0) # (1,)

            generated_sequence.append(next_token.item())

            if next_token.item() == ATOM_VOCAB['END']:
                break

            input_tokens = torch.cat([input_tokens, next_token.unsqueeze(0)], dim=1)

        # After generating the full sequence, make a final forward pass to get connectivity prediction
        final_input_tokens = torch.tensor(generated_sequence, dtype=torch.long, device=device).unsqueeze(0)
        with torch.no_grad():
            _, connectivity_logits = self.forward(final_input_tokens, polarizability_tensor)

        predicted_bonds_flat_thresholded = (torch.sigmoid(connectivity_logits) > 0.5).squeeze(0).cpu().numpy()

        # Decode the generated sequence
        decoded_molecule_str = []
        atoms_generated_count = 0
        current_atom_type = None
        current_coords_buffer = [] # Use a buffer to store coordinates temporarily

        for token_id in generated_sequence:
            if token_id == ATOM_VOCAB['START']:
                decoded_molecule_str.append("START")
            elif token_id == ATOM_VOCAB['END']:
                decoded_molecule_str.append("END")
                break # Stop decoding after END token
            elif token_id == ATOM_VOCAB['PAD']:
                continue
            elif token_id < COORD_TOKEN_OFFSET: # It's an atom type
                if current_atom_type is not None: # If previous atom had incomplete coords
                    decoded_molecule_str.append(f"Atom {atoms_generated_count}: {current_atom_type}, Coords: (Incomplete)")
                    atoms_generated_count += 1
                current_atom_type = REV_ATOM_VOCAB.get(token_id, f"UNKNOWN_ATOM_{token_id}")
                current_coords_buffer = [] # Reset buffer for new atom
            elif token_id >= COORD_TOKEN_OFFSET: # It's a coordinate bin
                coord_bin_idx = token_id - COORD_TOKEN_OFFSET
                # Determine which dimension (x, y, z)
                if coord_bin_idx < NUM_COORD_BINS: # X coord
                    current_coords_buffer.append(COORD_MIN + coord_bin_idx * COORD_BIN_SIZE)
                elif coord_bin_idx < NUM_COORD_BINS * 2: # Y coord
                    current_coords_buffer.append(COORD_MIN + (coord_bin_idx - NUM_COORD_BINS) * COORD_BIN_SIZE)
                else: # Z coord
                    current_coords_buffer.append(COORD_MIN + (coord_bin_idx - NUM_COORD_BINS * 2) * COORD_BIN_SIZE)

                if len(current_coords_buffer) == 3 and current_atom_type is not None:
                    decoded_molecule_str.append(f"Atom {atoms_generated_count}: {current_atom_type}, Coords: ({current_coords_buffer[0]:.2f}, {current_coords_buffer[1]:.2f}, {current_coords_buffer[2]:.2f})")
                    atoms_generated_count += 1
                    current_atom_type = None # Reset for next atom
                    current_coords_buffer = []

        # Decode predicted connectivity based on the number of atoms actually generated
        predicted_bonds_list = []
        bond_flat_idx = 0
        # Iterate over possible atom pairs up to the number of atoms generated *or* max_atoms_in_dataset_overall
        # whichever is smaller, to correctly map the flattened matrix.
        for i in range(min(atoms_generated_count, self.max_atoms_in_dataset_overall)):
            for j in range(i + 1, min(atoms_generated_count, self.max_atoms_in_dataset_overall)):
                if bond_flat_idx < len(predicted_bonds_flat_thresholded) and predicted_bonds_flat_thresholded[bond_flat_idx] == 1:
                    predicted_bonds_list.append(f"({i}-{j})")
                bond_flat_idx += 1


        return " ".join(decoded_molecule_str), predicted_bonds_list


# --- 4. Training Configuration and Loop ---
def train_model():
    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Hyperparameters
    NUM_MOLECULES_TO_AUGMENT_PER_PDB = 1 # Set to 1 for now, as you have ~9000 unique PDBs
    BATCH_SIZE = 32
    NUM_EPOCHS = 50
    LEARNING_RATE = 1e-4
    BOND_LOSS_WEIGHT = 0.5 # Weight for the connectivity prediction loss
    COORD_NOISE_STD = 0.1 # Standard deviation for coordinate noise
    POL_NOISE_STD = 0.5 # Standard deviation for polarizability noise

    # Model parameters
    D_MODEL = 256 # Dimension of embeddings and Transformer layers
    NHEAD = 8 # Number of attention heads
    NUM_DECODER_LAYERS = 3 # Number of Transformer decoder layers
    DIM_FEEDFORWARD = 512 # Dimension of the feedforward network in Transformer
    DROPOUT = 0.1

    # Property embedding parameters
    NUM_GAUSSIANS_PROP = 100
    PROP_HIDDEN_DIM = 128
    PROP_EMBEDDING_DIM = D_MODEL # Match property embedding dim to d_model for easier integration

    # --- Data Loading from Folder ---
    pdb_folder = 'xyz_files' # Your folder name
    # Create the folder if it doesn't exist (useful for testing or if you create dummy files)
    os.makedirs(pdb_folder, exist_ok=True)

    all_raw_molecules_data = [] # Stores lists of augmented data from each PDB
    max_seq_len_overall_dataset = 0
    max_atoms_overall_dataset = 0 # This will be the N in N*(N-1)/2 for bond prediction
    min_polarizability_overall = float('inf')
    max_polarizability_overall = float('-inf')
    num_parsed_pdbs = 0

    print(f"Reading PDB files from '{pdb_folder}'...")
    for filename in os.listdir(pdb_folder):
        if filename.startswith('monomer_') and filename.endswith('.pdb'):
            filepath = os.path.join(pdb_folder, filename)
            print(f"Parsing {filepath}...")
            try:
                with open(filepath, 'r') as f:
                    pdb_content = f.read()

                # parse_pdb_content_and_generate_dataset returns:
                # (augmented_data_list, max_seq_len_for_this_pdb, num_base_atoms, base_polarizability)
                augmented_data_from_one_pdb, current_max_seq_len, current_num_atoms, current_polarizability = \
                    parse_pdb_content_and_generate_dataset(
                        pdb_content,
                        num_molecules_to_augment_per_pdb=NUM_MOLECULES_TO_AUGMENT_PER_PDB,
                        coord_noise_std=COORD_NOISE_STD, # Pass the defined constants
                        pol_noise_std=POL_NOISE_STD      # Pass the defined constants
                    )

                all_raw_molecules_data.extend(augmented_data_from_one_pdb)
                max_seq_len_overall_dataset = max(max_seq_len_overall_dataset, current_max_seq_len)
                max_atoms_overall_dataset = max(max_atoms_overall_dataset, current_num_atoms)
                min_polarizability_overall = min(min_polarizability_overall, current_polarizability)
                max_polarizability_overall = max(max_polarizability_overall, current_polarizability)
                num_parsed_pdbs += 1
            except Exception as e:
                print(f"Failed to parse {filepath}: {e}")
                continue

    if num_parsed_pdbs == 0:
        print(f"No PDB files found or successfully parsed in '{pdb_folder}'. Please check the folder and file names.")
        print("Creating a dummy PDB file for demonstration...")
        # Create a dummy PDB file for demonstration if no files are found
        dummy_pdb_content = """REMARK static_polarizability 184.75
HETATM      1  C   MOL     1        -1.000   0.000   0.000  1.00  0.00           C
HETATM      2  O   MOL     1         0.000   1.000   0.000  1.00  0.00           O
HETATM      3  H   MOL     1         0.000   0.000   1.000  1.00  0.00           H
CONECT      1    2
CONECT      2    3
END"""
        dummy_filepath = os.path.join(pdb_folder, "monomer_dummy.pdb")
        with open(dummy_filepath, "w") as f:
            f.write(dummy_pdb_content)
        print(f"Dummy file '{dummy_filepath}' created. Please rerun the script.")
        return # Exit if no real data to process

    print(f"\nFinished parsing {num_parsed_pdbs} PDB files.")
    print(f"Total augmented molecules for training: {len(all_raw_molecules_data)}")
    print(f"Max sequence length observed: {max_seq_len_overall_dataset}")
    print(f"Max atoms in any molecule (for bond matrix sizing): {max_atoms_overall_dataset}")
    print(f"Overall polarizability range: {min_polarizability_overall:.2f} - {max_polarizability_overall:.2f}")

    # Now, process all collected data to pad tokens and create final connectivity matrices
    final_dataset_items = []
    num_possible_bonds_overall = max_atoms_overall_dataset * (max_atoms_overall_dataset - 1) // 2

    for item in all_raw_molecules_data:
        # Pad tokens to the global max_seq_len
        padded_mol_tokens = item['tokens'].tolist() + \
                            [ATOM_VOCAB['PAD']] * (max_seq_len_overall_dataset - len(item['tokens']))

        # Create the fixed-size connectivity matrix for each molecule
        connectivity_target_tensor = torch.zeros(num_possible_bonds_overall, dtype=torch.float)
        base_connectivity = item['base_connectivity']
        num_current_atoms = item['num_base_atoms']

        bond_idx = 0
        for i in range(max_atoms_overall_dataset): # Iterate up to the largest possible molecule size
            for j in range(i + 1, max_atoms_overall_dataset):
                if i < num_current_atoms and j < num_current_atoms: # Only for atoms actually present in this molecule
                    if (i, j) in base_connectivity:
                        connectivity_target_tensor[bond_idx] = 1.0
                bond_idx += 1

        final_dataset_items.append({
            'tokens': torch.tensor(padded_mol_tokens, dtype=torch.long),
            'polarizability': item['polarizability'],
            'connectivity_matrix': connectivity_target_tensor
        })

    dataset = MoleculeDataset(final_dataset_items)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    # Initialize model
    model = MoleculeGeneratorLLM(
        vocab_size=TOTAL_VOCAB_SIZE, # Use the dynamically updated TOTAL_VOCAB_SIZE
        d_model=D_MODEL,
        nhead=NHEAD,
        num_decoder_layers=NUM_DECODER_LAYERS,
        dim_feedforward=DIM_FEEDFORWARD,
        dropout=DROPOUT,
        max_seq_len=max_seq_len_overall_dataset, # Use the globally determined max seq length
        num_gaussians_prop=NUM_GAUSSIANS_PROP,
        prop_hidden_dim=PROP_HIDDEN_DIM,
        prop_embedding_dim=PROP_EMBEDDING_DIM,
        max_atoms_in_dataset_overall=max_atoms_overall_dataset, # Use the globally determined max atoms
        prop_min_val=min_polarizability_overall - POL_NOISE_STD * 2, # Adjust range based on observed data
        prop_max_val=max_polarizability_overall + POL_NOISE_STD * 2
    ).to(device)

    # Loss functions
    token_criterion = nn.CrossEntropyLoss(ignore_index=ATOM_VOCAB['PAD'])
    connectivity_criterion = nn.BCEWithLogitsLoss()

    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # Training loop
    print("Starting training...")
    for epoch in range(NUM_EPOCHS):
        model.train()
        total_token_loss = 0
        total_connectivity_loss = 0
        total_overall_loss = 0

        for batch_idx, (tokens, polarizabilities, connectivity_matrices) in enumerate(dataloader):
            tokens, polarizabilities, connectivity_matrices = \
                tokens.to(device), polarizabilities.to(device), connectivity_matrices.to(device)

            input_seq = tokens[:, :-1]
            target_seq = tokens[:, 1:]

            optimizer.zero_grad()
            token_logits, connectivity_logits = model(input_seq, polarizabilities)

            token_loss = token_criterion(token_logits.reshape(-1, token_logits.size(-1)), target_seq.reshape(-1))
            connectivity_loss = connectivity_criterion(connectivity_logits, connectivity_matrices)

            overall_loss = token_loss + (BOND_LOSS_WEIGHT * connectivity_loss)
            overall_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_token_loss += token_loss.item()
            total_connectivity_loss += connectivity_loss.item()
            total_overall_loss += overall_loss.item()

        avg_token_loss = total_token_loss / len(dataloader)
        avg_connectivity_loss = total_connectivity_loss / len(dataloader)
        avg_overall_loss = total_overall_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Overall Loss: {avg_overall_loss:.4f}, "
              f"Token Loss: {avg_token_loss:.4f}, Connectivity Loss: {avg_connectivity_loss:.4f}")

    print("Training complete.")
    torch.save(model, 'gpt2.pt')

train_model()


Using device: cuda
Reading PDB files from 'xyz_files'...
Parsing xyz_files/monomer_6445001.pdb...
Adding new atom type 'CL' to vocabulary.
Parsing xyz_files/monomer_170898663.pdb...
Parsing xyz_files/monomer_12627872.pdb...
Parsing xyz_files/monomer_24229203.pdb...
Adding new atom type 'BR' to vocabulary.
Parsing xyz_files/monomer_91684808.pdb...
Parsing xyz_files/monomer_66794870.pdb...
Parsing xyz_files/monomer_117693409.pdb...
Parsing xyz_files/monomer_118874022.pdb...
Parsing xyz_files/monomer_2250420.pdb...
Parsing xyz_files/monomer_45029402.pdb...
Parsing xyz_files/monomer_46909798.pdb...
Parsing xyz_files/monomer_131733914.pdb...
Parsing xyz_files/monomer_68742092.pdb...
Parsing xyz_files/monomer_54455193.pdb...
Parsing xyz_files/monomer_23132660.pdb...
Parsing xyz_files/monomer_118654865.pdb...
Parsing xyz_files/monomer_91684551.pdb...
Parsing xyz_files/monomer_15330366.pdb...
Parsing xyz_files/monomer_11321221.pdb...
Parsing xyz_files/monomer_733139.pdb...
Parsing xyz_files/mo



Epoch 1/50, Overall Loss: 2.9107, Token Loss: 2.8560, Connectivity Loss: 0.1092
Epoch 2/50, Overall Loss: 2.0736, Token Loss: 2.0711, Connectivity Loss: 0.0051
Epoch 3/50, Overall Loss: 1.8614, Token Loss: 1.8590, Connectivity Loss: 0.0048
Epoch 4/50, Overall Loss: 1.7749, Token Loss: 1.7726, Connectivity Loss: 0.0047
Epoch 5/50, Overall Loss: 1.7278, Token Loss: 1.7255, Connectivity Loss: 0.0045
Epoch 6/50, Overall Loss: 1.6925, Token Loss: 1.6903, Connectivity Loss: 0.0045
Epoch 7/50, Overall Loss: 1.6649, Token Loss: 1.6627, Connectivity Loss: 0.0044
Epoch 8/50, Overall Loss: 1.6448, Token Loss: 1.6426, Connectivity Loss: 0.0043
Epoch 9/50, Overall Loss: 1.6280, Token Loss: 1.6259, Connectivity Loss: 0.0042
Epoch 10/50, Overall Loss: 1.6135, Token Loss: 1.6115, Connectivity Loss: 0.0041
Epoch 11/50, Overall Loss: 1.6017, Token Loss: 1.5997, Connectivity Loss: 0.0041
Epoch 12/50, Overall Loss: 1.5909, Token Loss: 1.5889, Connectivity Loss: 0.0040
Epoch 13/50, Overall Loss: 1.5814, To

In [5]:
import re
import numpy as np

# --- 5. Generation Example ---
model = torch.load('gpt2.pt', weights_only=False)
print("\n--- Generating new molecules ---")
# Example target polarizability values for generation
# Try values within the range of your dataset's polarizabilities
target_polarizabilities = [
    235.9034334693333
]

for target_pol in target_polarizabilities:
    print(f"\nGenerating molecule with target polarizability: {target_pol:.2f}")
    generated_mol_str, predicted_bonds = model.generate(
        polarizability=target_pol,
        max_new_tokens=914,
        temperature=0.8,
        top_k=50
    )
    print(f"Generated Sequence: {generated_mol_str}")
    print(f"Predicted Bonds: {', '.join(predicted_bonds) if predicted_bonds else 'None'}")

def build_connectivity_matrix(generated_sequence, predicted_bonds):
    # Count atoms
    atom_matches = re.findall(r"Atom \d+:", generated_sequence)
    N = len(atom_matches)
    
    # Parse bonds
    bonds = []
    if predicted_bonds:
        if isinstance(predicted_bonds, str):
            bond_strs = predicted_bonds.split(', ')
        else:  # Assume list
            bond_strs = predicted_bonds
        for b in bond_strs:
            match = re.match(r"\((\d+)-(\d+)\)", b)
            if match:
                i, j = int(match.group(1)), int(match.group(2))
                bonds.append((min(i,j), max(i,j)))
    
    # Build matrix
    matrix = np.zeros((N, N), dtype=int)
    for i, j in bonds:
        if i < N and j < N:
            matrix[i, j] = 1
            matrix[j, i] = 1
    
    return matrix

mat = build_connectivity_matrix(generated_mol_str, predicted_bonds)


--- Generating new molecules ---

Generating molecule with target polarizability: 235.90




Generated Sequence: START Atom 0: C, Coords: (-9.75, -7.50, -11.25) Atom 1: C, Coords: (-9.00, -7.50, -10.50) Atom 2: O, Coords: (-9.00, -7.50, -11.25) Atom 3: C, Coords: (-8.25, -7.50, -10.50) Atom 4: O, Coords: (-7.50, -7.50, -9.75) Atom 5: C, Coords: (-8.25, -7.50, -11.25) Atom 6: C, Coords: (-9.75, -7.50, -12.00) Atom 7: C, Coords: (-12.75, -8.25, -10.50) Atom 8: C, Coords: (-12.00, -9.75, -9.75) Atom 9: C, Coords: (-11.25, -10.50, -8.25) Atom 10: C, Coords: (-12.00, -12.00, -8.25) Atom 11: C, Coords: (-12.75, -12.75, -8.25) Atom 12: C, Coords: (-12.75, -13.50, -7.50) Atom 13: BR, Coords: (-12.75, -13.50, -9.00) Atom 14: O, Coords: (-12.75, -14.25, -12.00) Atom 15: C, Coords: (-13.50, -14.25, -12.75) Atom 16: C, Coords: (-13.50, -13.50, -12.75) Atom 17: C, Coords: (-14.25, -13.50, -13.50) Atom 18: C, Coords: (-15.00, -15.00, -15.00) Atom 19: C, Coords: (-14.25, -15.00, -14.25) Atom 20: C, Coords: (-15.00, 14.25, -14.25) Atom 21: C, Coords: (-15.00, 14.25, -12.75) Atom 22: C, Coords

In [12]:
mat[:20, :20]

array([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0,