In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Setting random seed for reproducibility
torch.manual_seed(42)

# 1. Defining a simple dataset
X = torch.randn(100, 10)  # 100 samples, 10 features
y = torch.randint(0, 2, (100,))  # Binary classification

dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# 2. Simple model
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = nn.Linear(10, 2)  # 10 input features, 2 output classes

    def forward(self, x):
        return self.fc(x)

model = SimpleNet()

# 3. Loss function and optimizer for model
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 4. Training loop with checkpointing
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for inputs, targets in dataloader:
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)

    # Save checkpoint after each epoch (Checkpoint can also be saved after a number of training step))
    checkpoint_path = f'checkpoint_epoch{epoch}.pth'
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss,
    }, checkpoint_path)

    print(f"[Epoch {epoch}] Avg Loss: {avg_loss:.4f} | Checkpoint saved to {checkpoint_path}")

[Epoch 0] Avg Loss: 0.7209 | Checkpoint saved to checkpoint_epoch0.pth
[Epoch 1] Avg Loss: 0.7538 | Checkpoint saved to checkpoint_epoch1.pth
[Epoch 2] Avg Loss: 0.7330 | Checkpoint saved to checkpoint_epoch2.pth
[Epoch 3] Avg Loss: 0.7613 | Checkpoint saved to checkpoint_epoch3.pth
[Epoch 4] Avg Loss: 0.7328 | Checkpoint saved to checkpoint_epoch4.pth


In [22]:
#Loading a checkpoint
checkpoint = torch.load('checkpoint_epoch3.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
loss = checkpoint['loss']

print(f"Resumed from epoch {start_epoch}, previous loss: {loss}")

Resumed from epoch 4, previous loss: 0.7612517731530326


In [23]:
import os

# Delete checkpoint

if os.path.exists(checkpoint_path):
    os.remove(checkpoint_path)
    print(f"Checkpoint {checkpoint_path} deleted.")
else:
    print("Checkpoint file not found.")

Checkpoint checkpoint_epoch4.pth deleted.
