<a href="https://colab.research.google.com/github/PeterHJY628/MyOwnExample/blob/main/save_load_weight_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# install PyTorch
!pip install torch torchvision

# Load Google Drive
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
import os

# check if the path exists, if it exists, list all of them
if os.path.exists('/content/drive'):
    print("Google Drive is mounted at '/content/drive'")
    print("Contents of Google Drive root directory:")
    print(os.listdir('/content/drive/My Drive'))  # List Drive list
else:
    print("Google Drive is not mounted!")


Google Drive is mounted at '/content/drive'
Contents of Google Drive root directory:
['无标题演示文稿.gslides', 'Colab Notebooks', 'MultiTask.csv', '通过 Chrome 保存', 'Share_Weights']


In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)  # 展平输入
        return self.fc(x)

# save checkpoint
def save_checkpoint(model, optimizer, epoch, save_path):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
    }
    torch.save(checkpoint, save_path)
    print(f"Checkpoint saved at epoch {epoch+1} to {save_path}")

# load checkpoint
def load_checkpoint(model, optimizer, load_path):
    checkpoint = torch.load(load_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    print(f"Checkpoint loaded. Resuming training from epoch {start_epoch + 1}")
    return start_epoch

# set hyperparameters
batch_size = 64
learning_rate = 0.001
epochs = 20
checkpoint_path = "/content/drive/My Drive/Share_Weights/test1_checkpoint.pth"  # Google Drive 中的保存路径

# dataset and dataloader
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# initialize model
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# try to load checkpoints
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

try:
    start_epoch = load_checkpoint(model, optimizer, checkpoint_path)
except FileNotFoundError:
    print("No checkpoint found, starting training from scratch.")
    start_epoch = 0

# train
for epoch in range(start_epoch, epochs):  # start from interrupt epoch
    model.train()
    epoch_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)

        # forward
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # statistc
        epoch_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        # print loss for N epoches
        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Step [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}")

    print(f"Epoch [{epoch+1}/{epochs}] - Loss: {epoch_loss/len(train_loader):.4f}, Accuracy: {100.*correct/total:.2f}%")

    # save checkpoint of each epoch
    save_checkpoint(model, optimizer, epoch, checkpoint_path)

print("Finish trainning！")

  checkpoint = torch.load(load_path)


Checkpoint loaded. Resuming training from epoch 18
Epoch [18/20], Step [0/938], Loss: 0.0003
Epoch [18/20], Step [100/938], Loss: 0.0002
Epoch [18/20], Step [200/938], Loss: 0.0008
Epoch [18/20], Step [300/938], Loss: 0.0020
Epoch [18/20], Step [400/938], Loss: 0.0009
Epoch [18/20], Step [500/938], Loss: 0.0004
Epoch [18/20], Step [600/938], Loss: 0.0126
Epoch [18/20], Step [700/938], Loss: 0.0058
Epoch [18/20], Step [800/938], Loss: 0.0084
Epoch [18/20], Step [900/938], Loss: 0.0230
Epoch [18/20] - Loss: 0.0046, Accuracy: 99.88%
Checkpoint saved at epoch 18 to /content/drive/My Drive/Share_Weights/test1_checkpoint.pth
Epoch [19/20], Step [0/938], Loss: 0.0027
Epoch [19/20], Step [100/938], Loss: 0.0015
Epoch [19/20], Step [200/938], Loss: 0.0022
Epoch [19/20], Step [300/938], Loss: 0.0099
Epoch [19/20], Step [400/938], Loss: 0.0032
Epoch [19/20], Step [500/938], Loss: 0.0008
Epoch [19/20], Step [600/938], Loss: 0.0004
Epoch [19/20], Step [700/938], Loss: 0.0006
Epoch [19/20], Step [80

KeyboardInterrupt: 