In [7]:
import os
from rdkit import Chem
import numpy as np
from sklearn.preprocessing import OneHotEncoder
from PIL import Image
import pandas as pd
from tqdm import tqdm

# Directories
pdb_folder = 'xyz_files'
img_folder = 'images_for_training'
os.makedirs(img_folder, exist_ok=True)

# Initialize CSV data
csv_data = []

# Get max number of atoms across all PDB files
max_atoms = 0
for file in tqdm(os.listdir(pdb_folder)):
    if file.endswith('.pdb'):
        mol = Chem.MolFromPDBFile(os.path.join(pdb_folder, file), removeHs=True)
        if mol:
            max_atoms = max(max_atoms, mol.GetNumAtoms())

# Initialize one-hot encoder
all_features = []
for file in os.listdir(pdb_folder):
    if file.endswith('.pdb'):
        mol = Chem.MolFromPDBFile(os.path.join(pdb_folder, file), removeHs=True)
        if mol:
            atomic_nums = [atom.GetAtomicNum() for atom in mol.GetAtoms()]
            degrees = [atom.GetDegree() for atom in mol.GetAtoms()]
            features = np.array([atomic_nums, degrees]).T
            all_features.append(features)
all_features = np.vstack(all_features) if all_features else np.array([])
encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
encoder.fit(all_features)

# Process each PDB file
for file in tqdm(os.listdir(pdb_folder)):
    if file.endswith('.pdb'):
        mol = Chem.MolFromPDBFile(os.path.join(pdb_folder, file), removeHs=True)
        if mol:
            # Adjacency matrix
            adj_matrix = Chem.GetAdjacencyMatrix(mol)
            n_atoms = adj_matrix.shape[0]
            padded_adj = np.zeros((max_atoms, max_atoms))
            padded_adj[:n_atoms, :n_atoms] = adj_matrix
            adj_img = (padded_adj * 255).astype(np.uint8)
            adj_filename = f"{file}_adj.png"
            Image.fromarray(adj_img, mode='L').save(os.path.join(img_folder, adj_filename))

            # Feature matrix
            atomic_nums = [atom.GetAtomicNum() for atom in mol.GetAtoms()]
            degrees = [atom.GetDegree() for atom in mol.GetAtoms()]
            features = np.array([atomic_nums, degrees]).T
            binary_matrix = encoder.transform(features)
            padded_binary = np.zeros((max_atoms, binary_matrix.shape[1]))
            padded_binary[:n_atoms] = binary_matrix
            binary_img = (padded_binary * 255).astype(np.uint8)
            binary_filename = f"{file}_binary.png"
            Image.fromarray(binary_img, mode='L').save(os.path.join(img_folder, binary_filename))

            # Extract polarizability from PDB file
            polarizability = None
            with open(os.path.join(pdb_folder, file), 'r') as f:
                for line in f:
                    if line.startswith('REMARK static_polarizability'):
                        polarizability = float(line.split()[-1])
                        break
            if polarizability is not None:
                csv_data.append([adj_filename, polarizability])
                csv_data.append([binary_filename, polarizability])

# Save CSV
pd.DataFrame(csv_data, columns=['string', 'polarizability']).to_csv(os.path.join(img_folder, 'polarizability.csv'), index=False)

  0%|          | 0/57450 [00:00<?, ?it/s][13:46:30] Explicit valence for atom # 1 C, 12, is greater than permitted
[13:46:30] Explicit valence for atom # 2 O, 3, is greater than permitted
[13:46:30] Explicit valence for atom # 4 O, 3, is greater than permitted
[13:46:30] Explicit valence for atom # 11 O, 3, is greater than permitted
[13:46:30] Explicit valence for atom # 2 O, 3, is greater than permitted
[13:46:30] Explicit valence for atom # 0 C, 6, is greater than permitted
[13:46:30] Explicit valence for atom # 1 O, 3, is greater than permitted
[13:46:30] Explicit valence for atom # 1 O, 3, is greater than permitted
[13:46:30] Explicit valence for atom # 27 O, 4, is greater than permitted
[13:46:30] Explicit valence for atom # 1 C, 8, is greater than permitted
[13:46:30] Explicit valence for atom # 21 C, 5, is greater than permitted
[13:46:30] Explicit valence for atom # 3 C, 8, is greater than permitted
[13:46:30] Explicit valence for atom # 23 O, 4, is greater than permitted
[13:4