 Implement a simple linear regression model with one input and one output. For this, you can define a
 model class that inherits from torch.nn.Module, and within this class define your model’s architecture
 in the __init__ method and how it processes input tensors in the forward method. The model should
 consist of one linear layer (torch.nn.Linear). Then, generate some random data and labels, perform
 a forward pass of the data through the model, calculate the mean squared error loss, perform
 backpropagation using the backward method, and finally update the model weights with the help of
 the SGD optimizer
 .
 After this training epoch, save its weights. Then create a new model of the same architecture, load
 the saved weights into it, and apply it to some data.

In [34]:
import torch
import torch.nn as nn
import torch.optim as optim

# 1. 定义线性回归模型
class LinearRegressionModel(nn.Module):
    def __init__(self,layers=8, hidden_dim=1024):
        super().__init__()
        self.layers = nn.ModuleList()
        input_dim = 1
        for _ in range(layers):
            self.layers.append(nn.Linear(input_dim, hidden_dim))
            self.layers.append(nn.ReLU())  # 用ReLU激活
            input_dim = hidden_dim
        self.fc_out = nn.Linear(hidden_dim, 1)  # 输出层

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = self.fc_out(x)
        return x


In [35]:
# 2. 生成数据
torch.manual_seed(0)
x = torch.linspace(-3, 3, 1024).unsqueeze(1)
# 目标函数：y = 0.5 * x^3 - x^2 + 2 * sin(x) + 3
y = 0.5 * x**3 - x**2 + 2 * torch.sin(x) + 3 + 0.2 * torch.randn(1024, 1)  # 加一点噪声

In [36]:
# 3. 实例化模型、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LinearRegressionModel(layers=8).to(device) 
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [37]:
# 4. 训练（加入batch和epoch）
from tqdm import tqdm

batch_size = 64
epochs = 100
dataset = torch.utils.data.TensorDataset(x, y)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    tbar = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}")
    for batch_x, batch_y in tbar:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        optimizer.zero_grad()
        outputs = model(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * batch_x.size(0)
        tbar.set_postfix(batch_loss=loss.item())
    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(dataset):.4f}")

Epoch 1/100: 100%|██████████| 16/16 [00:00<00:00, 35.89it/s, batch_loss=26.8]


Epoch 1/100, Loss: 40.6309


Epoch 2/100: 100%|██████████| 16/16 [00:00<00:00, 175.73it/s, batch_loss=2.51]


Epoch 2/100, Loss: 8.8158


Epoch 3/100: 100%|██████████| 16/16 [00:00<00:00, 168.73it/s, batch_loss=1.85]


Epoch 3/100, Loss: 2.8921


Epoch 4/100: 100%|██████████| 16/16 [00:00<00:00, 175.47it/s, batch_loss=0.949]


Epoch 4/100, Loss: 1.4398


Epoch 5/100: 100%|██████████| 16/16 [00:00<00:00, 176.15it/s, batch_loss=0.955]


Epoch 5/100, Loss: 0.7976


Epoch 6/100: 100%|██████████| 16/16 [00:00<00:00, 174.75it/s, batch_loss=0.367]


Epoch 6/100, Loss: 0.6069


Epoch 7/100: 100%|██████████| 16/16 [00:00<00:00, 178.96it/s, batch_loss=0.587]


Epoch 7/100, Loss: 0.5109


Epoch 8/100: 100%|██████████| 16/16 [00:00<00:00, 172.28it/s, batch_loss=0.336]


Epoch 8/100, Loss: 0.3649


Epoch 9/100: 100%|██████████| 16/16 [00:00<00:00, 168.61it/s, batch_loss=0.171]


Epoch 9/100, Loss: 0.2723


Epoch 10/100: 100%|██████████| 16/16 [00:00<00:00, 176.09it/s, batch_loss=0.126]


Epoch 10/100, Loss: 0.2056


Epoch 11/100: 100%|██████████| 16/16 [00:00<00:00, 183.89it/s, batch_loss=0.332]


Epoch 11/100, Loss: 0.1925


Epoch 12/100: 100%|██████████| 16/16 [00:00<00:00, 182.56it/s, batch_loss=0.0966]


Epoch 12/100, Loss: 0.1690


Epoch 13/100: 100%|██████████| 16/16 [00:00<00:00, 186.74it/s, batch_loss=0.104]


Epoch 13/100, Loss: 0.1565


Epoch 14/100: 100%|██████████| 16/16 [00:00<00:00, 185.33it/s, batch_loss=0.0986]


Epoch 14/100, Loss: 0.1194


Epoch 15/100: 100%|██████████| 16/16 [00:00<00:00, 179.44it/s, batch_loss=0.108]


Epoch 15/100, Loss: 0.0923


Epoch 16/100: 100%|██████████| 16/16 [00:00<00:00, 175.01it/s, batch_loss=0.0634]


Epoch 16/100, Loss: 0.0966


Epoch 17/100: 100%|██████████| 16/16 [00:00<00:00, 173.92it/s, batch_loss=0.0906]


Epoch 17/100, Loss: 0.0870


Epoch 18/100: 100%|██████████| 16/16 [00:00<00:00, 183.43it/s, batch_loss=0.144]


Epoch 18/100, Loss: 0.0859


Epoch 19/100: 100%|██████████| 16/16 [00:00<00:00, 189.26it/s, batch_loss=0.0705]


Epoch 19/100, Loss: 0.0967


Epoch 20/100: 100%|██████████| 16/16 [00:00<00:00, 184.55it/s, batch_loss=0.111]


Epoch 20/100, Loss: 0.0786


Epoch 21/100: 100%|██████████| 16/16 [00:00<00:00, 184.97it/s, batch_loss=0.0687]


Epoch 21/100, Loss: 0.0837


Epoch 22/100: 100%|██████████| 16/16 [00:00<00:00, 180.75it/s, batch_loss=0.0477]


Epoch 22/100, Loss: 0.0697


Epoch 23/100: 100%|██████████| 16/16 [00:00<00:00, 179.27it/s, batch_loss=0.0958]


Epoch 23/100, Loss: 0.0825


Epoch 24/100: 100%|██████████| 16/16 [00:00<00:00, 177.89it/s, batch_loss=0.0611]


Epoch 24/100, Loss: 0.0785


Epoch 25/100: 100%|██████████| 16/16 [00:00<00:00, 186.87it/s, batch_loss=0.0448]


Epoch 25/100, Loss: 0.0744


Epoch 26/100: 100%|██████████| 16/16 [00:00<00:00, 183.66it/s, batch_loss=0.0923]


Epoch 26/100, Loss: 0.0854


Epoch 27/100: 100%|██████████| 16/16 [00:00<00:00, 187.92it/s, batch_loss=0.0895]


Epoch 27/100, Loss: 0.1120


Epoch 28/100: 100%|██████████| 16/16 [00:00<00:00, 181.84it/s, batch_loss=0.216]


Epoch 28/100, Loss: 0.1011


Epoch 29/100: 100%|██████████| 16/16 [00:00<00:00, 183.10it/s, batch_loss=0.0677]


Epoch 29/100, Loss: 0.0839


Epoch 30/100: 100%|██████████| 16/16 [00:00<00:00, 169.61it/s, batch_loss=0.0855]


Epoch 30/100, Loss: 0.0736


Epoch 31/100: 100%|██████████| 16/16 [00:00<00:00, 175.32it/s, batch_loss=0.0702]


Epoch 31/100, Loss: 0.0859


Epoch 32/100: 100%|██████████| 16/16 [00:00<00:00, 182.44it/s, batch_loss=0.0801]


Epoch 32/100, Loss: 0.1160


Epoch 33/100: 100%|██████████| 16/16 [00:00<00:00, 190.52it/s, batch_loss=0.0751]


Epoch 33/100, Loss: 0.1044


Epoch 34/100: 100%|██████████| 16/16 [00:00<00:00, 185.73it/s, batch_loss=0.0469]


Epoch 34/100, Loss: 0.0776


Epoch 35/100: 100%|██████████| 16/16 [00:00<00:00, 188.66it/s, batch_loss=0.0609]


Epoch 35/100, Loss: 0.0591


Epoch 36/100: 100%|██████████| 16/16 [00:00<00:00, 179.33it/s, batch_loss=0.0845]


Epoch 36/100, Loss: 0.0688


Epoch 37/100: 100%|██████████| 16/16 [00:00<00:00, 167.76it/s, batch_loss=0.0527]


Epoch 37/100, Loss: 0.0645


Epoch 38/100: 100%|██████████| 16/16 [00:00<00:00, 180.37it/s, batch_loss=0.057]


Epoch 38/100, Loss: 0.0656


Epoch 39/100: 100%|██████████| 16/16 [00:00<00:00, 187.67it/s, batch_loss=0.107]


Epoch 39/100, Loss: 0.0595


Epoch 40/100: 100%|██████████| 16/16 [00:00<00:00, 187.24it/s, batch_loss=0.0714]


Epoch 40/100, Loss: 0.0697


Epoch 41/100: 100%|██████████| 16/16 [00:00<00:00, 186.92it/s, batch_loss=0.0678]


Epoch 41/100, Loss: 0.0741


Epoch 42/100: 100%|██████████| 16/16 [00:00<00:00, 187.92it/s, batch_loss=0.0974]


Epoch 42/100, Loss: 0.1246


Epoch 43/100: 100%|██████████| 16/16 [00:00<00:00, 172.56it/s, batch_loss=0.0611]


Epoch 43/100, Loss: 0.0691


Epoch 44/100: 100%|██████████| 16/16 [00:00<00:00, 166.70it/s, batch_loss=0.0432]


Epoch 44/100, Loss: 0.0562


Epoch 45/100: 100%|██████████| 16/16 [00:00<00:00, 175.00it/s, batch_loss=0.0724]


Epoch 45/100, Loss: 0.0729


Epoch 46/100: 100%|██████████| 16/16 [00:00<00:00, 173.21it/s, batch_loss=0.0817]


Epoch 46/100, Loss: 0.1233


Epoch 47/100: 100%|██████████| 16/16 [00:00<00:00, 177.14it/s, batch_loss=0.0903]


Epoch 47/100, Loss: 0.1316


Epoch 48/100: 100%|██████████| 16/16 [00:00<00:00, 174.10it/s, batch_loss=0.0449]


Epoch 48/100, Loss: 0.0605


Epoch 49/100: 100%|██████████| 16/16 [00:00<00:00, 171.08it/s, batch_loss=0.0717]


Epoch 49/100, Loss: 0.0538


Epoch 50/100: 100%|██████████| 16/16 [00:00<00:00, 174.66it/s, batch_loss=0.112]


Epoch 50/100, Loss: 0.0824


Epoch 51/100: 100%|██████████| 16/16 [00:00<00:00, 174.42it/s, batch_loss=0.079]


Epoch 51/100, Loss: 0.0712


Epoch 52/100: 100%|██████████| 16/16 [00:00<00:00, 178.07it/s, batch_loss=0.0748]


Epoch 52/100, Loss: 0.0705


Epoch 53/100: 100%|██████████| 16/16 [00:00<00:00, 174.40it/s, batch_loss=0.0636]


Epoch 53/100, Loss: 0.0634


Epoch 54/100: 100%|██████████| 16/16 [00:00<00:00, 177.14it/s, batch_loss=0.0939]


Epoch 54/100, Loss: 0.0613


Epoch 55/100: 100%|██████████| 16/16 [00:00<00:00, 173.32it/s, batch_loss=0.0672]


Epoch 55/100, Loss: 0.0784


Epoch 56/100: 100%|██████████| 16/16 [00:00<00:00, 171.95it/s, batch_loss=0.107]


Epoch 56/100, Loss: 0.0712


Epoch 57/100: 100%|██████████| 16/16 [00:00<00:00, 172.53it/s, batch_loss=0.0441]


Epoch 57/100, Loss: 0.0719


Epoch 58/100: 100%|██████████| 16/16 [00:00<00:00, 175.17it/s, batch_loss=0.0436]


Epoch 58/100, Loss: 0.0627


Epoch 59/100: 100%|██████████| 16/16 [00:00<00:00, 185.65it/s, batch_loss=0.0658]


Epoch 59/100, Loss: 0.0544


Epoch 60/100: 100%|██████████| 16/16 [00:00<00:00, 188.27it/s, batch_loss=0.045]


Epoch 60/100, Loss: 0.0539


Epoch 61/100: 100%|██████████| 16/16 [00:00<00:00, 193.06it/s, batch_loss=0.0531]


Epoch 61/100, Loss: 0.0621


Epoch 62/100: 100%|██████████| 16/16 [00:00<00:00, 188.14it/s, batch_loss=0.042]


Epoch 62/100, Loss: 0.0556


Epoch 63/100: 100%|██████████| 16/16 [00:00<00:00, 184.91it/s, batch_loss=0.0702]


Epoch 63/100, Loss: 0.1000


Epoch 64/100: 100%|██████████| 16/16 [00:00<00:00, 171.66it/s, batch_loss=0.0943]


Epoch 64/100, Loss: 0.0844


Epoch 65/100: 100%|██████████| 16/16 [00:00<00:00, 180.66it/s, batch_loss=0.0663]


Epoch 65/100, Loss: 0.0525


Epoch 66/100: 100%|██████████| 16/16 [00:00<00:00, 193.27it/s, batch_loss=0.0583]


Epoch 66/100, Loss: 0.0594


Epoch 67/100: 100%|██████████| 16/16 [00:00<00:00, 192.61it/s, batch_loss=0.0766]


Epoch 67/100, Loss: 0.0704


Epoch 68/100: 100%|██████████| 16/16 [00:00<00:00, 190.14it/s, batch_loss=0.0614]


Epoch 68/100, Loss: 0.0689


Epoch 69/100: 100%|██████████| 16/16 [00:00<00:00, 185.98it/s, batch_loss=0.0488]


Epoch 69/100, Loss: 0.0526


Epoch 70/100: 100%|██████████| 16/16 [00:00<00:00, 185.18it/s, batch_loss=0.0339]


Epoch 70/100, Loss: 0.0520


Epoch 71/100: 100%|██████████| 16/16 [00:00<00:00, 177.81it/s, batch_loss=0.085]


Epoch 71/100, Loss: 0.0726


Epoch 72/100: 100%|██████████| 16/16 [00:00<00:00, 169.46it/s, batch_loss=0.0711]


Epoch 72/100, Loss: 0.1129


Epoch 73/100: 100%|██████████| 16/16 [00:00<00:00, 175.15it/s, batch_loss=0.097]


Epoch 73/100, Loss: 0.1099


Epoch 74/100: 100%|██████████| 16/16 [00:00<00:00, 175.27it/s, batch_loss=0.059]


Epoch 74/100, Loss: 0.0652


Epoch 75/100: 100%|██████████| 16/16 [00:00<00:00, 174.86it/s, batch_loss=0.118]


Epoch 75/100, Loss: 0.0732


Epoch 76/100: 100%|██████████| 16/16 [00:00<00:00, 175.01it/s, batch_loss=0.147]


Epoch 76/100, Loss: 0.0762


Epoch 77/100: 100%|██████████| 16/16 [00:00<00:00, 173.83it/s, batch_loss=0.0903]


Epoch 77/100, Loss: 0.0720


Epoch 78/100: 100%|██████████| 16/16 [00:00<00:00, 169.80it/s, batch_loss=0.046]


Epoch 78/100, Loss: 0.0667


Epoch 79/100: 100%|██████████| 16/16 [00:00<00:00, 170.54it/s, batch_loss=0.0708]


Epoch 79/100, Loss: 0.0541


Epoch 80/100: 100%|██████████| 16/16 [00:00<00:00, 174.37it/s, batch_loss=0.0567]


Epoch 80/100, Loss: 0.0562


Epoch 81/100: 100%|██████████| 16/16 [00:00<00:00, 178.05it/s, batch_loss=0.0335]


Epoch 81/100, Loss: 0.0524


Epoch 82/100: 100%|██████████| 16/16 [00:00<00:00, 176.65it/s, batch_loss=0.0619]


Epoch 82/100, Loss: 0.0471


Epoch 83/100: 100%|██████████| 16/16 [00:00<00:00, 168.87it/s, batch_loss=0.0513]


Epoch 83/100, Loss: 0.0515


Epoch 84/100: 100%|██████████| 16/16 [00:00<00:00, 173.08it/s, batch_loss=0.0489]


Epoch 84/100, Loss: 0.0500


Epoch 85/100: 100%|██████████| 16/16 [00:00<00:00, 172.09it/s, batch_loss=0.0675]


Epoch 85/100, Loss: 0.0651


Epoch 86/100: 100%|██████████| 16/16 [00:00<00:00, 184.49it/s, batch_loss=0.0736]


Epoch 86/100, Loss: 0.1072


Epoch 87/100: 100%|██████████| 16/16 [00:00<00:00, 184.80it/s, batch_loss=0.0466]


Epoch 87/100, Loss: 0.0650


Epoch 88/100: 100%|██████████| 16/16 [00:00<00:00, 186.50it/s, batch_loss=0.0613]


Epoch 88/100, Loss: 0.0694


Epoch 89/100: 100%|██████████| 16/16 [00:00<00:00, 184.76it/s, batch_loss=0.0592]


Epoch 89/100, Loss: 0.0537


Epoch 90/100: 100%|██████████| 16/16 [00:00<00:00, 182.20it/s, batch_loss=0.0423]


Epoch 90/100, Loss: 0.0490


Epoch 91/100: 100%|██████████| 16/16 [00:00<00:00, 175.73it/s, batch_loss=0.129]


Epoch 91/100, Loss: 0.0770


Epoch 92/100: 100%|██████████| 16/16 [00:00<00:00, 171.44it/s, batch_loss=0.0684]


Epoch 92/100, Loss: 0.0647


Epoch 93/100: 100%|██████████| 16/16 [00:00<00:00, 175.33it/s, batch_loss=0.0373]


Epoch 93/100, Loss: 0.0496


Epoch 94/100: 100%|██████████| 16/16 [00:00<00:00, 175.40it/s, batch_loss=0.0393]


Epoch 94/100, Loss: 0.0471


Epoch 95/100: 100%|██████████| 16/16 [00:00<00:00, 176.25it/s, batch_loss=0.0601]


Epoch 95/100, Loss: 0.0528


Epoch 96/100: 100%|██████████| 16/16 [00:00<00:00, 174.96it/s, batch_loss=0.0679]


Epoch 96/100, Loss: 0.0550


Epoch 97/100: 100%|██████████| 16/16 [00:00<00:00, 169.84it/s, batch_loss=0.0606]


Epoch 97/100, Loss: 0.0663


Epoch 98/100: 100%|██████████| 16/16 [00:00<00:00, 170.86it/s, batch_loss=0.0517]


Epoch 98/100, Loss: 0.0690


Epoch 99/100: 100%|██████████| 16/16 [00:00<00:00, 170.98it/s, batch_loss=0.0347]


Epoch 99/100, Loss: 0.0535


Epoch 100/100: 100%|██████████| 16/16 [00:00<00:00, 173.32it/s, batch_loss=0.0397]

Epoch 100/100, Loss: 0.0519





In [38]:
# 5. 
torch.save(model.state_dict(), "simple_linear_regression.pth")

In [39]:
# 6. 
new_model = LinearRegressionModel(layers=8, hidden_dim=1024).to(device)
new_model.load_state_dict(torch.load("simple_linear_regression.pth"))

  new_model.load_state_dict(torch.load("simple_linear_regression.pth"))


<All keys matched successfully>

In [None]:
# 7. 用新模型做预测
with torch.no_grad():
    # 预测部分数据
    test_x = torch.tensor([[0.5], [-1.0], [2.0]]).to(device)
    pred = new_model(test_x)
    print("Predictions for test_x:", pred.cpu().numpy())
    # 计算真实函数值
    true_y = 0.5 * test_x.cpu()**3 - test_x.cpu()**2 + 2 * torch.sin(test_x.cpu()) + 3
    print("True values for test_x:", true_y.numpy())

Predictions for test_x: [[ 3.823445  ]
 [-0.13981417]
 [ 4.787151  ]]
True values for test_x: [[ 3.771351  ]
 [-0.18294191]
 [ 4.818595  ]]
