In [None]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from rdkit import Chem
from rdkit.Chem import rdmolops
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Function to generate the adjacency matrix with diagonal = 1
def generate_adjacency_matrix(mol):
    num_atoms = mol.GetNumAtoms()
    adj_matrix = np.zeros((num_atoms, num_atoms), dtype=np.float32)

    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        bond_type = bond.GetBondType()

        # Map bond types to numeric values
        if bond_type == Chem.rdchem.BondType.SINGLE:
            adj_matrix[i, j] = 1.0
            adj_matrix[j, i] = 1.0
        elif bond_type == Chem.rdchem.BondType.DOUBLE:
            adj_matrix[i, j] = 2.0
            adj_matrix[j, i] = 2.0
        elif bond_type == Chem.rdchem.BondType.AROMATIC:
            adj_matrix[i, j] = 1.5
            adj_matrix[j, i] = 1.5
        elif bond_type == Chem.rdchem.BondType.TRIPLE:
            adj_matrix[i, j] = 3.0
            adj_matrix[j, i] = 3.0

    np.fill_diagonal(adj_matrix, 1.0)  # Add self-loops
    return adj_matrix

# Function to generate the feature matrix (61) 
def generate_feature_matrix(mol):
    import torch
    from rdkit import Chem

    atom_types = ['As', 'B', 'Br', 'C', 'Cl', 'F', 'I', 'N', 'O', 'P', 'S', 'Se', 'Si']
    feature_matrix = []

    valence_electrons = {
        'As': 5, 'B': 3, 'Br': 7, 'C': 4, 'Cl': 7, 'F': 7, 'I': 7,
        'N': 5, 'O': 6, 'P': 5, 'S': 6, 'Se': 6, 'Si': 4
    }

    for atom in mol.GetAtoms():
        features = []

        # Atom type (13)
        atom_symbol = atom.GetSymbol()
        features.extend([1 if atom_symbol == t else 0 for t in atom_types])

        # Number of implicit hydrogens (max: 4)
        num_hydrogens = atom.GetTotalNumHs()
        features.extend([1 if num_hydrogens == i else 0 for i in range(5)])

        # Aromaticity (1)
        features.append(1 if atom.GetIsAromatic() else 0)

        # Hybridization (3)
        hybridization = atom.GetHybridization()
        features.extend([
            1 if hybridization == Chem.rdchem.HybridizationType.SP else 0,
            1 if hybridization == Chem.rdchem.HybridizationType.SP2 else 0,
            1 if hybridization == Chem.rdchem.HybridizationType.SP3 else 0
        ])

        # Formal charge (considered -3 to +3)
        formal_charge = atom.GetFormalCharge()
        features.extend([1 if formal_charge == i else 0 for i in range(-3, 4)])

        # Valence electrons (numerical, one-hot encoding for 1-8)
        valence = valence_electrons.get(atom_symbol, 0)
        features.extend([1 if valence == i else 0 for i in range(1, 9)])

        # Atom degree (one-hot encoding for 0-4)
        degree = atom.GetDegree()
        features.extend([1 if degree == i else 0 for i in range(5)])

        # Implicit valence (one-hot encoding for 0-8)
        implicit_valence = atom.GetImplicitValence()
        features.extend([1 if implicit_valence == i else 0 for i in range(9)])

        # Chirality (binary)
        features.append(1 if atom.HasProp('_CIPCode') else 0)

        # Is heteroatom (binary)
        features.append(1 if atom.GetAtomicNum() not in [1, 6] else 0)

        # Is in ring (binary)
        features.append(1 if atom.IsInRing() else 0)

        # Total number of bonds (one-hot encoding for 1-6)
        num_bonds = len(atom.GetBonds())
        features.extend([1 if num_bonds == i else 0 for i in range(1, 7)])

        # Is terminal atom (binary)
        features.append(1 if atom.GetDegree() == 1 else 0)

        feature_matrix.append(features)

    return torch.tensor(feature_matrix, dtype=torch.float32)


# MolecularDataset class       
class MolecularDataset(Dataset):
    def __init__(self, csv_file, target_col, max_atoms=460):
        data = pd.read_csv(csv_file)
        
        # Drop rows where SMILES is NaN or invalid
        data = data.dropna(subset=['SMILES'])
        data = data[data['SMILES'].apply(lambda x: isinstance(x, str))]
        
        self.smiles = []
        self.targets = []
        self.max_atoms = max_atoms  # Fixed maximum number of atoms
        
        for i, row in data.iterrows():
            smiles = row['SMILES']
            mol = Chem.MolFromSmiles(smiles)
            if mol:  # Only keep valid SMILES
                self.smiles.append(smiles)
                self.targets.append(row[target_col])
        
        self.targets = torch.tensor(self.targets, dtype=torch.float32)

    def __len__(self):
        return len(self.smiles)

    def __getitem__(self, idx):
        smiles = self.smiles[idx]
        mol = Chem.MolFromSmiles(smiles)
  
        adj = generate_adjacency_matrix(mol)
        features = generate_feature_matrix(mol)
        
        # Pad adjacency matrix to (max_atoms, max_atoms)
        padded_adj = np.zeros((self.max_atoms, self.max_atoms), dtype=np.float32)
        num_atoms = adj.shape[0]
        padded_adj[:num_atoms, :num_atoms] = adj
  
        # Add self-loops and normalize
        adj_hat = padded_adj + np.eye(self.max_atoms, dtype=np.float32)
        D_hat_inv_sqrt = np.diag(np.power(np.sum(adj_hat, axis=1), -0.5, where=np.sum(adj_hat, axis=1) != 0))
        adj_normalized = np.matmul(np.matmul(D_hat_inv_sqrt, adj_hat), D_hat_inv_sqrt)
  
        # Pad feature matrix to (max_atoms, feature_dim)
        padded_features = np.zeros((self.max_atoms, 61), dtype=np.float32)
        padded_features[:num_atoms, :] = features
  
        target = self.targets[idx]
        return (
          torch.tensor(adj_normalized, dtype=torch.float32).to(device),
          torch.tensor(padded_features, dtype=torch.float32).to(device),
          target.to(device),
          torch.tensor(num_atoms, dtype=torch.int64).to(device) 
          )

# GCN Layer (ReLU)
#class GCNLayer(nn.Module):
   # def __init__(self, input_dim, output_dim):
   #     super(GCNLayer, self).__init__()
   #     self.linear = nn.Linear(input_dim, output_dim)

   # def forward(self, A_normalized, H):
        # Apply the linear transformation
   #     HW = self.linear(H)
        
   #     # Compute the new output with the added adjacency matrix term
   #     output = torch.matmul(A_normalized, HW)
        
        # Apply ReLU activation function
   #     return torch.relu(output)

# GCN Layer (LeakyReLU)
class GCNLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim, bias=False)  # No bias in the linear layer
        self.bias = nn.Parameter(torch.zeros(1, output_dim))  # Trainable bias matrix
        # self.batch_norm = nn.BatchNorm1d(output_dim)  # Batch normalization
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.01)  # LeakyReLU activation

    def forward(self, A_normalized, H):
        # Apply the linear transformation
        HW = self.linear(H)
        
        # Add the bias matrix
        HW_plus_B = HW + self.bias
        
        # Apply batch normalization
        # HW_plus_B = self.batch_norm(HW_plus_B)
        
        # Compute the new output with the bias matrix
        output = torch.matmul(A_normalized, HW_plus_B)
        
        # Apply LeakyReLU activation function
        return self.leaky_relu(output)

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# GCN Model
class GCN(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GCN, self).__init__()
        self.gcn1 = GCNLayer(input_dim, hidden_dim)
        self.gcn2 = GCNLayer(hidden_dim, hidden_dim)
        self.gcn3 = GCNLayer(hidden_dim, hidden_dim)
        self.gcn4 = GCNLayer(hidden_dim, hidden_dim)
        # self.gcn5 = GCNLayer(hidden_dim_4, hidden_dim_5)
        # self.gcn6 = GCNLayer(hidden_dim_5, hidden_dim_6)
        # self.gcn7 = GCNLayer(hidden_dim_6, hidden_dim_7)
        # self.gcn8 = GCNLayer(hidden_dim_7, hidden_dim_8)
        # self.fc = nn.Linear(hidden_dim_8, output_dim)

    def forward(self, A_normalized, X, num_atoms):
        H = self.gcn1(A_normalized, X)
        # print(f"After GCN1: {H.shape}")
        
        H = self.gcn2(A_normalized, H)
        # print(f"After GCN2: {H.shape}")
        
        H = self.gcn3(A_normalized, H)
        # print(f"After GCN3: {H.shape}")
        
        H = self.gcn4(A_normalized, H)
        # print(f"After GCN4: {H.shape}")
        
        # Pooling over only valid atoms
        mask = torch.arange(H.size(1), device=H.device).expand(H.size(0), -1) < num_atoms.unsqueeze(1)
        H_masked = H * mask.unsqueeze(2)  # Zero out invalid rows
        embeddings = H_masked.sum(dim=1) / num_atoms.unsqueeze(1).float()

        
        # out = self.fc(H)
        # print(f"Final output: {out.shape}")
        
        return embeddings

def weighted_mse_loss(y_pred, y_true, low_energy_weight=10.0):
    """
    Custom weighted MSE loss that applies higher weights to the low-energy region.
    """
   # Compute weights inversely proportional to true values
    weights = torch.where(y_true < 0.5, low_energy_weight, 1.0)  # Weight low-energy (<0.5) more heavily
    loss = weights * (y_pred - y_true) ** 2
    return loss.mean()
    
# def weighted_mse_loss(y_pred, y_true, low_energy_weight=10.0):
   #  weights = torch.where(y_true < 0.8, low_energy_weight, 1.0)
   #  mse_loss = weights * (y_pred - y_true) ** 2
   #  negative_penalty = torch.sum(torch.relu(-y_pred))  # Penalize negative predictions
   #  return mse_loss.mean() + 0.1 * negative_penalty  # Add penalty term with weight


# Main Training and Validation Function
def train_and_validate(gcn_model, mlp_model, train_loader, val_loader, optimizer, epochs, low_energy_weight=10.0):
    train_mse_list = []
    val_mse_list = []

    for epoch in range(epochs):
        gcn_model.train()
        mlp_model.train()
        train_loss = 0.0

        for A, X, y, num_atoms in train_loader:
            A, X, y, num_atoms = A.to(device), X.to(device), y.to(device), num_atoms.to(device)
            optimizer.zero_grad()
            
            # GCN Embedding
            embeddings = gcn_model(A, X, num_atoms)
            
            # Prediction from MLP
            y_pred = mlp_model(embeddings)
            
            # Loss and Backpropagation
            loss = weighted_mse_loss(y_pred.squeeze(), y, low_energy_weight)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_mse_list.append(train_loss / len(train_loader))

        # Validation Phase
        gcn_model.eval()
        mlp_model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for A, X, y, num_atoms in val_loader:
                A, X, y, num_atoms = A.to(device), X.to(device), y.to(device), num_atoms.to(device)
                embeddings = gcn_model(A, X, num_atoms)
                y_pred = mlp_model(embeddings)
                loss = weighted_mse_loss(y_pred.squeeze(), y, low_energy_weight)
                val_loss += loss.item()

        val_mse_list.append(val_loss / len(val_loader))
        scheduler.step(val_loss / len(val_loader))  # Update scheduler
        
        print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss / len(train_loader):.4f}, "
              f"Val Loss: {val_loss / len(val_loader):.4f}")

    return train_mse_list, val_mse_list


# Test set evaluation   
def evaluate_test_set(model, test_loader):
    model.eval()
    y_true = []
    y_pred = []
    
    with torch.no_grad():
        for A, X, y, num_atoms in test_loader:
            A, X, y, num_atoms = A.to(device), X.to(device), y.to(device), num_atoms.to(device)
            predictions = model(A, X, num_atoms).squeeze()
            y_true.extend(y.view(-1).cpu().numpy())  # Move to CPU for numpy conversion
            y_pred.extend(predictions.view(-1).cpu().numpy())  # Move to CPU for numpy conversion

    return np.array(y_true), np.array(y_pred)

# Main
if __name__ == "__main__":
    # Hyperparameters
    input_dim = 61  # Feature size 
    hidden_dim_gcn = 100
    hidden_dim_mlp = 300
    output_dim = 1
    batch_size = 15
    epochs = 1
    learning_rate = 0.0001
    low_energy_weight =1.0
    
    # Dataset
    train_dataset = MolecularDataset("/home/george/TADF/gcn/training_set.csv", target_col="ST_split")
    val_dataset = MolecularDataset("/home/george/TADF/gcn/validation_set.csv", target_col="ST_split")
    test_dataset = MolecularDataset("/home/george/TADF/gcn/testing_set.csv", target_col="ST_split")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Model, Loss, Optimizer, and Scheduler
    gcn_model = GCN(input_dim, hidden_dim_gcn).to(device)
    mlp_model = MLP(hidden_dim_gcn, hidden_dim_mlp, output_dim).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(list(gcn_model.parameters()) + list(mlp_model.parameters()), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)


    # Train and validate
    train_mse, val_mse = train_and_validate(gcn_model, mlp_model, train_loader, val_loader, optimizer, epochs, low_energy_weight)
    

    # Save MSE values to CSV
    mse_df = pd.DataFrame({
        "Epoch": list(range(1, epochs + 1)),
        "Train_MSE": train_mse,
        "Validation_MSE": val_mse
    })
    mse_csv_path = "mse_train_validation_gcn_nn.csv"
    mse_df.to_csv(mse_csv_path, index=False)
    print(f"MSE values saved to '{mse_csv_path}'")
    
    # Test set evaluation and R-squared
    y_true, y_pred = evaluate_test_set(model, test_loader)
    r2 = r2_score(y_true, y_pred)

    # Save the trained model
    torch.save(model.state_dict(), "gcn_nn.pth")
    print("Model saved to 'gcn_nn.pth'")
    
    # Save MSE values and plots
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, epochs + 1), train_mse, label="Train MSE")
    plt.plot(range(1, epochs + 1), val_mse, label="Validation MSE")
    plt.xlabel("Epochs")
    plt.ylabel("Mean Squared Error (eV$^{2}$)")
    plt.title("Train and Validation MSE Over Epochs")
    plt.legend()
    plt.savefig("train_val_mse_gcn_nn.png")
    print("Train and Validation MSE plot saved as 'train_val_mse_gcn_nn.png'")

    # Scatter plot for test set predictions
    plt.figure(figsize=(10, 5))
    plt.scatter(y_true, y_pred, alpha=0.7, label="Predicted vs Actual")
    plt.plot([min(y_true), max(y_true)], [min(y_true), max(y_true)], color="red", linestyle="--", label="Diagonal")
    plt.xlabel("Actual Values (eV)")
    plt.ylabel("Predicted Values (eV)")
    plt.title("Testing Set Predictions")
    #plt.legend()
    plt.text(0.05, 0.95, f"$R^2$: {r2:.2f}", transform=plt.gca().transAxes, fontsize=12, verticalalignment='top')
    plt.savefig("test_predictions_gcn_nn.png")
    print("Test set predictions plot saved as 'test_predictions_gcn_nn.png'")


    # Save MSE values to CSV
    mse_df = pd.DataFrame({"Epoch": list(range(1, epochs + 1)), "MSE": mse_list})
    mse_csv_path = "mse_over_epochs_gcn_nn.csv"
    mse_df.to_csv(mse_csv_path, index=False)
    print(f"MSE values saved to '{mse_csv_path}'")
    
