In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_diabetes

from torch import nn, optim, utils, Tensor
from torch.utils.data import Dataset, DataLoader
import lightning as l

In [4]:
class ANN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 64)
        self.fc2 = nn.Linear(64, 128)
        self.fc3 = nn.Linear(128, 1)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x
    
class DiabetesDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.X = data.data
        self.y = data.target

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]
    
class DiabetesDataModule(l.LightningDataModule):
    def __init__(self, batch_size=32):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        self.data = load_diabetes()
        self.dataset = DiabetesDataset(self.data)

    def setup(self, stage=None):
        self.train, self.val, self.test = utils.random_split(self.dataset, [300, 100, 42])

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)
    
class DiabetesRegressor(l.LightningModule):
    def __init__(self, lr=1e-3):
        super().__init__()
        self.model = ANN()
        self.loss = nn.MSELoss()
        self.lr = lr

    def training_step(self):
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat = self.model(x)
        loss = self.loss(y_hat, y)
        return loss
    
    def configure_optimizers(self):
        return optim.Adam(self.model.parameters(), lr=self.lr)
    
model = DiabetesRegressor()
    

