In [1]:
%matplotlib inline
import random
import torch
from d2l import torch as d2l

In [None]:
class SyntheticRegressionData(d2l.DataModule):  #@save
    """Synthetic data for linear regression."""
    def __init__(self, w, b, noise=0.01, num_train=1000, num_val=1000, batch_size=32):
        super().__init__()
        # python的method都是依附于类的。
        self.save_hyperparameters()
        n = num_train + num_val
        self.X = torch.randn(n, len(w))
        noise = torch.randn(n, 1) * noise 
        self.y = torch.matmul(self.X, w.reshape(-1, 1)) + b + noise 

In [7]:
data = SyntheticRegressionData(w=torch.tensor([2, -3.4]), b=4.2)
print('features:', data.X[0],'\nlabel:', data.y[0])

features: tensor([-0.0648,  0.2781]) 
label: tensor([3.1115])


In [None]:
def add_to_class(Class):
    def wrapper(obj):
        setattr(Class, obj.__name__, obj)
    return wrapper  

@add_to_class(SyntheticRegressionData)
def get_dataloader(self, train):
    # train -> mode
    if train:
        indices = list(range(0, self.num_train))
        # The examples are read in random order
        random.shuffle(indices)
    else:
        indices = list(range(self.num_train, self.num_train+self.num_val))
    for i in range(0, len(indices), self.batch_size):
        batch_indices = torch.tensor(indices[i:i+self.batch_size])
        yield self.X[batch_indices], self.y[batch_indices]

In [14]:
# 调用 get_dataloader() 不会立即执行，而是返回一个 生成器对象 (generator)。
# iter(对象) 用于获取可迭代对象的迭代器，迭代器是一种可以逐个返回元素的对象。
X,y = next(iter(data.train_dataloader()))
print('X shape:', X.shape, '\ny shape:', y.shape)


X shape: torch.Size([32, 2]) 
y shape: torch.Size([32, 1])


In [16]:
@d2l.add_to_class(d2l.DataModule)  #@save
def get_tensorloader(self, tensors, train, indices=slice(0, None)):
    tensors = tuple(a[indices] for a in tensors)
    dataset = torch.utils.data.TensorDataset(*tensors)
    return torch.utils.data.DataLoader(dataset, self.batch_size,
                                       shuffle=train)

@d2l.add_to_class(SyntheticRegressionData)  #@save
def get_dataloader(self, train):
    i = slice(0, self.num_train) if train else slice(self.num_train, None)
    return self.get_tensorloader((self.X, self.y), train, i)



In [18]:
X, y = next(iter(data.train_dataloader()))
print('X shape:', X.shape, '\ny shape:', y.shape)


X shape: torch.Size([32, 2]) 
y shape: torch.Size([32, 1])


In [20]:
len(data.train_dataloader())

32

In [21]:
for batch in data.train_dataloader(): 
    print(batch)

[tensor([[-9.6927e-01, -9.2085e-02],
        [-1.0948e+00, -1.9321e+00],
        [-8.5725e-01, -1.3869e+00],
        [-6.1067e-02, -1.5707e+00],
        [-4.0936e-01, -3.5088e-01],
        [ 1.0927e+00, -2.5231e-01],
        [-1.6617e+00,  9.5409e-01],
        [ 1.6252e+00,  1.4049e+00],
        [ 1.0791e+00,  8.3563e-01],
        [ 7.9482e-01, -1.0683e+00],
        [ 3.1972e-05,  1.9269e+00],
        [ 5.7276e-01, -2.8157e-01],
        [-3.5000e-01,  5.1940e-01],
        [-3.4358e-02,  7.1717e-02],
        [-9.1325e-01,  1.1166e+00],
        [ 1.4342e+00, -1.0549e+00],
        [ 8.7228e-01, -8.3789e-01],
        [-5.8734e-01,  5.5088e-01],
        [ 6.2122e-01, -6.2073e-01],
        [ 5.6863e-01,  1.5878e-01],
        [ 8.1768e-01,  1.0783e+00],
        [ 1.3565e-01,  1.3494e+00],
        [ 7.3541e-01,  7.7041e-01],
        [-7.1968e-01, -2.4341e+00],
        [-1.0438e+00,  2.8427e+00],
        [ 2.4295e-01,  1.3337e+00],
        [ 1.0647e-02,  1.9532e+00],
        [ 1.3748e+00, -2.09