In [1]:
import random
import torch
import numpy as np

In [2]:
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
random.seed(0)

In [5]:
from esol_utils import load_esol_data

(X_train, X_valid, X_test, y_train, y_valid, y_test, scaler) = load_esol_data()

Number of generated molecular descriptors: 217
Number of molecular descriptors without invalid values: 217


In [10]:
import os
import torch
import wandb
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from pytorch_lightning.loggers import WandbLogger

In [11]:
class NeuralNetwork(pl.LightningModule):
    def __init__(self, input_dim, hdden_dim, train_data, val_data, test_data, batch_size=254, lr=1e-3):
        super().__init__()
        self.lr = lr
        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        self.batch_size = batch_size

        self.model = nn.Sequential(
            nn.Linear(input_dim, hdden_dim),
            nn.ReLU(),
            nn.Linear(hdden_dim, hdden_dim),
            nn.ReLU(),
            nn.Linear(hdden_dim, 1)
        )

    def training_step(self, batch, batch_idx):
        x, y = batch
        z = self.model(x)
        loss = F.mse_loss(z, y)
        self.log("Train loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        z = self.model(x)
        loss = F.mse_loss(z, y) 
        self.log("Valid MSE", loss)
        
    def test_step(self, batch, batch_idx):
        x, y = batch
        z = self.model(x)
        loss = F.mse_loss(z, y) 
        self.log("Test MSE", loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.lr
        )
        return optimizer
    
    def forward(self, x):
        return self.model(x).flatten()
    
    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.valid_data, batch_size=self.batch_size, shuffle=False)
    
    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=self.batch_size, shuffle=False)