In [144]:
import numpy as np
import pandas as pd

In [146]:
df = pd.read_csv("LP_PDBBind.csv")

In [148]:
df.head()

Unnamed: 0.1,Unnamed: 0,header,smiles,category,seq,resolution,date,type,new_split,CL1,CL2,CL3,remove_for_balancing_val,kd/ki,value,covalent
0,6r8o,isomerase,CSc1ccccc1[C@H]1CCCN1C(=O)CNC(=O)NCc1ccc2c(c1)...,refined,GNPLVYLDVDANGKPLGRVVLELKADVVPKTAENFRALCTGEKGFG...,1.36,2019-11-27,isomerase,test,True,True,True,False,Kd=0.006uM,8.22,False
1,3fh7,hydrolase/hydrolase inhibitor,O=C([O-])CCC[N@H+]1CCC[C@H]1COc1ccc(Oc2ccc(Cl)...,refined,VDTCSLASPASVCRTKHLHLRCSVDFTRRTLTGTAALTVQSQEDNL...,2.05,2010-01-05,hydrolase,test,True,True,True,False,Kd=25nM,7.6,False
2,4b7r,hydrolase,CCC(CC)O[C@@H]1C[C@H](C(=O)[O-])C[C@H]([NH3+])...,refined,VKLAGNSSLCPVSGWAIYSKDNSVRIGSKGDVFVIREPFISCSPLE...,1.9,2012-10-03,hydrolase,,True,True,True,False,Ki=0.23nM,9.64,False
3,3qfd,immune system,CC[C@H](C)[C@H](NC(=O)CNC(=O)[C@H](C)NC(=O)[C@...,refined,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,1.68,2011-09-28,other,train,False,False,False,False,Kd=68uM,4.17,False
4,3fvn,membrane protein,[NH3+][C@@H](C[C@]1(C(=O)[O-])C[C@H]2OCC[C@@H]...,refined,ANRTLIVTTILEEPYVMYRKSDKPLYGNDRFEGYCLDLLKELSNIL...,1.5,2010-01-19,membrane,val,True,True,True,False,Ki=169nM,6.77,False


In [152]:
min(df['value']), max(df['value'])

(0.4, 15.22)

In [156]:
len(np.unique(df['Unnamed: 0']))

19443

In [154]:
len(df)

19443

In [46]:
train_df = df[['smiles', 'value']]

In [50]:
train_df.head()

Unnamed: 0,smiles,value
0,CSc1ccccc1[C@H]1CCCN1C(=O)CNC(=O)NCc1ccc2c(c1)...,8.22
1,O=C([O-])CCC[N@H+]1CCC[C@H]1COc1ccc(Oc2ccc(Cl)...,7.6
2,CCC(CC)O[C@@H]1C[C@H](C(=O)[O-])C[C@H]([NH3+])...,9.64
3,CC[C@H](C)[C@H](NC(=O)CNC(=O)[C@H](C)NC(=O)[C@...,4.17
4,[NH3+][C@@H](C[C@]1(C(=O)[O-])C[C@H]2OCC[C@@H]...,6.77


In [56]:
sum(train_df['smiles'].isnull())

322

In [58]:
sum(train_df['value'].isnull())

0

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

class ProteinEncoder3DCNN(nn.Module):
    def __init__(self, in_channels=4, base_filters=16, embedding_dim=128):
        super(ProteinEncoder3DCNN, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, base_filters, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(base_filters)
        self.conv2 = nn.Conv3d(base_filters, base_filters*2, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(base_filters*2)
        self.conv3 = nn.Conv3d(base_filters*2, base_filters*4, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm3d(base_filters*4)
        self.pool = nn.MaxPool3d(2)
        self.fc = nn.Linear((base_filters*4) * 4 * 4 * 4, embedding_dim)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x 

class LigandGNN(nn.Module):
    def __init__(self, in_dim, hidden_dim=64, embedding_dim=128, num_layers=3):
        super(LigandGNN, self).__init__()
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(in_dim, hidden_dim))
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
        self.convs.append(GCNConv(hidden_dim, embedding_dim))
        self.pool = global_mean_pool

    def forward(self, x, edge_index, batch):
        for conv in self.convs[:-1]:
            x = F.relu(conv(x, edge_index))
        x = self.convs[-1](x, edge_index)
        x = self.pool(x, batch)
        return x  

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.mean_head = nn.Linear(hidden_dim, action_dim)
        self.log_std_head = nn.Linear(hidden_dim, action_dim)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        mean = self.mean_head(x)
        log_std = self.log_std_head(x).clamp(-20, 2)
        std = log_std.exp()
        return mean, std

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.q_head = nn.Linear(hidden_dim, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        q = self.q_head(x)
        return q

class DockingAgent(nn.Module):
    def __init__(self, prot_channels, ligand_node_dim, action_dim,
                 prot_embed=128, ligand_embed=128, hidden_dim=256):
        super(DockingAgent, self).__init__()
        self.protein_encoder = ProteinEncoder3DCNN(in_channels=prot_channels,
                                                  embedding_dim=prot_embed)
        self.ligand_encoder = LigandGNN(in_dim=ligand_node_dim,
                                        embedding_dim=ligand_embed)
        state_dim = prot_embed + ligand_embed + action_dim  
        self.actor = Actor(state_dim, action_dim, hidden_dim)
        self.critic = Critic(state_dim, action_dim, hidden_dim)

    def encode(self, prot_tensor, ligand_data, pose):
        h_p = self.protein_encoder(prot_tensor)
        h_l = self.ligand_encoder(ligand_data.x, ligand_data.edge_index, ligand_data.batch)
        state = torch.cat([h_p, h_l, pose], dim=-1)
        return state

    def act(self, state):
        mean, std = self.actor(state)
        dist = torch.distributions.Normal(mean, std)
        action = dist.rsample()
        return action.clamp(-1, 1), dist.log_prob(action).sum(-1)

    def evaluate(self, prot_tensor, ligand_data, pose, action):
        state = self.encode(prot_tensor, ligand_data, pose)
        q_value = self.critic(state, action)
        return q_value

In [48]:
agent = DockingAgent(
    prot_channels=4,
    ligand_node_dim=30,
    action_dim=3 
)
agent

DockingAgent(
  (protein_encoder): ProteinEncoder3DCNN(
    (conv1): Conv3d(4, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn2): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn3): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pool): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (fc): Linear(in_features=4096, out_features=128, bias=True)
  )
  (ligand_encoder): LigandGNN(
    (convs): ModuleList(
      (0): GCNConv(30, 64)
      (1): GCNConv(64, 64)
      (2): GCNConv(64, 128)
    )
  )
  (actor): Actor(
    (fc1): Linear(in_features=259, out_features=256, bias=True)
    (fc2): Linear(in_fe

In [54]:
import torch
import torch.optim as optim
from collections import deque
import random

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Replay buffer (очень упрощённый, без приоритетов)
class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)
        return (torch.stack(state),
                torch.stack(action),
                torch.tensor(reward, dtype=torch.float32, device=device),
                torch.stack(next_state),
                torch.tensor(done, dtype=torch.float32, device=device))
    
    def __len__(self):
        return len(self.buffer)

# Гиперпараметры
batch_size = 64
gamma = 0.99
lr = 3e-4
num_training_steps = 1000

# Инициализация агента и оптимизаторов
agent = DockingAgent(prot_channels=4, ligand_node_dim=30, action_dim=3).to(device)
optimizer_actor = optim.Adam(agent.actor.parameters(), lr=lr)
optimizer_critic = optim.Adam(agent.critic.parameters(), lr=lr)

replay_buffer = ReplayBuffer()

def env_step(action):
    # TODO
    # Заглушка — возвращаем случайные данные
    prot_tensor = torch.randn(1, 4, 32, 32, 32).to(device)
    ligand_data = type('LigandData', (), {})()
    ligand_data.x = torch.randn(100, 30).to(device)
    ligand_data.edge_index = torch.randint(0, 100, (2, 300)).to(device)
    ligand_data.batch = torch.zeros(100, dtype=torch.long).to(device)
    pose = torch.randn(1, 3).to(device)
    reward = random.random()
    done = random.random() < 0.05
    return prot_tensor, ligand_data, pose, reward, done

# Основной цикл обучения
for step in range(num_training_steps):
    # В реальном случае: получаем состояние из окружения
    prot_tensor, ligand_data, pose, reward, done = env_step(None)
    
    # Кодируем состояние
    state = agent.encode(prot_tensor, ligand_data, pose)
    
    # Выбираем действие по текущей политике
    action, log_prob = agent.act(state)
    
    # Выполняем действие, получаем новое состояние и награду
    prot_tensor_next, ligand_data_next, pose_next, reward, done = env_step(action)
    state_next = agent.encode(prot_tensor_next, ligand_data_next, pose_next)
    
    # Сохраняем опыт в replay buffer
    replay_buffer.push(state.detach(), action.detach(), reward, state_next.detach(), done)
    
    if len(replay_buffer) < batch_size:
        continue  # Ждём, пока соберётся достаточно данных
    
    # Выбираем случайный батч
    state_batch, action_batch, reward_batch, next_state_batch, done_batch = replay_buffer.sample(batch_size)
    
    # Критик: вычисляем целевые Q-значения
    with torch.no_grad():
        next_action, next_log_prob = agent.act(next_state_batch)
        target_q = agent.critic(next_state_batch, next_action)
        target_value = reward_batch.unsqueeze(1) + gamma * (1 - done_batch.unsqueeze(1)) * (target_q - next_log_prob.unsqueeze(1))
    
    # Q-функция для текущих состояний и действий
    current_q = agent.critic(state_batch, action_batch)
    
    # Потеря критика (MSE)
    loss_critic = F.mse_loss(current_q, target_value)
    
    optimizer_critic.zero_grad()
    loss_critic.backward()
    optimizer_critic.step()
    
    # Актор: максимизируем ожидаемый Q с учетом энтропии (SAC)
    mean, std = agent.actor(state_batch)
    dist = torch.distributions.Normal(mean, std)
    sampled_action = dist.rsample()
    log_prob_action = dist.log_prob(sampled_action).sum(-1, keepdim=True)
    
    q_val = agent.critic(state_batch, sampled_action)
    
    loss_actor = (log_prob_action - q_val).mean()  # Минимизируем -(Q - log_prob)
    
    optimizer_actor.zero_grad()
    loss_actor.backward()
    optimizer_actor.step()
    
    if step % 100 == 0:

        print(f'Step {step}: loss_critic={loss_critic.item():.4f}, loss_actor={loss_actor.item():.4f}')


  loss_critic = F.mse_loss(current_q, target_value)


Step 100: loss_critic=118.5984, loss_actor=-27.4303
Step 200: loss_critic=2586.0942, loss_actor=-1553.3281
Step 300: loss_critic=1403.2402, loss_actor=-7675.3965
Step 400: loss_critic=1306.0812, loss_actor=-19304.5605
Step 500: loss_critic=529.2733, loss_actor=-45399.3828
Step 600: loss_critic=926.5152, loss_actor=-85250.5234
Step 700: loss_critic=1034.1207, loss_actor=-123226.5391
Step 800: loss_critic=898.5427, loss_actor=-235127.5781
Step 900: loss_critic=1082.6539, loss_actor=-312994.0000
