In [1]:
import numpy as np
import pandas as pd

from sklearn.datasets import load_iris
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from torch.utils.data import DataLoader, Dataset

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from tqdm.auto import tqdm

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
    
    
print(f'Actual device: {device}')

Actual device: mps


In [3]:
X = load_iris()['data']
X
min_max = MinMaxScaler()
X = min_max.fit_transform(X)
X[:10]

array([[0.22222222, 0.625     , 0.06779661, 0.04166667],
       [0.16666667, 0.41666667, 0.06779661, 0.04166667],
       [0.11111111, 0.5       , 0.05084746, 0.04166667],
       [0.08333333, 0.45833333, 0.08474576, 0.04166667],
       [0.19444444, 0.66666667, 0.06779661, 0.04166667],
       [0.30555556, 0.79166667, 0.11864407, 0.125     ],
       [0.08333333, 0.58333333, 0.06779661, 0.08333333],
       [0.19444444, 0.58333333, 0.08474576, 0.04166667],
       [0.02777778, 0.375     , 0.06779661, 0.04166667],
       [0.16666667, 0.45833333, 0.08474576, 0.        ]])

In [4]:
class IrisDataset(Dataset):
    def __init__(self, data):
        self.data = torch.tensor(data, dtype = torch.float32)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

dataset = IrisDataset(X)
dataloader = DataLoader(dataset, batch_size = 32, shuffle = True)

for batch in dataloader:
    print(f'{batch.shape}\n{batch}')
    break

torch.Size([32, 4])
tensor([[0.2500, 0.6250, 0.0847, 0.0417],
        [0.6667, 0.4167, 0.7119, 0.9167],
        [0.7222, 0.4583, 0.6949, 0.9167],
        [0.4167, 0.2917, 0.6949, 0.7500],
        [0.3056, 0.5833, 0.0847, 0.1250],
        [0.8056, 0.6667, 0.8644, 1.0000],
        [0.4167, 0.8333, 0.0339, 0.0417],
        [0.1944, 0.4167, 0.1017, 0.0417],
        [0.5000, 0.4167, 0.6610, 0.7083],
        [0.3889, 0.3750, 0.5424, 0.5000],
        [0.7778, 0.4167, 0.8305, 0.8333],
        [0.5833, 0.3750, 0.5593, 0.5000],
        [0.3889, 0.7500, 0.1186, 0.0833],
        [0.2222, 0.7083, 0.0847, 0.1250],
        [0.1667, 0.4583, 0.0847, 0.0417],
        [0.6944, 0.4167, 0.7627, 0.8333],
        [0.3333, 0.2500, 0.5763, 0.4583],
        [0.1667, 0.1667, 0.3898, 0.3750],
        [0.1944, 0.6250, 0.0508, 0.0833],
        [0.3889, 0.3333, 0.5254, 0.5000],
        [0.1944, 0.5417, 0.0678, 0.0417],
        [0.6667, 0.4583, 0.7797, 0.9583],
        [0.2222, 0.7500, 0.1525, 0.1250],
        [0.666