In [3]:
# 1. Setup und Imports
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt
import wandb
import numpy as np

In [4]:

# 2. Daten generieren (gleiche Spiral-Funktion wie im Original)
def spiral(phi):
    x = (phi + 1) * torch.cos(phi)
    y = phi * torch.sin(phi)
    return torch.cat((x, y), dim=1)

def generate_data(num_data):
    angles = torch.empty((num_data, 1)).uniform_(1, 15)
    data = spiral(angles)
    data += torch.empty((num_data, 2)).normal_(0.0, 0.4)
    labels = torch.zeros((num_data,), dtype=torch.int)
    data[num_data // 2 :, :] *= -1
    labels[num_data // 2 :] = 1
    return data, labels

In [5]:

# 3. Model Definition
class SpiralNet(nn.Module):
    def __init__(self, config):
        super().__init__()
        layers = []
        layer_sizes = [2] + [config.hidden_size] * config.num_layers + [1]
        
        for i in range(len(layer_sizes)-1):
            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))
            if i < len(layer_sizes)-2:
                if config.activation == "relu":
                    layers.append(nn.ReLU())
                elif config.activation == "tanh":
                    layers.append(nn.Tanh())
                
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x)


In [6]:
# 4. Training Function
def train(config=None):
    with wandb.init(config=config):
        config = wandb.config
        
        # Get data
        x_train, y_train = generate_data(4000)
        x_val, y_val = generate_data(1000)
        train_dataset = TensorDataset(x_train, y_train)
        
        # Model & training setup
        model = SpiralNet(config)
        criterion = nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
        
        # Training loop
        for epoch in range(config.epochs):
            train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
            
            for x_batch, y_batch in train_loader:
                optimizer.zero_grad()
                output = model(x_batch).squeeze()
                loss = criterion(output, y_batch.float())
                loss.backward()
                optimizer.step()
            
            # Validation
            with torch.no_grad():
                val_output = model(x_val).squeeze()
                val_loss = criterion(val_output, y_val.float())
                val_preds = (torch.sigmoid(val_output) > 0.5).float()
                accuracy = (val_preds == y_val.float()).mean()
            
            wandb.log({
                "epoch": epoch,
                "train_loss": loss.item(),
                "val_loss": val_loss.item(),
                "val_accuracy": accuracy.item()
            })


In [7]:
# 5. Sweep Configuration
sweep_config = {
    'method': 'random',
    'metric': {'name': 'val_accuracy', 'goal': 'maximize'},
    'parameters': {
        'learning_rate': {'min': 1e-4, 'max': 1e-2},
        'batch_size': {'values': [16, 32, 64, 128]},
        'hidden_size': {'values': [32, 64, 128]},
        'num_layers': {'values': [2, 3, 4]},
        'activation': {'values': ['relu', 'tanh']},
        'epochs': {'value': 20}
    }
}

In [8]:

# 6. Run Sweep
wandb.login()
sweep_id = wandb.sweep(sweep_config, project="spiral-classification")
wandb.agent(sweep_id, train, count=20)

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
wandb: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:wandb: Appending key for api.wandb.ai to your netrc file: C:\Users\wusch\_netrc


Create sweep with ID: r9t9p5i8
Sweep URL: https://wandb.ai/wuschelschulz8/spiral-classification/sweeps/r9t9p5i8


wandb: Agent Starting Run: db63tn1b with config:
wandb: 	activation: tanh
wandb: 	batch_size: 64
wandb: 	epochs: 20
wandb: 	hidden_size: 64
wandb: 	learning_rate: 0.009356238634557546
wandb: 	num_layers: 2
wandb: Currently logged in as: wuschelschulz8. Use `wandb login --relogin` to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888925108, max=1.0…

Traceback (most recent call last):
  File "C:\Users\wusch\AppData\Local\Temp\ipykernel_22976\2663321464.py", line 32, in train
    accuracy = (val_preds == y_val.float()).mean()
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Bool


Run db63tn1b errored:
Traceback (most recent call last):
  File "c:\Users\wusch\working_directory\AI-verstehen-Winteraka-2024-25\venv\Lib\site-packages\wandb\agents\pyagent.py", line 306, in _run_job
    self._function()
  File "C:\Users\wusch\AppData\Local\Temp\ipykernel_22976\2663321464.py", line 32, in train
    accuracy = (val_preds == y_val.float()).mean()
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Bool

wandb: ERROR Run db63tn1b errored:
wandb: ERROR Traceback (most recent call last):
wandb: ERROR   File "c:\Users\wusch\working_directory\AI-verstehen-Winteraka-2024-25\venv\Lib\site-packages\wandb\agents\pyagent.py", line 306, in _run_job
wandb: ERROR     self._function()
wandb: ERROR   File "C:\Users\wusch\AppData\Local\Temp\ipykernel_22976\2663321464.py", line 32, in train
wandb: ERROR     accuracy = (val_preds == y_val.float()).mean()
wandb: ERROR      

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888925108, max=1.0…

Traceback (most recent call last):
  File "C:\Users\wusch\AppData\Local\Temp\ipykernel_22976\2663321464.py", line 32, in train
    accuracy = (val_preds == y_val.float()).mean()
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Bool


Run cdpo9lrg errored:
Traceback (most recent call last):
  File "c:\Users\wusch\working_directory\AI-verstehen-Winteraka-2024-25\venv\Lib\site-packages\wandb\agents\pyagent.py", line 306, in _run_job
    self._function()
  File "C:\Users\wusch\AppData\Local\Temp\ipykernel_22976\2663321464.py", line 32, in train
    accuracy = (val_preds == y_val.float()).mean()
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Bool

wandb: ERROR Run cdpo9lrg errored:
wandb: ERROR Traceback (most recent call last):
wandb: ERROR   File "c:\Users\wusch\working_directory\AI-verstehen-Winteraka-2024-25\venv\Lib\site-packages\wandb\agents\pyagent.py", line 306, in _run_job
wandb: ERROR     self._function()
wandb: ERROR   File "C:\Users\wusch\AppData\Local\Temp\ipykernel_22976\2663321464.py", line 32, in train
wandb: ERROR     accuracy = (val_preds == y_val.float()).mean()
wandb: ERROR      

Traceback (most recent call last):
  File "C:\Users\wusch\AppData\Local\Temp\ipykernel_22976\2663321464.py", line 32, in train
    accuracy = (val_preds == y_val.float()).mean()
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Bool


Run om9tw1et errored:
Traceback (most recent call last):
  File "c:\Users\wusch\working_directory\AI-verstehen-Winteraka-2024-25\venv\Lib\site-packages\wandb\agents\pyagent.py", line 306, in _run_job
    self._function()
  File "C:\Users\wusch\AppData\Local\Temp\ipykernel_22976\2663321464.py", line 32, in train
    accuracy = (val_preds == y_val.float()).mean()
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Bool

wandb: ERROR Run om9tw1et errored:
wandb: ERROR Traceback (most recent call last):
wandb: ERROR   File "c:\Users\wusch\working_directory\AI-verstehen-Winteraka-2024-25\venv\Lib\site-packages\wandb\agents\pyagent.py", line 306, in _run_job
wandb: ERROR     self._function()
wandb: ERROR   File "C:\Users\wusch\AppData\Local\Temp\ipykernel_22976\2663321464.py", line 32, in train
wandb: ERROR     accuracy = (val_preds == y_val.float()).mean()
wandb: ERROR      