In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np

In [None]:
class SignalCountClassifier(nn.Module):
    def __init__(self, input_dim=1024, num_classes=3):  # 2 to 4 signals
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)  # logits
        )

    def forward(self, x):
        return self.model(x)  # Use CrossEntropyLoss


In [None]:
class SignalCountDataset(Dataset):
    def __init__(self, fft_data, signal_counts, input_dim=1024):
        self.fft_data = fft_data
        self.signal_counts = signal_counts
        self.input_dim = input_dim

    def __len__(self):
        return len(self.fft_data)

    def __getitem__(self, idx):
        spectrum = self.fft_data[idx][:self.input_dim]
        spectrum = np.log1p(spectrum)
        spectrum = (spectrum - np.mean(spectrum)) / (np.std(spectrum) + 1e-6)
        label = self.signal_counts[idx] - 2  # map 2,3,4 → 0,1,2
        return torch.tensor(spectrum, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

In [None]:
def train_model(model, dataloader, optimizer, epochs=5):
    criterion = nn.CrossEntropyLoss()
    model.train()
    for epoch in range(epochs):
        total_loss, correct, total = 0, 0, 0
        for x, y in dataloader:
            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}, Acc: {correct/total:.4f}")


In [None]:
if __name__ == "__main__":
    # Simulate data
    N = 50
    fft_data = [np.abs(np.random.randn(2048)) for _ in range(N)]  # reduced FFT size
    signal_counts = np.random.choice([2, 3, 4], size=N)

    # Dataloader
    dataset = SignalCountDataset(fft_magnitude_shifted, signal_counts)
    loader = DataLoader(dataset, batch_size=8, shuffle=True)

    # Model
    model = SignalCountClassifier()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Train
    train_model(model, loader, optimizer)