## CSC413 - research project
---

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/DeepPROTACs

In [None]:
! pip install torch
! pip install torch_geometric
! pip install rdkit

In [None]:
import sys
import numpy as np
import torch
import os
import pickle
import logging
from pathlib import  Path
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from protacloader import PROTACSet, collater
from model_Runshi import GraphConv, SmilesNet, ProtacModel
from train_and_test4 import train
from prepare_data import GraphData

In [None]:
BATCH_SIZE = 1
EPOCH = 200
TRAIN_RATE = 0.8
LEARNING_RATE = 0.0005
WEIGHT_DECAY = 0.0001
TRAIN_NAME = "test_test"
for handler in logging.root.handlers[:]:
      logging.root.removeHandler(handler)
logging.basicConfig(filename="/content/drive/MyDrive/DeepPROTACs/log/"+TRAIN_NAME+".log", filemode="a", level=logging.DEBUG, force=True)
logging.getLogger('RootLogger').setLevel(logging.DEBUG)

In [None]:
Path('/content/drive/MyDrive/DeepPROTACs/log').mkdir(exist_ok=True)
Path('/content/drive/MyDrive/DeepPROTACs/model').mkdir(exist_ok=True)

In [None]:
def main_small():
  root = "small_dataset/data"
  ligase_ligand = GraphData("ligase_ligand", root)
  ligase_pocket = GraphData("ligase_pocket", root)
  target_ligand = GraphData("target_ligand", root)
  target_pocket = GraphData("target_pocket", root)
  with open(os.path.join(target_pocket.processed_dir, "smiles.pkl"),"rb") as f:
      smiles = pickle.load(f)
  with open('small_dataset/name.pkl','rb') as f:
      name_list = pickle.load(f)
  label = torch.load(os.path.join(target_pocket.processed_dir, "label.pt"))

  protac_set = PROTACSet(
      name_list,
      ligase_ligand, 
      ligase_pocket, 
      target_ligand, 
      target_pocket, 
      smiles, 
      label,
  )
  data_size = len(protac_set)
  train_size = int(data_size * TRAIN_RATE)
  test_size = data_size - train_size
  logging.info(f"all data: {data_size}")
  logging.info(f"train data: {train_size}")
  logging.info(f"test data: {test_size}")
  train_dataset = torch.utils.data.Subset(protac_set, range(train_size))
  test_dataset = torch.utils.data.Subset(protac_set, range(train_size, data_size))
  trainloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collater,drop_last=False, shuffle=True)
  testloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=collater,drop_last=False)

  ligase_ligand_model = GraphConv(num_embeddings=10)
  ligase_pocket_model = GraphConv(num_embeddings=5)
  target_ligand_model = GraphConv(num_embeddings=10)
  target_pocket_model = GraphConv(num_embeddings=5)
  smiles_model = SmilesNet(batch_size=BATCH_SIZE)
  model = ProtacModel(
      ligase_ligand_model, 
      ligase_pocket_model,
      target_ligand_model,
      target_pocket_model,
      smiles_model,
  )
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  writer = SummaryWriter(f'runs/{TRAIN_NAME}')
  model = train(
      model, 
      train_loader=trainloader, 
      valid_loader=testloader,
      device=device,
      writer=writer,
      LOSS_NAME=TRAIN_NAME,
      batch_size=BATCH_SIZE,
      epoch=EPOCH,
      lr=LEARNING_RATE,
      weight_decay = WEIGHT_DECAY
  )
  for handler in logging.root.handlers[:]:
      logging.root.removeHandler(handler)

def main_large():
  root = "large_dataset/data"
  with open('large_dataset/name.pkl','rb') as f:
        name_list = pickle.load(f)
  data_size = len(name_list)/2
  train_size = int(data_size * TRAIN_RATE)
  test_size = data_size - train_size

  train_ligase_ligand = GraphData("ligase_ligand", root)[:train_size] + GraphData("ligase_ligand", root)[949:949+train_size]
  test_ligase_ligand = GraphData("ligase_ligand", root)[train_size:949] + GraphData("ligase_ligand", root)[949+train_size:]
  train_ligase_pocket = GraphData("ligase_pocket", root)[:train_size] + GraphData("ligase_pocket", root)[949:949+train_size]
  test_ligase_pocket = GraphData("ligase_pocket", root)[train_size:949] + GraphData("ligase_pocket", root)[949+train_size:]
  train_target_ligand = GraphData("target_ligand", root)[:train_size] + GraphData("target_ligand", root)[949:949+train_size]
  test_target_ligand = GraphData("target_ligand", root)[train_size:949] + GraphData("target_ligand", root)[949+train_size:]
  train_target_pocket = GraphData("target_pocket", root)[:train_size] + GraphData("target_pocket", root)[949:949+train_size]
  test_target_pocket = GraphData("target_pocket", root)[train_size:949] + GraphData("target_pocket", root)[949+train_size:]
  with open(root+"/processed/smiles.pkl","rb") as f:
        smiles = pickle.load(f)
  train_smiles = smiles[:train_size] + smiles[949:949+train_size]
  test_smiles = smiles[train_size:949] + smiles[949+train_size:]

  with open('large_dataset/name.pkl','rb') as f:
        name_list = pickle.load(f)
  train_name = name_list[:train_size] + name_list[949:949+train_size]
  test_name = name_list[train_size:949] + name_list[949+train_size:]

  label = torch.load(root+"/processed/label.pt")
  train_label = label[:train_size] + label[949:949+train_size]
  test_label = label[train_size:949] + label[949+train_size:]

  train_set = PROTACSet(
    train_name,
    train_ligase_ligand, 
    train_ligase_pocket, 
    train_target_ligand, 
    train_target_pocket, 
    train_smiles, 
    train_label,
  )

  valid_set = PROTACSet(
    test_name,
    test_ligase_ligand, 
    test_ligase_pocket, 
    test_target_ligand, 
    test_target_pocket, 
    test_smiles, 
    test_label,
  )

  data_size = len(train_set) + len(valid_set)
  train_size = len(train_set)
  test_size = len(valid_set)
  logging.info(f"all data: {data_size}")
  logging.info(f"train data: {train_size}")
  logging.info(f"test data: {test_size}")
  trainloader = DataLoader(train_set, batch_size=BATCH_SIZE, collate_fn=collater,drop_last=False, shuffle=True)
  testloader = DataLoader(valid_set, batch_size=BATCH_SIZE, collate_fn=collater,drop_last=False, shuffle=True)

  ligase_ligand_model = GraphConv(num_embeddings=10)
  ligase_pocket_model = GraphConv(num_embeddings=5)
  target_ligand_model = GraphConv(num_embeddings=10)
  target_pocket_model = GraphConv(num_embeddings=5)
  smiles_model = SmilesNet(batch_size=BATCH_SIZE)
  model = ProtacModel(
      ligase_ligand_model, 
      ligase_pocket_model,
      target_ligand_model,
      target_pocket_model,
      smiles_model,
  )
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  writer = SummaryWriter(f'runs/{TRAIN_NAME}')
  model = train(
      model, 
      train_loader=trainloader, 
      valid_loader=testloader,
      device=device,
      writer=writer,
      LOSS_NAME=TRAIN_NAME,
      batch_size=BATCH_SIZE,
      epoch=EPOCH,
      lr=LEARNING_RATE,
      weight_decay = WEIGHT_DECAY
  )
  for handler in logging.root.handlers[:]:
      logging.root.removeHandler(handler)

In [None]:
#comment out which one you DON'T want to run
#main_small()
main_large()



KeyboardInterrupt: ignored