<a href="https://colab.research.google.com/github/PeterHJY628/tutorial_notebooks/blob/main/save_load_weight_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
# install PyTorch
!pip install torch torchvision

# Load Google Drive
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [7]:
import os

# check if the path exists, if it exists, list all of them
if os.path.exists('/content/drive'):
    print("Google Drive is mounted at '/content/drive'")
    print("Contents of Google Drive root directory:")
    print(os.listdir('/content/drive/My Drive'))  # List Drive list
else:
    print("Google Drive is not mounted!")


Google Drive is mounted at '/content/drive'
Contents of Google Drive root directory:
['未命名文件夹', '无标题演示文稿.gslides', 'Colab Notebooks', 'model_checkpoint.pth']


In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)  # 展平输入
        return self.fc(x)

# 定义保存检查点的函数
def save_checkpoint(model, optimizer, epoch, save_path):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
    }
    torch.save(checkpoint, save_path)
    print(f"Checkpoint saved at epoch {epoch+1} to {save_path}")

# 定义加载检查点的函数
def load_checkpoint(model, optimizer, load_path):
    checkpoint = torch.load(load_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    print(f"Checkpoint loaded. Resuming training from epoch {start_epoch + 1}")
    return start_epoch

# 设置超参数
batch_size = 64
learning_rate = 0.001
epochs = 20
checkpoint_path = "/content/drive/My Drive/model_checkpoint.pth"  # Google Drive 中的保存路径

# 数据集和数据加载器
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# 初始化模型、损失函数和优化器
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 加载检查点（如果存在）
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

try:
    start_epoch = load_checkpoint(model, optimizer, checkpoint_path)
except FileNotFoundError:
    print("No checkpoint found, starting training from scratch.")
    start_epoch = 0

# 训练模型
for epoch in range(start_epoch, epochs):  # 从中断的 epoch 开始
    model.train()
    epoch_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)

        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 统计
        epoch_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        # 打印每 N 个 batch 的损失
        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Step [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}")

    print(f"Epoch [{epoch+1}/{epochs}] - Loss: {epoch_loss/len(train_loader):.4f}, Accuracy: {100.*correct/total:.2f}%")

    # 每个 epoch 保存检查点
    save_checkpoint(model, optimizer, epoch, checkpoint_path)

print("训练完成！")

  checkpoint = torch.load(load_path)


Checkpoint loaded. Resuming training from epoch 6
Epoch [6/20], Step [0/938], Loss: 0.0042
Epoch [6/20], Step [100/938], Loss: 0.0158
Epoch [6/20], Step [200/938], Loss: 0.0135
Epoch [6/20], Step [300/938], Loss: 0.0357
Epoch [6/20], Step [400/938], Loss: 0.0053
Epoch [6/20], Step [500/938], Loss: 0.0428
Epoch [6/20], Step [600/938], Loss: 0.0189
Epoch [6/20], Step [700/938], Loss: 0.0654
Epoch [6/20], Step [800/938], Loss: 0.0177
Epoch [6/20], Step [900/938], Loss: 0.0275
Epoch [6/20] - Loss: 0.0249, Accuracy: 99.30%
Checkpoint saved at epoch 6 to /content/drive/My Drive/model_checkpoint.pth
Epoch [7/20], Step [0/938], Loss: 0.0130
Epoch [7/20], Step [100/938], Loss: 0.0137
Epoch [7/20], Step [200/938], Loss: 0.0106
Epoch [7/20], Step [300/938], Loss: 0.0062
Epoch [7/20], Step [400/938], Loss: 0.0042
Epoch [7/20], Step [500/938], Loss: 0.0683
Epoch [7/20], Step [600/938], Loss: 0.0018
Epoch [7/20], Step [700/938], Loss: 0.0457
Epoch [7/20], Step [800/938], Loss: 0.0053
Epoch [7/20], S