In [35]:
import torch
import torch.nn as nn
import torch.optim as optim
from google.colab import drive

In [36]:
# Mount Google Drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [37]:
# Define a simple feedforward neural network (MLP)
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)  # Example input size for MNIST images
        self.fc2 = nn.Linear(128, 10)     # Example output for 10 classes (MNIST)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [44]:
model = SimpleModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Dummy data loader (ensure labels are correctly shaped)
train_loader = torch.utils.data.DataLoader(
    [(torch.randn(28*28), torch.randint(0, 10, ())) for _ in range(100)],  # torch.randint(0, 10, ()) generates a scalar label
    batch_size=32, shuffle=True
)

In [45]:
# Training function (simplified)
def train(model, optimizer, criterion, train_loader, epochs=2):
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs = inputs.view(inputs.size(0), -1)  # Flattening inputs
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_loader)}")

        # Save checkpoint after each epoch
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': running_loss / len(train_loader)
        }
        checkpoint_path = "/content/drive/My Drive/checkpoint.pth"
        save_checkpoint(checkpoint, checkpoint_path)


In [46]:
# Save checkpoint function
def save_checkpoint(checkpoint, checkpoint_path):
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved to {checkpoint_path}")

In [47]:
# Load checkpoint function
def load_checkpoint(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print(f"Checkpoint loaded from {checkpoint_path}")
    print(f"Resumed from epoch {epoch}, loss: {loss}")
    return model, optimizer, epoch

In [48]:
train(model, optimizer, criterion, train_loader)

Epoch 1, Loss: 2.2810311913490295
Checkpoint saved to /content/drive/My Drive/checkpoint.pth
Epoch 2, Loss: 1.619733989238739
Checkpoint saved to /content/drive/My Drive/checkpoint.pth


In [49]:
model, optimizer, epoch = load_checkpoint('/content/drive/My Drive/checkpoint.pth')

Checkpoint loaded from /content/drive/My Drive/checkpoint.pth
Resumed from epoch 2, loss: 1.619733989238739


  checkpoint = torch.load(checkpoint_path)
