# 3DCNN Training Demo

### Training Dataloader

Here we use the training dataset from the dataset folder.

In [26]:
import torch

from torch.utils.data import DataLoader
from data_reader import LigandDataset

train_batch_size = 64
path = "datasets\postera_protease2_pos_neg_train.hdf5"
train_data = LigandDataset(path,parse_features=False)
train_dataloader = DataLoader(train_data, batch_size=train_batch_size,shuffle=False)
print(train_batch_size)

64


### Validation Dataloader

In [28]:
val_batch_size = 64
path = "datasets\postera_protease2_pos_neg_val.hdf5"
val_data = LigandDataset(path,parse_features=False)
val_dataloader = DataLoader(val_data, batch_size=val_batch_size,shuffle=False)

### Instantiate models

Set model instances for use with a cuda device.

In [29]:
from model import Model_3DCNN 
from voxelizer import Voxelizer3D
from gaussian_filter import GaussianFilter

use_cuda = True

if use_cuda:
    device = "cuda"
else:
    device = "cpu"

voxelizer = Voxelizer3D(use_cuda=use_cuda,verbose=0)
gaussian_filter = GaussianFilter(dim=3, channels=19, kernel_size=11, sigma=1, use_cuda=use_cuda)
model = Model_3DCNN(use_cuda=use_cuda, num_classes=2)

### Set Loss Function and Optimizer

In [30]:
from torch.nn import CrossEntropyLoss
from torch.optim import Adam, RMSprop, lr_scheduler



#optimizer = Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08)
learning_rate = 7e-4
decay_iter =100
decay_rate = 0.95

loss_fn = CrossEntropyLoss()

optimizer = RMSprop(model.parameters(), lr=learning_rate)
scheduler = lr_scheduler.StepLR(optimizer, step_size=decay_iter, gamma=decay_rate)

### Train!

We still need to add the model checkpoints to save the model after each epoch.

In [31]:
from main_train_validate import train, validate

# here we set a model path for saving

epochs = 5
for epoch in range(epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    
    
    losses = train(
    train_dataloader, 
    voxelizer,
    gaussian_filter,
    model,
    loss_fn, 
    optimizer, 
    device
    )


    avg_loss, accuracy = validate(
    val_dataloader, 
    voxelizer,
    gaussian_filter,
    model,
    loss_fn, 
    optimizer, 
    device
    )

    checkpoint_dict = {
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "loss": losses[-1],
    "epoch": epoch+1
    }
    model_path = "models\\3DCNN_model_" + "checkpoint" + str(epoch+1) + ".pth"
    torch.save(checkpoint_dict, model_path)



print("Done!")

Epoch 1
-------------------------------
loss: 6.870573 [    0/19533]


KeyboardInterrupt: 