In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Define a simple neural network with some convex layers
class SimpleConvexNet(nn.Module):
    def __init__(self):
        super(SimpleConvexNet, self).__init__()
        self.fc1 = nn.Linear(784, 128)  # Fully connected layer
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 1)    # Output layer for convex function

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Initialize network, optimizer, and loss function
model = SimpleConvexNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

# Prepare data (dummy data in this case)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Initialize TensorBoard writer
writer = SummaryWriter('runs/convex_net_example')

# Training loop with TensorBoard weight visualization
num_epochs = 5
for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(train_loader):
        # Forward pass
        outputs = model(data)
        loss = criterion(outputs, targets.float().unsqueeze(1))

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Every 100 batches, log weights and gradients to TensorBoard
        if batch_idx % 100 == 0:
            for name, param in model.named_parameters():
                writer.add_histogram(f'{name}.weights', param, epoch)
                writer.add_histogram(f'{name}.gradients', param.grad, epoch)

            # Log the loss
            writer.add_scalar('Training Loss', loss.item(), epoch * len(train_loader) + batch_idx)

# Close the writer after training
writer.close()


: 