Importando bibliotecas

In [1]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader


Carregando e separando o dataset

In [2]:
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

features, labels = fetch_california_housing(return_X_y=True)

features = torch.from_numpy(features).float()
labels = torch.from_numpy(labels)[:, None].float()

data = torch.concat([features, labels], dim=1)
train_data, test_data = train_test_split(data, test_size=0.25)


Definindo os parâmetros para a regressão linear com o PyTorch

In [3]:
dim = features.shape[1]

W = torch.empty(dim).normal_(std=1e-5).requires_grad_(True)
b = torch.empty(1).normal_(std=1e-5).requires_grad_(True)



Definindo um DataLoader

In [4]:
batch_size = 256

train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)


Definindo o otimizador (SGD)

In [5]:
learning_rate = 1e-7

optimizer = torch.optim.SGD([W, b], lr=learning_rate)


Treinando com os dados de treino

In [6]:
epochs = 500
N = features.shape[0]

for epoch in range(epochs):
    optimizer.zero_grad()
    loss = 0

    for batch in train_dataloader:
        X = batch[:, :dim]
        y = batch[:, dim]

        y_hat = X @ W + b
        loss += torch.sum((y_hat - y) ** 2)

    loss /= N

    loss.backward()
    optimizer.step()

    print(f"{epoch}: {loss.item()}")



0: 4.259960174560547
1: 2.775346517562866
2: 2.4068336486816406
3: 2.313920497894287
4: 2.2890665531158447
5: 2.2810215950012207
6: 2.2771294116973877
7: 2.274268627166748
8: 2.271665573120117
9: 2.2691290378570557
10: 2.266613245010376
11: 2.2641072273254395
12: 2.261605739593506
13: 2.2591097354888916
14: 2.256619453430176
15: 2.2541329860687256
16: 2.2516520023345947
17: 2.249175786972046
18: 2.2467048168182373
19: 2.2442383766174316
20: 2.24177622795105
21: 2.2393202781677246
22: 2.236868143081665
23: 2.2344205379486084
24: 2.2319788932800293
25: 2.229541301727295
26: 2.2271080017089844
27: 2.224679946899414
28: 2.222257375717163
29: 2.219838857650757
30: 2.2174248695373535
31: 2.2150163650512695
32: 2.2126121520996094
33: 2.210212469100952
34: 2.207817316055298
35: 2.205427885055542
36: 2.20304274559021
37: 2.2006618976593018
38: 2.198286294937134
39: 2.1959145069122314
40: 2.1935479640960693
41: 2.1911864280700684
42: 2.188829183578491
43: 2.186476230621338
44: 2.1841280460357666