In [1]:
import torch
from torch.utils.data import Dataset
import numpy as np

class CustomDataset():
    def __init__(self):
        self.x_data = np.random.rand(5,3)
        self.y_data = np.random.randn(5,1)
    
    def __len__(self):
        return len(self.x_data)
    
    def __getitem__(self, idx):
        x = torch.FloatTensor(self.x_data[idx])
        y = torch.FloatTensor(self.y_data[idx])
        
        return x, y
    
dataset = CustomDataset()

In [3]:
import torch.nn as nn
import torch.nn.functional as F

class MultivariateLinearRegressionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(3,1)
        
    def forward(self, x):
        return self.linear(x)

model = MultivariateLinearRegressionModel()

In [5]:
from torch.utils.data import DataLoader
from torch import optim


dataloader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=True, # 항상 true 권장
)

W = torch.zeros(1, requires_grad=True)
b = torch.zeros(1, requires_grad=True)

optimizer = optim.SGD([W,b], lr=1e-5)

epochs = 10

for epoch in range(epochs+1):
    for batch_idx, samples in enumerate(dataloader): # minibatch를 위한 for loop
        X_train, y_train = samples
        
        pred = model(X_train)

        loss = F.mse_loss(pred, y_train)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print('Epoch {:4d}/{} Batch {}/{} Loss: {:.6f}'.format(epoch, epochs, batch_idx+1, len(dataloader), loss.item()))


Epoch    0/10 Batch 1/3 Loss: 0.195074
Epoch    0/10 Batch 2/3 Loss: 1.900477
Epoch    0/10 Batch 3/3 Loss: 0.335825
Epoch    1/10 Batch 1/3 Loss: 1.977474
Epoch    1/10 Batch 2/3 Loss: 0.256365
Epoch    1/10 Batch 3/3 Loss: 0.059248
Epoch    2/10 Batch 1/3 Loss: 1.900477
Epoch    2/10 Batch 2/3 Loss: 0.195074
Epoch    2/10 Batch 3/3 Loss: 0.335825
Epoch    3/10 Batch 1/3 Loss: 0.256365
Epoch    3/10 Batch 2/3 Loss: 0.136245
Epoch    3/10 Batch 3/3 Loss: 3.741705
Epoch    4/10 Batch 1/3 Loss: 0.136245
Epoch    4/10 Batch 2/3 Loss: 2.038765
Epoch    4/10 Batch 3/3 Loss: 0.176905
Epoch    5/10 Batch 1/3 Loss: 2.038765
Epoch    5/10 Batch 2/3 Loss: 0.118077
Epoch    5/10 Batch 3/3 Loss: 0.213242
Epoch    6/10 Batch 1/3 Loss: 0.197536
Epoch    6/10 Batch 2/3 Loss: 1.977474
Epoch    6/10 Batch 3/3 Loss: 0.176905
Epoch    7/10 Batch 1/3 Loss: 0.195074
Epoch    7/10 Batch 2/3 Loss: 2.038765
Epoch    7/10 Batch 3/3 Loss: 0.059248
Epoch    8/10 Batch 1/3 Loss: 0.195074
Epoch    8/10 Batch 2/3 L