In [7]:
import torch
import torch.nn as nn
import torch.optim as optim

# For reproducibility
torch.manual_seed(0)
# (If using numpy or random, also set np.random.seed(0) and random.seed(0))

# 1. Define a simple model
class Net(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Net, self).__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_dim, 1)
        # output 1 logit for binary classification

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x
    
# Assume input dimension is 20
model = Net(input_dim=20, hidden_dim=10)

# 2. Define loss and optimizer
criterion = nn.BCEWithLogitsLoss() # for binary classification
optimizer = optim.SGD(model.parameters(), lr=0.1)

# 3. Prepare dummy data
# let's simulare 100 samples, 20 features
X = torch.randn(100, 20)
# random 0/1 labels
y = torch.randint(0, 2, (100, 1)).float()

# 4. Training Loop
model.train()
num_epochs = 5
for epoch in range(1, num_epochs+1):
    #Forward Pass
    outputs = model(X)
    loss = criterion(outputs, y)
    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    #Logging
    print(f"""Epoch [{epoch} / {num_epochs}],
          Loss: {loss.item():.4f}""")
    
# 5. save the model weights
torch.save(model.state_dict(), "../../_model/model_weights.pth")
print("Model Weights saved")

# (later or in deployment)
# Load the model weights into a new instance
model2 = Net(input_dim=20, hidden_dim=10)
model2.load_state_dict(torch.load(r"../../_model/model_weights.pth"))
model2.eval()

Epoch [1 / 5],
          Loss: 0.7115
Epoch [2 / 5],
          Loss: 0.7091
Epoch [3 / 5],
          Loss: 0.7070
Epoch [4 / 5],
          Loss: 0.7050
Epoch [5 / 5],
          Loss: 0.7031
Model Weights saved


Net(
  (linear1): Linear(in_features=20, out_features=10, bias=True)
  (relu): ReLU()
  (linear2): Linear(in_features=10, out_features=1, bias=True)
)

In [8]:
torch.save({
    'epoch': num_epochs,
    'model_state': model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'loss': loss.item()
}, "../../_model/checkpoint.pth")
