# 3DCNN Training Demo

### Training Dataloader

Here we use the training dataset from the dataset folder.

In [4]:
from torch.utils.data import DataLoader
from data_reader import LigandDataset

train_batch_size = 10
path = "datasets\postera_protease2_pos_neg_train.hdf5"
train_data = LigandDataset(path,parse_features=False)
train_dataloader = DataLoader(train_data, batch_size=train_batch_size,shuffle=False)
print(train_batch_size)

10


### Validation Dataloader

In [5]:
val_batch_size = 10 
path = "datasets\postera_protease2_pos_neg_val.hdf5"
val_data = LigandDataset(path,parse_features=False)
val_dataloader = DataLoader(val_data, batch_size=val_batch_size,shuffle=False)

### Instantiate models

Set model instances for use with a cuda device.

In [7]:
from model import Model_3DCNN 
from voxelizer import Voxelizer3D
from gaussian_filter import GaussianFilter

use_cuda = True

if use_cuda:
    device = "cuda"
else:
    device = "cpu"

voxelizer = Voxelizer3D(use_cuda=use_cuda,verbose=0)
gaussian_filter = GaussianFilter(dim=3, channels=19, kernel_size=11, sigma=1, use_cuda=use_cuda)
model = Model_3DCNN(use_cuda=use_cuda, num_classes=2)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


### Set Loss Function and Optimizer

In [8]:
from torch.nn import CrossEntropyLoss
from torch.optim import Adam

learning_rate = 1e-4
loss_fn = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08)

### Train!

We still need to add the model checkpoints to save the model after each epoch.

In [10]:
from main_train_validate import train, validate

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    
    
    losses = train(
    train_dataloader, 
    voxelizer,
    gaussian_filter,
    model,
    loss_fn, 
    optimizer, 
    device
    )


    avg_loss, accuracy = validate(
    val_dataloader, 
    voxelizer,
    gaussian_filter,
    model,
    loss_fn, 
    optimizer, 
    device
    )


print("Done!")

Epoch 1
-------------------------------
loss: 6.916627 [    0/19533]
loss: 1.939380 [ 1000/19533]
loss: 2.990100 [ 2000/19533]
loss: 0.921849 [ 3000/19533]
loss: 3.471420 [ 4000/19533]
loss: 3.051025 [ 5000/19533]
