In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch_support.nn_utils import training_loop
from torch_support.tensor_utils import xy_to_tensordataset

from tqdm import tqdm

In [None]:


# Example model: a simple feed-forward network for classification
class SimpleNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # x should have shape [batch_size, input_dim]
        out = self.layer1(x)
        out = self.relu(out)
        out = self.layer2(out)
        return out

# Example usage of the training_loop
if __name__ == "__main__":
    # 1. Create some synthetic data (for demonstration)
    #    Suppose we have 1000 samples, each with 20 features, and we want 3-class classification.
    np.random.seed(42)
    X = np.random.randn(10000, 20).astype(np.float32)
    y = np.random.randint(0, 3, size=(10000,))

    train_loader, val_loader = xy_to_tensordataset(X, y.astype(np.int64), val_ratio=.2, return_loader=True)

    # 3. Initialize the model, loss, optimizer
    model = SimpleNet(input_dim=20, hidden_dim=50, output_dim=3)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    # 4. Define or import your training_loop function
    # (Assuming you've already defined it in your code as shown previously.)
    # from your_script_name import training_loop

    # 5. Train the model
    trained_model = training_loop(
        model=model,
        device=device,
        train_loader=train_loader,
        optimizer=optimizer,
        criterion=criterion,
        epochs=5,              # Number of epochs
        val_loader=val_loader  # Use the validation loader to check val loss
    )

    # 6. After training, you can do further evaluation or inference
    model.eval()
    test_input = torch.randn(1, 20, device=device)  # A single sample
    with torch.no_grad():
        logits = trained_model(test_input)
        predicted_label = torch.argmax(logits, dim=1)
    print("Example inference on a single test sample:", predicted_label.item())

