In [None]:
import json
import os
import torch
from torch_geometric.data import Data, Dataset, Batch
import pandas as pd
import numpy as np
import random
import pickle
from pathlib import Path
import logging
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class MoleculeGraphDataset(Dataset):
    def __init__(self, similar_mol_path, final_df, batch_size = 512):
        super().__init__()
        with open(similar_mol_path, "r") as f:
            self.similar_mols = json.load(f)
        self.mol_list = list(self.similar_mols.keys())
        self.final_df = final_df
        self.batch_size = batch_size
        
    def __len__(self):
        return len(self.mol_list)
    
    def __getitem__(self, idx):
        anchor = self.mol_list[idx]
        
        for _ in range(20):
            try:
                positive = random.choice(self.similar_mols[anchor])
                positive_data = self.final_df[positive]
                break
            except:
                continue
        else:
            raise ValueError(f"Failed to find valid positive for anchor: {anchor}")

        try:
            anchor_data = self.final_df[anchor]
        except:
            if idx + 1 < len(self.mol_list):
                anchor = self.mol_list[idx + 1]
                anchor_data = self.final_df[anchor]
            else:
                raise IndexError("Anchor index out of range and fallback failed")

        negative_data = []
        candidates = [mol for mol in self.mol_list if mol not in [anchor, positive]]
        random.shuffle(candidates)
        
        for neg in candidates:
            if len(negative_data) >= self.batch_size - 2:
                break
            try:
                negative_data.append(self.final_df[neg])
            except:
                continue

        if len(negative_data) < self.batch_size - 2:
            raise ValueError("Not enough negative samples found")

        return anchor_data, positive_data, negative_data

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Batch
from torch_geometric.nn import global_max_pool

class GNNLayerSimple(nn.Module):
    def __init__(self, node_dim, edge_dim, hidden_dim):
        super().__init__()
        self.node_mlp = nn.Sequential(
            nn.Linear(node_dim + edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, node_dim)
        )

        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * node_dim + edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, edge_dim),
        )

    def forward(self, h, edge_index, edge_attr):
        row, col = edge_index
        node_feat = torch.cat([h[row], edge_attr], dim=1)
        delta_h = self.node_mlp(node_feat)

        edge_feat = torch.cat([h[row], h[col], edge_attr], dim=1)
        delta_edge = self.edge_mlp(edge_feat)

        h = h + torch.zeros_like(h).scatter_add(0, row.unsqueeze(1).expand_as(delta_h), delta_h)
        edge_attr = edge_attr + delta_edge
        
        return h, edge_attr


class AttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.q = nn.Parameter(torch.randn(1, dim))  
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.scale = dim ** -0.5

    def forward(self, h):
        K = self.k_proj(h)  
        V = self.v_proj(h)  
        Q = self.q

        scores = (Q @ K.T) * self.scale  
        attn = F.softmax(scores, dim=-1) 
        out = attn @ V  

        return out  

class GNNFingerprintSimple(nn.Module):
    def __init__(self, node_dim, edge_dim, hidden_dim=256, num_layers=6, out_dim=1024):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.egnn_layers = nn.ModuleList([
            GNNLayerSimple(node_dim, edge_dim, hidden_dim) for _ in range(num_layers)
        ])

        self.projection_head = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, out_dim),
        )
        
        self.attn_pool = AttentionPooling(node_dim)

    def forward(self, data):
        t = False
        if isinstance(data, list):
            for graph in data:
                if graph.edge_index.shape[0] != 2:
                    graph.edge_index = graph.edge_index.t()
            data = Batch.from_data_list(data)
            t = True
        else:
            if data.edge_index.shape[0] != 2:
                data.edge_index = data.edge_index.t()

        h, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        for layer in self.egnn_layers:
            h, edge_attr = layer(h, edge_index, edge_attr)

        if t:
            batch_size = data.num_graphs
            pooled = torch.stack([ 
                self.attn_pool(h[data.batch == i]) for i in range(batch_size)
            ])
            pooled = pooled.squeeze(1)
        else:
            pooled = self.attn_pool(h)
        
        return self.projection_head(pooled)

In [None]:
def ntxent_loss(anchor_emb, positive_emb, negative_emb, temperature=0.1, margin=0.1):

    positive_sim = F.cosine_similarity(anchor_emb, positive_emb, dim=-1)

    negative_sim = F.cosine_similarity(anchor_emb.unsqueeze(1), negative_emb, dim=-1)

    logits = torch.cat([positive_sim.unsqueeze(1), negative_sim], dim=1)
    logits /= temperature

    labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)

    loss = F.cross_entropy(logits, labels)
    return loss

In [None]:
def training_loop(data, model, model_name, optimizer, device, temperature=0.1, batch_size=512, epochs=100):
    # Configure logging
    log_dir = "logging"
    log_filename = f"{model_name}.log"
    log_filepath = os.path.join(log_dir, log_filename)
    
    if not os.path.exists(log_dir):
        Path(log_dir).mkdir(parents=True, exist_ok=True)

    print(f"Log file path: {log_filepath}")
    
    # Logging settings
    try:
        logging.basicConfig(
            filename=log_filepath,
            level=logging.INFO,
            format="%(asctime)s - %(message)s",
        )
        logging.info("Logging initialized.")
    except Exception as e:
        print(f"Error while setting up logging: {e}")
        return

    logging.info("Logging has started...")

    logging.info(f"Training {model_name}")
    logging.info(f"Using device: {device}")
    logging.info(f"Temperature: {temperature}, Batch Size: {batch_size}, Epochs: {epochs}")

    device = torch.device(device) 
    model = model.to(device)

    n_batches = len(data) // batch_size

    for epoch in range(epochs):
        torch.cuda.empty_cache() 
        model.train() 

        total_loss = 0.0
        with tqdm(total=n_batches, desc=f"Epoch {epoch:02d}", leave=False) as pbar:
            for _ in range(n_batches):
                idx = random.randint(0, len(data)-1)
                anchor, positive, negative = data[idx]

                anchor = anchor.to(device)
                positive = positive.to(device)
                negative = [neg.to(device) for neg in negative]

                z_anchor = model(anchor)  
                z_positive = model(positive) 
                z_negative = model(negative) 
    
                loss = ntxent_loss(z_anchor, z_positive, z_negative, temperature)

                pbar.set_postfix({"loss": f"{loss.item():6.3f}"})
                pbar.update(1)

                optimizer.zero_grad(set_to_none=True)
                loss.backward()

                #monitor_gradients(model)
                
                optimizer.step()
                
                total_loss += loss.item()

        avg_loss = total_loss / n_batches
        logging.info(f"Epoch {epoch} - Avg Train Loss: {avg_loss:.4f}")
        print(f"Epoch {epoch} - Avg Train Loss: {avg_loss:.4f}")

    model_save_dir = "models"
    Path(model_save_dir).mkdir(parents=True, exist_ok=True)
    model_save_path = os.path.join(model_save_dir, f"{model_name}.pth")
    torch.save(model.state_dict(), model_save_path)

    logging.info(f"Model saved to {model_save_path}")
    print(f"Model saved to {model_save_path}")

    logging.info("Training finished.")

In [None]:
similar_mol_path = r"data\similar_mol.json"

with open("data\\final_df_dict.pkl", "rb") as f:
    final_df = pickle.load(f)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = GNNFingerprintSimple(13, 5, 512, 8, 2048).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
temperature = .1
batch_size = 512
epochs = 50    
data = MoleculeGraphDataset(similar_mol_path, final_df, batch_size= batch_size)

training_loop(data, model, "FINAL_GNN_NE_BIG_GREATER_OUTPUT", optimizer, device, temperature, batch_size, epochs)