In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
import matplotlib.pyplot as plt

In [2]:
class Net(nn.Module):
  def __init__(self):
    super().__init__()
    self.fc1 = nn.Linear(9,10)
    self.bn1 = nn.BatchNorm1d(10)
    self.fc2 = nn.Linear(10,5)
    self.bn2 = nn.BatchNorm1d(5)
    self.fc3 = nn.Linear(5,1)
    self.sigmoid = nn.Sigmoid()


  def forward(self, x):
    x = F.elu(self.fc1(x))
    x = self.bn1(x)
    x = F.elu(self.fc2(x))
    x = self.bn2(x)
    x = self.fc3(x)
    x = self.sigmoid(x)
    return x

In [3]:
class WaterDataset(Dataset):
  def __init__(self, csv_path):
    super().__init__()
    df = pd.read_csv(csv_path)
    # Handle missing values, e.g., by imputing with the mean
    df = df.fillna(df.mean())
    self.data = df.to_numpy()
  def __len__(self):
    return len(self.data)
  def __getitem__(self, idx):
    return self.data[idx][:-1], self.data[idx][-1]

In [4]:
dataset = WaterDataset('Dataset/water_potability.csv')

dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
feature , labels = next(iter(dataloader))

print(f"Feature : {feature}")
print(f"Labels : {labels}")

Feature : tensor([[6.6283e+00, 1.9887e+02, 1.5911e+04, 7.5179e+00, 3.4202e+02, 4.3792e+02,
         1.5006e+01, 3.8846e+01, 4.4645e+00],
        [3.4104e+00, 2.0740e+02, 4.9075e+04, 5.6674e+00, 3.0198e+02, 3.5152e+02,
         1.5987e+01, 8.6639e+01, 3.7218e+00],
        [6.4651e+00, 2.4127e+02, 4.3959e+04, 7.4202e+00, 3.0602e+02, 5.4460e+02,
         2.0769e+01, 8.9647e+01, 3.7905e+00],
        [8.5185e+00, 1.2826e+02, 3.2018e+04, 6.0587e+00, 4.5844e+02, 5.5414e+02,
         1.5976e+01, 8.7472e+01, 4.1472e+00],
        [6.7077e+00, 1.9929e+02, 1.7435e+04, 6.8099e+00, 3.3378e+02, 4.5473e+02,
         1.4587e+01, 8.1378e+01, 3.8227e+00],
        [7.5397e+00, 2.0196e+02, 2.6716e+04, 5.6374e+00, 3.3378e+02, 5.1635e+02,
         1.4986e+01, 8.3537e+01, 4.2107e+00],
        [8.9037e+00, 1.9422e+02, 1.3319e+04, 5.5669e+00, 3.2342e+02, 3.1735e+02,
         1.7130e+01, 6.1429e+01, 4.3652e+00],
        [2.9455e+00, 1.2675e+02, 1.6829e+04, 1.0597e+01, 3.3378e+02, 4.5623e+02,
         1.0157e+01,

In [5]:
net = Net()

In [6]:
criterion = nn.BCELoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(50):
  for features , labels in dataloader:
    optimizer.zero_grad()
    outputs = net(features.float()) # Convert features to Float
    loss = criterion(
        outputs, labels.view(-1,1).float() # Convert labels to Float
    )
    loss.backward()
    optimizer.step()
  print(f"Epoch : {epoch} , Loss : {loss.item()}")

Epoch : 0 , Loss : 0.6258886456489563
Epoch : 1 , Loss : 0.6525383591651917
Epoch : 2 , Loss : 0.7569179534912109
Epoch : 3 , Loss : 0.7042312622070312
Epoch : 4 , Loss : 0.7197045683860779
Epoch : 5 , Loss : 0.6023091673851013
Epoch : 6 , Loss : 0.728320300579071
Epoch : 7 , Loss : 0.6344703435897827
Epoch : 8 , Loss : 0.6730366349220276
Epoch : 9 , Loss : 0.6339048147201538
Epoch : 10 , Loss : 0.6233407855033875
Epoch : 11 , Loss : 0.7969539761543274
Epoch : 12 , Loss : 0.6594261527061462
Epoch : 13 , Loss : 0.6499658226966858
Epoch : 14 , Loss : 0.6701683402061462
Epoch : 15 , Loss : 0.6838014125823975
Epoch : 16 , Loss : 0.6655528545379639
Epoch : 17 , Loss : 0.6884009838104248
Epoch : 18 , Loss : 0.7163702845573425
Epoch : 19 , Loss : 0.574576199054718
Epoch : 20 , Loss : 0.7971964478492737
Epoch : 21 , Loss : 0.5675024390220642
Epoch : 22 , Loss : 0.5337291359901428
Epoch : 23 , Loss : 0.6426097750663757
Epoch : 24 , Loss : 0.7214091420173645
Epoch : 25 , Loss : 0.645451962947845

In [7]:
from torchmetrics import Accuracy

acc = Accuracy(task='binary')
net.eval()
with torch.no_grad():
  for features, labels in dataloader:
    outputs = net(features.float())
    pred = (outputs > 0.5).float()
    acc(pred, labels.view(-1,1).float())
print(f"Accuracy : {acc.compute()}")


Accuracy : 0.610805869102478
