### Initialization

In [None]:
### Initialization block
from pathlib import Path
import numpy as np
import json
import torch
import numpy as np
from tqdm import tqdm
import math
from torch.utils.data import DataLoader, TensorDataset

STFT_LENGTH = 16 * 1024
DATA_DIR = Path("dataset/")
SAMPLE_RATE = 20e6
MODULATIONS = ["QPSK", "BPSK", "8-PSK", "8-QAM", "16-QAM", "GMSK", "2-FSK"]
MODULATION_LABELS = {j: i for i, j in enumerate(MODULATIONS)}
NUMBER_OF_MODULATIONS = len(MODULATIONS)

def load_data(snr, name, load_metadata_only=False):
    if not load_metadata_only:
        with open(DATA_DIR/str(snr)/str(name)/"data.dat", "rb") as f:
            signal = np.fromfile(f, dtype=np.complex128)
    else:
        signal = None
    with open(DATA_DIR/str(snr)/str(name)/"meta-data.json") as f:
        meta = json.load(f)
        if type(meta) == dict:
            meta = [meta]
    return signal, meta

    
def _get_all_numbered_dirs(root_dir):
    dirs = []
    for directory in root_dir.iterdir():
        dirs.append(int(directory.name))
    dirs.sort()
    return dirs

def get_signals(snr):
    return _get_all_numbered_dirs(Path(DATA_DIR)/str(snr))


def get_snrs(root_dir=DATA_DIR):
    return _get_all_numbered_dirs(root_dir)
        
        
def process_metadata(metadata):
    scaled_metadata =  [
        {
            "position": (SAMPLE_RATE/2 + i['fc'], i['bw']),
            "mod": i["mod"]
        }
        for i in metadata
    ]
    return scaled_metadata


def process_signal(signal):
    signal = signal[:STFT_LENGTH]

    signal = np.fft.fft(signal)
    signal = np.fft.fftshift(signal)
    signal /= np.max(np.abs(signal))
    
    #return np.expand_dims(signal, axis=0)
    return signal

### Data Loading

In [None]:
MASK_SIZE = int(STFT_LENGTH)

class WidebandSignalDataset(torch.utils.data.Dataset):
    def __init__(self, signal_ids, mask_size=MASK_SIZE, return_snr=False):
        self.mask_size = mask_size
        self.signal_ids = signal_ids
        self.return_snr = return_snr  # New parameter to control SNR return
        loaded_data = []
        for snr, signal_id in tqdm(self.signal_ids):
            signal, masks = self.process_signal(snr, signal_id)
            loaded_data.append((signal, masks))
        self.loaded_data = loaded_data

    def __len__(self):
        return len(self.signal_ids)

    def __getitem__(self, index):
        signal, masks = self.loaded_data[index]
        if self.return_snr:
            snr, _ = self.signal_ids[index]
            return signal, masks, snr  # Return SNR during evaluation
        else:
            return signal, masks  # Return only signal and masks during training

    def process_signal(self, snr, signal_id):
        signal, metadata = load_data(snr, signal_id)
        scaled_metadata = process_metadata(metadata)
        signal = process_signal(signal)
        signal = torch.from_numpy(signal)
        masks = torch.zeros(self.mask_size)
        scale_ratio = self.mask_size / SAMPLE_RATE
        for meta in scaled_metadata:
            f, b = meta['position']
            x1, x2 = math.floor((f - b / 2) * scale_ratio), math.ceil((f + b / 2) * scale_ratio)
            masks[x1:x2] = 1
        return signal.type(torch.complex64), masks.type(torch.FloatTensor)

# Train test split 80 - 10 - 10
train, test, validation = [], [], [] 
for snr in get_snrs():
    signals = get_signals(snr)
    total_signals = len(signals)
    for signal in signals:
        if signal <= 0.8 * total_signals:
            train.append((snr, signal))
        elif signal <= 0.9 * total_signals:
            validation.append((snr, signal))
        else:
            test.append((snr, signal))
            
print("Train", len(train))
print("Validation", len(validation))
print("Test", len(test))

train_dataset = WidebandSignalDataset(signal_ids=train)
validation_dataset = WidebandSignalDataset(signal_ids=validation)
test_dataset = WidebandSignalDataset(signal_ids=test)

### Batch Loading

In [None]:
batch_size = 64  # Updated batch size

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

print("Train labels shape:", len(train_dataset))
print("Validation labels shape:", len(validation_dataset))

### Early Stop

In [None]:
import os

class EarlyStopping:
    def __init__(self, patience=10, verbose=False, delta=0.0001, save_path='./models/CMuSeNet'):
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.best_model = None
        self.save_path = save_path
        os.makedirs(save_path, exist_ok=True)
        
    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        self.val_loss_min = val_loss
        self.best_model = model.state_dict()
        save_path = os.path.join(self.save_path, 'best_model.pth')
        torch.save(self.best_model, save_path)

### Reshape

In [None]:
import torch.nn as nn
import complexPyTorch.complexLayers as cplx
import torch.nn.functional as F
import torch

def reshape_to_2d(data):
    return data.view(-1, 1, 128, 128)  # Reshape to [batch, channels, height, width]

### Complex IoU

In [None]:
def calculate_iou(pred, target, threshold=0.5):
    real_pred = (pred.real > threshold).float()
    imag_pred = (pred.imag > threshold).float()
    
    combined_pred = torch.logical_or(real_pred, imag_pred).float()
    
    intersection = (combined_pred * target).sum(dim=1)
    union = (combined_pred + target).sum(dim=1) - intersection
    iou = (intersection / union).mean().item()
    return iou

### Training

In [None]:
import time

def validate_model(model, valid_loader, criterion):
    model.eval()
    running_loss = 0.0
    iou_scores = []
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for inputs, masks in tqdm(valid_loader, desc="Validating"):
            inputs = reshape_to_2d(inputs).to(device)
            masks = masks.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, masks)
            running_loss += loss.item()

            # Calculate IoU
            iou = calculate_iou(outputs, masks, threshold=0.5)
            iou_scores.append(iou)
            
            # Calculate accuracy
            preds = ((outputs.real > 0.5) & (outputs.imag > 0.5)).float()
            correct = (preds == masks).float().sum()
            total_correct += correct.item()
            total_samples += masks.numel()

    val_loss = running_loss / len(valid_loader)
    mean_iou = sum(iou_scores) / len(iou_scores)
    accuracy = total_correct / total_samples * 100

    print(f'Validation Loss: {val_loss:.6f}')
    print(f'Validation Accuracy: {accuracy:.2f}%')

    return val_loss, accuracy

def train_model(model, train_loader, valid_loader, criterion, initial_lr=0.001, lr_steps=[0.00001], num_epochs=50, patience=3):
    train_losses = []
    val_losses = []
    val_accuracies = []
    epoch_durations = []
    
    current_lr = initial_lr
    for lr in lr_steps:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        early_stopping = EarlyStopping(patience=patience, verbose=True, delta=0.001)
        print("Current learning rate: ", lr)
        for epoch in range(num_epochs):
            epoch_start_time = time.time()
            
            model.train()
            running_loss = 0.0
            for inputs, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
                inputs = reshape_to_2d(inputs).to(device)
                masks = masks.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, masks)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

            epoch_loss = running_loss / len(train_loader)
            train_losses.append(epoch_loss)
            print(f"Training Loss: {epoch_loss:.6f}")

            val_loss, val_accuracy = validate_model(model, valid_loader, criterion)
            val_losses.append(val_loss)
            val_accuracies.append(val_accuracy)
            early_stopping(val_loss, model)

            if early_stopping.early_stop:
                print("Early stopping triggered")
                break

            epoch_duration = time.time() - epoch_start_time
            epoch_durations.append(epoch_duration)
        if early_stopping.best_model is not None:
            print(f"Loading best model from lr {lr}")
            model.load_state_dict(early_stopping.best_model)
        
    print("Training completed.")
    print("Epoch durations:", epoch_durations)
    return model, train_losses, val_losses, val_accuracies, epoch_durations

### ResNet-18

In [None]:
import torch
import torch.nn as nn
import complexPyTorch.complexLayers as cplx
from typing import Optional, Callable, Type, Union, List
import torch.nn.functional as F
from torch import Tensor

def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> cplx.ComplexConv2d:
    """3x3 convolution with padding"""
    return cplx.ComplexConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> cplx.ComplexConv2d:
    """1x1 convolution"""
    return cplx.ComplexConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = cplx.ComplexBatchNorm2d(planes)
        self.relu = cplx.ComplexReLU()
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = cplx.ComplexBatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super(Bottleneck, self).__init__()
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = cplx.ComplexBatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn2 = cplx.ComplexBatchNorm2d(planes)
        self.conv3 = conv1x1(planes, planes * self.expansion)
        self.bn3 = cplx.ComplexBatchNorm2d(planes * self.expansion)
        self.relu = cplx.ComplexReLU()
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class ComplexResNet(nn.Module):
    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        num_classes: int = STFT_LENGTH,
        zero_init_residual: bool = False,
        groups: int = 1,
        width_per_group: int = 64,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super(ComplexResNet, self).__init__()
        if norm_layer is None:
            norm_layer = cplx.ComplexBatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1

        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = cplx.ComplexConv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = cplx.ComplexReLU()
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = cplx.ComplexLinear(512 * block.expansion, num_classes)
        self.sigmoid = cplx.ComplexSigmoid()

    def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1) -> nn.Sequential:
        norm_layer = self._norm_layer
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x: Tensor) -> Tensor:
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        x = self.sigmoid(x)
        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)

def ComplexResNet18():
    return ComplexResNet(BasicBlock, [2, 2, 2, 2])

# Create the model instance
model = ComplexResNet18()
print(model)


### Complex focal Loss

In [None]:
class ComplexFocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(ComplexFocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        real_inputs = inputs.real
        imag_inputs = inputs.imag
        
        real_BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction='none')
        imag_BCE_loss = F.binary_cross_entropy(imag_inputs, targets, reduction='none')
        
        real_pt = torch.exp(-real_BCE_loss)
        imag_pt = torch.exp(-imag_BCE_loss)
        
        real_F_loss = self.alpha * (1 - real_pt) ** self.gamma * real_BCE_loss
        imag_F_loss = self.alpha * (1 - imag_pt) ** self.gamma * imag_BCE_loss

        if self.reduction == 'mean':
            return (torch.mean(real_F_loss) + torch.mean(imag_F_loss)) / 2
        elif self.reduction == 'sum':
            return torch.sum(real_F_loss) + torch.sum(imag_F_loss)
        else:
            return real_F_loss + imag_F_loss

# Update the IoU calculation to handle complex values
def calculate_iou(pred, target, threshold=0.5):
    real_pred = (pred.real > threshold).float()
    imag_pred = (pred.imag > threshold).float()
    
    combined_pred = torch.logical_or(real_pred, imag_pred).float()
    
    intersection = (combined_pred * target).sum(dim=1)
    union = (combined_pred + target).sum(dim=1) - intersection
    iou = (intersection / union).mean().item()
    return iou

### Training with complex focal loss

In [None]:
# Initialize and train the CResNet-18 model
model = ComplexResNet18().to(device)
criterion = ComplexFocalLoss()

# Train the model and validate it
#0.001, 0.0001, 0.00001, 0.000001
model, train_losses, val_losses, val_accuracies, epoch_durations =train_model(model, train_loader, valid_loader, criterion, initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=3)
combined_epoch_time = sum(epoch_durations)
print(f"Total time spent in epochs: {combined_epoch_time:.2f} seconds.")

### CVNN RV-BCE and CV-BCE Loss function implementation

In [None]:
# RV BCE Loss Function Definition
class RealValuedBCELoss(nn.Module):
    def __init__(self, reduction='mean'):
        super(RealValuedBCELoss, self).__init__()
        self.reduction = reduction

    def forward(self, inputs, targets):
        # Use only the real part of the complex inputs
        real_inputs = inputs.real
        BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction=self.reduction)
        return BCE_loss

    
# CV BCE Loss Function Definition
class ComplexValuedBCELoss(nn.Module):
    def __init__(self, reduction='mean'):
        super(ComplexValuedBCELoss, self).__init__()
        self.reduction = reduction

    def forward(self, inputs, targets):
        real_inputs = inputs.real
        imag_inputs = inputs.imag

        # Calculate binary cross-entropy for both real and imaginary parts
        real_BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction=self.reduction)
        imag_BCE_loss = F.binary_cross_entropy(imag_inputs, targets, reduction=self.reduction)
        
        # Combine the losses (you can adjust the weighting if necessary)
        combined_BCE_loss = (real_BCE_loss + imag_BCE_loss) / 2
        return combined_BCE_loss

### RV-BCE Training

In [None]:
# Set the criterion for RV BCE
criterion = RealValuedBCELoss()

# Train the ResNet-18 model with RV BCE
device = torch.device('cuda')
model = ComplexResNet18().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Start training with the previously defined train_model function
model, train_losses, val_losses, val_accuracies, epoch_durations = train_model(
    model, train_loader, valid_loader, criterion, 
    initial_lr=0.001, lr_steps=[0.001, 0.0001, 0.00001, 0.000001], num_epochs=50, patience=3
)


### CV-BCE Training

In [None]:
# Set the criterion for CV BCE
criterion = ComplexValuedBCELoss()

# Train the ResNet-18 model with CV BCE
device = torch.device('cuda')
model = ComplexResNet18().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Start training with the previously defined train_model function
model, train_losses, val_losses, val_accuracies, epoch_durations = train_model(
    model, train_loader, valid_loader, criterion, 
    initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=3
)
combined_epoch_time = sum(epoch_durations)
print(f"Total time spent in epochs: {combined_epoch_time:.2f} seconds.")

### Plot training result (Accuracy, loss vs epoch)

In [None]:
import matplotlib.pyplot as plt
import json
import os

# Ensure the directory exists
output_dir = 'cvnn_results/segmentation'
os.makedirs(output_dir, exist_ok=True)

def save_metrics_to_json(train_losses, val_accuracies, epoch_durations, filename):
    """
    Save the training losses and validation accuracies to a JSON file.
    
    Args:
        train_losses (list): List of training losses.
        val_accuracies (list): List of validation accuracies.
        filename (str): The file name for the JSON file.
    """
    metrics = {
        "train_losses": train_losses,
        "val_accuracies": val_accuracies,
        "epoch_durations": epoch_durations
    }
    with open(os.path.join(output_dir, filename), 'w') as f:
        json.dump(metrics, f)

def plot_training_metrics(train_losses, val_accuracies, plot_filename):
    """
    Plot the training loss and validation accuracy, and mark the epoch where accuracy reaches 99%.
    
    Args:
        train_losses (list): List of training losses.
        val_accuracies (list): List of validation accuracies.
        plot_filename (str): The file name for saving the plot as SVG.
    """
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize=(14, 6))

    # Plot Training Loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Training Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.legend()

    # Plot Validation Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy (%)')
    plt.title('Validation Accuracy')
    plt.legend()

    # Find the first epoch where validation accuracy reaches or exceeds 99%
    for i, acc in enumerate(val_accuracies):
        if acc >= 99:
            first_99_epoch = i + 1  # Epochs are 1-based
            plt.axvline(first_99_epoch, color='r', linestyle='--', label=f'99% reached at epoch {first_99_epoch}')
            break

    plt.legend()
    plt.tight_layout()

    # Save the plot as an SVG file
    plt.savefig(os.path.join(output_dir, plot_filename), format='svg')
    plt.show()

# Save the metrics to JSON in cvnn_results/segmentation
save_metrics_to_json(train_losses, val_accuracies, epoch_durations, 'training_metrics.json')

# Plot the metrics and highlight when accuracy reaches 99%, saving the plot as SVG
plot_training_metrics(train_losses, val_accuracies, 'training_metrics_plot.svg')

### Evaluation 

In [None]:
# Load the pre-trained model for evaluation
import torch

device = "cuda"

model_path = "path/to/the/model" #Please change this to the model path you trained
model = ComplexResNet18().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()


In [None]:
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
import numpy as np

# Define thresholds for recall calculation
iou_thresholds = [0.5, 0.7, 0.9]

# Initialize metrics
snr_results = {}
total_accuracy = 0.0
total_samples = 0
iou_scores = {th: 0.0 for th in iou_thresholds}
recall_counts = {th: 0 for th in iou_thresholds}
BATCH_SIZE = 64
# Create DataLoader for the entire dataset
full_dataset = WidebandSignalDataset(signal_ids=train + validation + test, return_snr=True)
full_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

### Bounding Box

In [None]:
import torch
from collections import defaultdict
import time
from tqdm import tqdm
import torch
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment

def expand_true(array, distance=1):
    # Create kernel of appropriate size
    kernel = torch.ones((1, 1, distance * 2 + 1), device=array.device)
    array = array.unsqueeze(1).float()  # Add channel dimension
    result = F.conv1d(array, kernel, padding=distance)
    result = result.squeeze(1)  # Remove the extra dimension
    
    # Convert values greater than 0 to `True`
    return result > 0

# Define supporting functions based on your friend's code
def get_true_groups(tensor, device):
    assert tensor.dim() == 2, 'This function handles 2D tensor only'
    all_groups = []
    for i in range(tensor.size(0)):
        item = tensor[i]
        item = torch.cat([torch.tensor([False]).to(device), item, torch.tensor([False]).to(device)])
        diffs = item.float().diff()
        starts = (diffs == 1).nonzero(as_tuple=True)[0]
        ends = (diffs == -1).nonzero(as_tuple=True)[0] - 1
        groups = [(start.item(), end.item()) for start, end in zip(starts, ends)]
        all_groups.append(groups)
    return all_groups

def get_target_boxes(metadata, number_of_bins, sample_rate=SAMPLE_RATE):
    scale_ratio = number_of_bins / sample_rate
    targets = []
    masks = torch.zeros(number_of_bins)
    for meta in metadata:
        f, b = meta['position']
        x1, x2 = math.floor((f-b/2)*scale_ratio), math.ceil((f+b/2)*scale_ratio)
        masks[x1:x2] = 1
        targets.append((x1, x2))
    return targets, masks

def get_target_boxes_batch(batch_metadata, number_of_bins, sample_rate=SAMPLE_RATE):
    all_targets, all_masks = [], []
    for metadata in batch_metadata:
        targets, masks = get_target_boxes(metadata, number_of_bins, sample_rate)
        all_targets.append(targets)
        all_masks.append(masks)
    return all_targets, all_masks

def calculate_iou(box1, box2):
    intersection = max(0, min(box1[1], box2[1]) - max(box1[0], box2[0]))
    union = max(box1[1], box2[1]) - min(box1[0], box2[0])
    return intersection / union if union != 0 else 0

def match_targets(targets, preds):
    ious = []
    for target in targets:
        iou_targets = []
        for pred in preds:
            iou_targets.append(calculate_iou(target, pred))
        ious.append(iou_targets)
    return linear_sum_assignment(ious, maximize=True)

def match_targets_batch(batch_targets, batch_preds):
    all_assignments = []
    for targets, preds in zip(batch_targets, batch_preds):
        all_assignments.append(match_targets(targets, preds))
    return all_assignments

def calculate_matched_ious(target_boxes, prediction_boxes, matching):
    ious = [0 for _ in target_boxes]
    matching_dict = dict(zip(*matching))
    for target_index, target_box in enumerate(target_boxes):
        if target_index in matching_dict:
            box1 = target_box
            box2 = prediction_boxes[matching_dict[target_index]]
            ious[target_index] = calculate_iou(box1, box2)
    return ious

def calculate_matched_iou_mean_batch(batch_target_boxes, batch_pred_boxes, batch_matching):
    all_ious = []
    for args in zip(batch_target_boxes, batch_pred_boxes, batch_matching):
        all_ious.append(calculate_matched_ious(*args))
    return all_ious



In [None]:
from collections import defaultdict
from tqdm import tqdm
def model_predictor(signals):
    # Use the already loaded model and apply thresholding
    signals = reshape_to_2d(signals)
    outputs = model(signals)
    return expand_true(outputs.real > 0.5)  # Use real part for thresholding
def evaluate(predictor, data_loader, device="cuda"):
    snr_metrics = defaultdict(lambda: {
        "iou_sum": 0.0,
        "iou_count": 0,
        "recall_counts": defaultdict(int),
        "total_samples": defaultdict(int),
        "correct_pixels": 0,
        "total_pixels": 0
    })
    total_iou_sum, total_iou_count = 0.0, 0
    total_correct_pixels, total_total_pixels = 0, 0
    total_recall_counts = defaultdict(int)
    total_samples = defaultdict(int)

    for inputs, masks, snrs_in_batch in tqdm(data_loader, desc="Evaluating"):
        #inputs = inputs.to(device)
        inputs = reshape_to_2d(inputs).to(device)
        masks = masks.to(device)
        outputs = predictor(inputs)

        for i in range(len(snrs_in_batch)):
            snr = snrs_in_batch[i].item()
            mask = masks[i]
            output = outputs[i]

            # Ensure output matches mask shape
            if output.numel() != mask.numel():
                output = output.expand_as(mask) if output.numel() == 1 else output.reshape_as(mask)

            thresholded_output = (output.real >= 0.5).float()

            correct_pixels = (thresholded_output == mask).sum().item()
            total_pixels = mask.numel()
            snr_metrics[snr]["correct_pixels"] += correct_pixels
            snr_metrics[snr]["total_pixels"] += total_pixels
            total_correct_pixels += correct_pixels
            total_total_pixels += total_pixels

            target_boxes = get_true_groups(mask.unsqueeze(0), device=device)[0]
            pred_boxes = get_true_groups(thresholded_output.unsqueeze(0), device=device)[0]
            if not target_boxes or not pred_boxes:
                continue
            matching = match_targets(target_boxes, pred_boxes)
            matched_ious = calculate_matched_ious(target_boxes, pred_boxes, matching)

            snr_metrics[snr]["iou_sum"] += sum(matched_ious)
            snr_metrics[snr]["iou_count"] += len(matched_ious)
            total_iou_sum += sum(matched_ious)
            total_iou_count += len(matched_ious)

            for th in iou_thresholds:
                true_positives = sum(1 for iou in matched_ious if iou >= th)
                snr_metrics[snr]["recall_counts"][th] += true_positives
                snr_metrics[snr]["total_samples"][th] += len(target_boxes)
                total_recall_counts[th] += true_positives
                total_samples[th] += len(target_boxes)

    # Calculate overall metrics
    overall_accuracy = (total_correct_pixels / total_total_pixels) * 100 if total_total_pixels > 0 else 0
    overall_iou = total_iou_sum / total_iou_count if total_iou_count > 0 else 0
    overall_recall = {th: total_recall_counts[th] / total_samples[th] if total_samples[th] > 0 else 0 for th in iou_thresholds}

    # Print overall results
    print(f"Overall Accuracy: {overall_accuracy:.2f}%")
    print(f"Overall IoU Score: {overall_iou:.4f}")
    for th in iou_thresholds:
        print(f"Recall at threshold {th}: {overall_recall[th]:.4f}")

    # Print per-SNR results
    for snr, metrics in sorted(snr_metrics.items()):
        snr_accuracy = (metrics["correct_pixels"] / metrics["total_pixels"]) * 100 if metrics["total_pixels"] > 0 else 0
        snr_iou = metrics["iou_sum"] / metrics["iou_count"] if metrics["iou_count"] > 0 else 0
        print(f"SNR: {snr} dB - Accuracy: {snr_accuracy:.2f}%")
        print(f"   IoU: {snr_iou:.4f}")
        for th in iou_thresholds:
            recall = metrics["recall_counts"][th] / metrics["total_samples"][th] if metrics["total_samples"][th] > 0 else 0
            print(f"   Recall at threshold {th}: {recall:.4f}")

    return snr_metrics


In [None]:
snr_metrics = evaluate(model_predictor, full_loader, device=device)

### Plot and Save

In [None]:
import json
import matplotlib.pyplot as plt
from pathlib import Path

# Define the path for saving the JSON file and plots
save_path = Path("CMuSeNet_plots/Synthetic")
save_path.mkdir(parents=True, exist_ok=True)
json_file_path = save_path / "evaluation_results.json"

# Save metrics and plot results
def save_and_plot_results(snr_metrics, iou_thresholds):
    # Prepare data for plotting and JSON saving
    snr_values = sorted(snr_metrics.keys())
    iou_scores = [snr_metrics[snr]["iou_sum"] / snr_metrics[snr]["iou_count"] if snr_metrics[snr]["iou_count"] > 0 else 0 for snr in snr_values]
    accuracies = [(snr_metrics[snr]["correct_pixels"] / snr_metrics[snr]["total_pixels"]) * 100 if snr_metrics[snr]["total_pixels"] > 0 else 0 for snr in snr_values]
    recalls = {th: [(snr_metrics[snr]["recall_counts"][th] / snr_metrics[snr]["total_samples"][th]) if snr_metrics[snr]["total_samples"][th] > 0 else 0 for snr in snr_values] for th in iou_thresholds}

    # Save results to JSON
    results = {
        "SNR": snr_values,
        "IoU_Scores": iou_scores,
        "Accuracy": accuracies,
        "Recall": {str(th): recalls[th] for th in iou_thresholds}
    }
    with open(json_file_path, "w") as f:
        json.dump(results, f, indent=4)
    print(f"Results saved to {json_file_path}")

    # Plot IoU vs SNR
    plt.figure()
    plt.plot(snr_values, iou_scores, marker='o', label="IoU Score")
    plt.xlabel("SNR (dB)")
    plt.ylabel("IoU Score")
    plt.title("IoU Score vs. SNR")
    plt.grid(True)
    plt.legend()
    plt.savefig(save_path / "IoU_vs_SNR.png")
    plt.savefig(save_path / "IoU_vs_SNR.svg")
    plt.show()

    # Plot Accuracy vs SNR
    plt.figure()
    plt.plot(snr_values, accuracies, marker='o', label="Accuracy")
    plt.xlabel("SNR (dB)")
    plt.ylabel("Accuracy (%)")
    plt.title("Accuracy vs. SNR (Threshold 0.5)")
    plt.grid(True)
    plt.legend()
    plt.savefig(save_path / "Accuracy_vs_SNR.png")
    plt.savefig(save_path / "Accuracy_vs_SNR.svg")
    plt.show()

    # Plot Recall vs SNR for each threshold
    for th in iou_thresholds:
        plt.figure()
        plt.plot(snr_values, recalls[th], marker='o', label=f"Recall at {th}")
        plt.xlabel("SNR (dB)")
        plt.ylabel("Recall")
        plt.title(f"Recall vs. SNR (Threshold {th})")
        plt.grid(True)
        plt.legend()
        plt.savefig(save_path / f"Recall_vs_SNR_{th}.png")
        plt.savefig(save_path / f"Recall_vs_SNR_{th}.svg")
        plt.show()

# Call this after running evaluate() to save and plot results
save_and_plot_results(snr_metrics, iou_thresholds)