# TransGNN_DTA Drug-Target Affinity Prediction Model User Guide  
> A drug-target affinity prediction model based on Transformer and GNN, supporting multi-dataset training and prediction.  
> Author: Quietpeng ([GitHub](https://github.com/Quietpeng)), related papers have been published (please refer to the project README for citation).  

## I. Environment Setup and Project Initialization  
### 1.1 Clone the Project Repository  

In [None]:
# Clone the repository (skip if already cloned)  
!git clone https://github.com/Quietpeng/TransGNN_DTA.git  
%cd TransGNN_DTA  

### 1.2 Create and Activate a Virtual Environment (Recommended Best Practice)  


In [None]:
# Create a Python virtual environment (based on Python 3.12)  
!python -m venv transgnn_env  
# Activate the environment (Linux/macOS)  
!source transgnn_env/bin/activate  
# For Windows: .\transgnn_env\Scripts\activate  

### 1.3 Install Dependencies 

In [None]:
# Install project dependencies (install PyTorch and CUDA in advance based on hardware configuration)  
!pip install -r requirements.txt  
# Optional: Install background management tools (for remote training monitoring)  
!sudo apt install screen -y && sudo apt update  

## II. Model Training Workflow  
### 2.1 Configuration Instructions  
**Command-Line Configuration Information**  
| Parameter Name       | Type    | Default Value   | Description                                                                 |  
|----------------------|---------|-----------------|-----------------------------------------------------------------------------|  
| `b`                  | int     | 32              | Training batch size, adjust according to GPU memory (e.g., 64 for 32GB GPU) |  
| `epochs`             | int     | 200             | Maximum number of training epochs, used with early stopping                |  
| `dataset`            | str     | `raw_davis`     | Dataset selection, supports `raw_davis`/`raw_kiba`/`benchmark_davis`, etc. |  
| `lr`                 | float   | 5e-4            | Initial learning rate, dynamically adjusted with AdamW optimizer and scheduler |  
| `model_config`       | str     | `config.json`   | Path to model configuration file, containing key parameters like embedding dimension and maximum sequence length |  

**Hyperparameter Configuration File Path**: `config.json` (adjust `drug_max_seq` and `target_max_seq` according to the dataset in advance)  

### 2.2 Quick Start Training (Recommended for Background Execution)  

In [None]:
# Method 1: Start with default parameters (output results to result.log)  
!python train_reg.py &> result.log  

In [None]:
# Method 2: Start with specified parameters (example: use raw_kiba dataset, batch size 128)  
!python train_reg.py --dataset raw_kiba --batchsize 128 --lr 1e-4 &> result.log  

**Notes**:  
- Training is recommended to use a GPU (e.g., 32GB vGPU), as laptops may crash due to insufficient resources.  
- Use the `screen` tool for background execution:  

In [None]:
!screen -S transgnn_train  # Create a new session  
  # After executing the training command, press Ctrl+A+D to exit the session  
!screen -r transgnn_train   # Resume the session 

### 2.3 Training Monitoring and Visualization  
**Visualization Address**: http://localhost:6006 (local) or server public IP:6006 (firewall needs to be open)  

In [None]:
# Start TensorBoard monitoring (default port 6006, install screen in advance)  
!screen -dmS tensorboard bash -c 'tensorboard --logdir=log --host=0.0.0.0'  
# For remote access, configure port forwarding: ssh -L 6006:localhost:6006 user@server_ip  

## III. Model Prediction Workflow  
### 3.1 Load Pretrained Model and Configuration  

In [None]:
import torch  
import json  
from preprocess import drug_encoder, target_encoder  
from double_towers import TransGNNModel  

# Example data (drug SMILES and protein sequence)  
DRUG_EXAMPLE = "CC1=C(C=C(C=C1)NC2=NC=CC(=N2)N(C)C3=CC4=NN(C(=C4C=C3)C)C)S(=O)(=O)Cl"  
PROTEIN_EXAMPLE = "MRGARGAWDFLCVLLLLLRVQTGSSQPSVSPGEPSPPSIHPGKSDLIVRVGDEIRLLCTDP"  

# Paths to configuration file and model  
MODEL_CONFIG_PATH = "config.json"  
MODEL_PATH = "./models/DAVIS_bestCI_model_reg1.pth"  # Replace with the actual path of the trained model  

### 3.2 Model Initialization and Device Configuration  

In [None]:
# Load model configuration  
model_config = json.load(open(MODEL_CONFIG_PATH, 'r'))  
# Initialize the model  
model = TransGNNModel(model_config)  
# Check for available GPU  
use_cuda = torch.cuda.is_available()  
device = torch.device('cuda:0' if use_cuda else 'cpu')  
model = model.to(device)  

# Load trained model weights  
try:  
    checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=True)  
    # Adjust the input dimension of the decoder layer  
    model.decoder[0] = nn.Linear(list(checkpoint['decoder.0.weight'].shape)[1], model.decoder[0].out_features).to(device)  
    model.load_state_dict(checkpoint)  
    print("Successfully loaded model state from checkpoint.")  
except Exception as e:  
    print(f"Error loading checkpoint: {e}")  

model.eval()  

### 3.3 Data Preprocessing and Prediction

In [None]:
# Sequence encoding (returns feature vectors and masks)  
d_out, mask_d_out = drug_encoder(DRUG_EXAMPLE)  
t_out, mask_t_out = target_encoder(PROTEIN_EXAMPLE)  

# Convert to tensors and move to device  
d_tensor = torch.LongTensor(d_out).unsqueeze(0).to(device)  
mask_d_tensor = torch.LongTensor(mask_d_out).unsqueeze(0).to(device)  
t_tensor = torch.LongTensor(t_out).unsqueeze(0).to(device)  
mask_t_tensor = torch.LongTensor(mask_t_out).unsqueeze(0).to(device)  

# Perform prediction  
with torch.no_grad():  
    prediction = model(d_tensor, t_tensor, mask_d_tensor, mask_t_tensor).cpu().numpy()  

print(f"Predicted Affinity Value: {prediction[0][0]:.4f}")  # Output formatted to 4 decimal places  

## IV. Advanced Features  
### 4.1 Early Stopping and Email Notification  
- **Early Stopping Configuration**: Set `early_stopping_patience` in `config.json` (default: 20 epochs)  
- **Email Notification**: Auto-notify upon training completion/failure after enabling email configuration  
  ```json  
  "email": {  
    "enabled": true,  
    "sender_email": "your_email@example.com",  
    "sender_password": "authorization_code",  
    "receiver_email": "recipient@example.com",  
    "smtp_server": "smtp.qq.com",  # Example for QQ Mail  
    "smtp_port": 465  
  }  
  ```  

### 4.2 Multi-GPU Training Support  
To use multi-GPU training, modify the data loading part in `train_reg.py`:  
```python  
# Add DistributedDataParallel support (example)  
if torch.cuda.device_count() > 1:  
    model = torch.nn.DataParallel(model)  
```  

## V. Citation Suggestions  
If using this model for research, please cite it in your paper  