In [53]:
import json
import numpy as np
import neural_network as nn
import torch
from torch.utils.data import DataLoader, TensorDataset

In [7]:
with open('drug_smiles_tfidf_vectors.json', 'r') as file:
    data = json.load(file)

In [9]:
#data[2]['tfidf_vector']
train_data = [data[i]['tfidf_vector'] for i in range(1, len(data))] #drug tfidf representations as dummy data

In [83]:
batch_size = 16
shuffle = True

X_train_tensor = torch.FloatTensor(train_data)
Y_train_tensor = torch.FloatTensor([np.random.rand() for i in train_data]) # random true values
train_loader = DataLoader(TensorDataset(X_train_tensor, Y_train_tensor), batch_size=batch_size, shuffle=shuffle)

In [85]:
# Create model
model = nn.ModularFCNN(
    input_size=2048,           # Number of input features
    hidden_layers=[512, 256, 128, 64],  # Hidden layer sizes
    dropout_rate=0.3,          
    activation='relu',         
    batch_norm=True            
)

In [87]:
# Create trainer
trainer = nn.NeuralNetworkTrainer(
    model=model,
    optimizer_name='adam',
    loss_fn_name='mse',
    learning_rate=0.001
)

In [91]:
# Train
trainer.fit(
    train_loader=train_loader,
    epochs=1000,
    verbose=True
)

Epoch 10/1000 - Train Loss: 0.1080
Epoch 20/1000 - Train Loss: 0.0965
Epoch 30/1000 - Train Loss: 0.0913
Epoch 40/1000 - Train Loss: 0.0885
Epoch 50/1000 - Train Loss: 0.0873
Epoch 60/1000 - Train Loss: 0.0829
Epoch 70/1000 - Train Loss: 0.0843
Epoch 80/1000 - Train Loss: 0.0833
Epoch 90/1000 - Train Loss: 0.0818
Epoch 100/1000 - Train Loss: 0.0818
Epoch 110/1000 - Train Loss: 0.0834
Epoch 120/1000 - Train Loss: 0.0825
Epoch 130/1000 - Train Loss: 0.0837
Epoch 140/1000 - Train Loss: 0.0819
Epoch 150/1000 - Train Loss: 0.0816
Epoch 160/1000 - Train Loss: 0.0809
Epoch 170/1000 - Train Loss: 0.0815
Epoch 180/1000 - Train Loss: 0.0836
Epoch 190/1000 - Train Loss: 0.0811
Epoch 200/1000 - Train Loss: 0.0827


KeyboardInterrupt: 

In [97]:
# Predict
predictions = trainer.predict(X_train_tensor)

In [99]:
predictions

array([[0.49191466],
       [0.48773098],
       [0.4878771 ],
       [0.46975085],
       [0.4816925 ],
       [0.46914953],
       [0.4826815 ],
       [0.48597533],
       [0.4693111 ],
       [0.48563135],
       [0.4790277 ],
       [0.49795258],
       [0.4925807 ],
       [0.49344778],
       [0.48497778],
       [0.50619054],
       [0.48560232],
       [0.49279332],
       [0.4867705 ],
       [0.47101355],
       [0.45739925],
       [0.4984839 ],
       [0.478286  ],
       [0.47848138],
       [0.4466008 ],
       [0.47837734],
       [0.4915256 ],
       [0.4876163 ],
       [0.47153813],
       [0.48255107],
       [0.48599628],
       [0.48688743],
       [0.4841503 ],
       [0.49440444],
       [0.4881089 ],
       [0.4772086 ],
       [0.4983378 ],
       [0.48036712],
       [0.48847434],
       [0.4745007 ],
       [0.4993873 ],
       [0.47155353],
       [0.48179883],
       [0.4953664 ],
       [0.47040984],
       [0.4878303 ],
       [0.497089  ],
       [0.501