### Initialization

In [None]:
from pathlib import Path
import numpy as np
from scipy.signal import welch
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import math
import json

# Constants
START_INDEX = 10  # Skip first few samples
SIGNAL_LENGTH = 1024 * 16
SAMPLE_RATE = 20e6
MASK_SIZE = 1024 * 16  # Mask size for segmentation

# Functions for Signal Processing
def load_real_data(sample_path):
    """
    Load raw signal data from a .dat file.
    """
    with open(sample_path, "rb") as f:
        signal = np.fromfile(f, dtype=np.complex64)
    return signal

def load_data(signal_id):
    """
    Load signal data and its corresponding metadata.
    """
    signal = load_real_data(signal_id)
    metadata_file = signal_id.with_suffix(".json")
    if metadata_file.exists():
        with open(metadata_file, "r") as f:
            metadata = json.load(f)
    else:
        raise FileNotFoundError(f"Metadata file {metadata_file} not found for signal {signal_id}")
    return signal[START_INDEX:], metadata, metadata_file

def apply_psd(signal, Fs, NFFT):
    """
    Calculate the PSD and corresponding frequencies using Welch's method.
    """
    freqs, psd = welch(signal, fs=Fs, nfft=NFFT, return_onesided=False)
    psd = np.fft.fftshift(psd)
    freqs = np.fft.fftshift(freqs)
    return psd, freqs

def calculate_fft(signal):
    """
    Calculate the FFT of the signal and return real and imaginary parts as separate channels.
    """
    signal = signal[:SIGNAL_LENGTH]
    signal = np.fft.fft(signal)
    signal = np.fft.fftshift(signal)
    signal /= np.max(np.abs(signal))
    return signal

### Data Loading

In [None]:
# Dataset Class
class WidebandSignalDataset(Dataset):
    def __init__(self, signal_ids, mask_size=1024 * 16):
        """
        Initialize the dataset with signal IDs and the specified mask size.
        """
        self.mask_size = mask_size
        self.signal_ids = signal_ids
        self.loaded_data = [self.process_signal(signal_id) for signal_id in tqdm(self.signal_ids)]

    def __len__(self):
        return len(self.signal_ids)

    def __getitem__(self, index):
        return self.loaded_data[index]

    def process_signal(self, signal_id):
        signal, metadata, _ = load_data(signal_id)

        # Ensure signal length matches SIGNAL_LENGTH
        if len(signal) < SIGNAL_LENGTH:
            # Pad with zeros if the signal is shorter
            signal = np.pad(signal, (0, SIGNAL_LENGTH - len(signal)), mode='constant')
        elif len(signal) > SIGNAL_LENGTH:
            # Truncate if the signal is longer
            signal = signal[:SIGNAL_LENGTH]

        # Apply FFT
        signal = np.fft.fft(signal)
        signal = np.fft.fftshift(signal)
        signal /= np.max(np.abs(signal))  # Normalize
        complex_signal = torch.from_numpy(signal).type(torch.complex64).unsqueeze(0)  # Add channel dimension

        # Create mask with fixed size
        masks = torch.zeros(self.mask_size, dtype=torch.float32)
        scale_ratio = self.mask_size / SAMPLE_RATE
        scaled_metadata = process_metadata(metadata)
        for meta in scaled_metadata:
            f1, f2 = meta["position"]
            x1 = int(math.floor(f1 * scale_ratio))
            x2 = int(math.ceil(f2 * scale_ratio))
            masks[x1:x2] = 1

        return complex_signal, masks



def process_metadata(metadata):
    """
    Scale metadata to the dataset's frequency and bandwidth ranges.
    """
    scaled_metadata = [
        {
            "position": (
                math.floor((SAMPLE_RATE / 2 + i["fc"] - i["bw"] / 2) * SIGNAL_LENGTH / SAMPLE_RATE),
                math.ceil((SAMPLE_RATE / 2 + i["fc"] + i["bw"] / 2) * SIGNAL_LENGTH / SAMPLE_RATE)
            ),
            "snr": 1,  # Placeholder value
            "bw": i["bw"],
            "num": len(metadata),
            "esn0": 1,  # Placeholder value
        }
        for i in metadata
    ]
    return scaled_metadata

# Dataset Splitting and Initialization
NEW_DATA_DIR = Path("/data/bigred/ofh/0")
def get_real_signals(freq_directory):
    return list(freq_directory.rglob("*.dat"))

signal_dirs = get_real_signals(NEW_DATA_DIR)
total_signals = len(signal_dirs)

train_split = int(0.80 * total_signals)
validation_split = int(0.90 * total_signals)

train, validation, test = (
    signal_dirs[:train_split],
    signal_dirs[train_split:validation_split],
    signal_dirs[validation_split:]
)

print(f"Train set size: {len(train)}")
print(f"Validation set size: {len(validation)}")
print(f"Test set size: {len(test)}")

In [None]:
# Data Loaders
BATCH_SIZE = 64

train_dataset = WidebandSignalDataset(signal_ids=train)
validation_dataset = WidebandSignalDataset(signal_ids=validation)
test_dataset = WidebandSignalDataset(signal_ids=test)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

### CV-ResNet-18

In [None]:
import torch
import torch.nn as nn
import complexPyTorch.complexLayers as cplx
from typing import Optional, Callable, Type, Union, List
import torch.nn.functional as F
from torch import Tensor

def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> cplx.ComplexConv2d:
    """3x3 convolution with padding"""
    return cplx.ComplexConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> cplx.ComplexConv2d:
    """1x1 convolution"""
    return cplx.ComplexConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = cplx.ComplexBatchNorm2d(planes)
        self.relu = cplx.ComplexReLU()
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = cplx.ComplexBatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super(Bottleneck, self).__init__()
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = cplx.ComplexBatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn2 = cplx.ComplexBatchNorm2d(planes)
        self.conv3 = conv1x1(planes, planes * self.expansion)
        self.bn3 = cplx.ComplexBatchNorm2d(planes * self.expansion)
        self.relu = cplx.ComplexReLU()
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class ComplexResNet(nn.Module):
    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        num_classes: int = SIGNAL_LENGTH,
        zero_init_residual: bool = False,
        groups: int = 1,
        width_per_group: int = 64,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super(ComplexResNet, self).__init__()
        if norm_layer is None:
            norm_layer = cplx.ComplexBatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1

        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = cplx.ComplexConv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = cplx.ComplexReLU()
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = cplx.ComplexLinear(512 * block.expansion, num_classes)
        self.sigmoid = cplx.ComplexSigmoid()

    def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1) -> nn.Sequential:
        norm_layer = self._norm_layer
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x: Tensor) -> Tensor:
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        x = self.sigmoid(x)
        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)

def ComplexResNet18():
    return ComplexResNet(BasicBlock, [2, 2, 2, 2])

# Create the model instance
model = ComplexResNet18()
print(model)


### Early Stop

In [None]:
import os

class EarlyStopping:
    def __init__(self, patience=10, verbose=False, delta=0.0001, save_path='./path/to/model/save'):
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.best_model = None
        self.save_path = save_path
        os.makedirs(save_path, exist_ok=True)
        
    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        self.val_loss_min = val_loss
        self.best_model = model.state_dict()
        save_path = os.path.join(self.save_path, 'best_model.pth')
        torch.save(self.best_model, save_path)

### Focal loss and reshape

In [None]:
class ComplexFocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(ComplexFocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        real_inputs = inputs.real
        imag_inputs = inputs.imag
        
        real_BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction='none')
        imag_BCE_loss = F.binary_cross_entropy(imag_inputs, targets, reduction='none')
        
        real_pt = torch.exp(-real_BCE_loss)
        imag_pt = torch.exp(-imag_BCE_loss)
        
        real_F_loss = self.alpha * (1 - real_pt) ** self.gamma * real_BCE_loss
        imag_F_loss = self.alpha * (1 - imag_pt) ** self.gamma * imag_BCE_loss

        if self.reduction == 'mean':
            return (torch.mean(real_F_loss) + torch.mean(imag_F_loss)) / 2
        elif self.reduction == 'sum':
            return torch.sum(real_F_loss) + torch.sum(imag_F_loss)
        else:
            return real_F_loss + imag_F_loss

# Update the IoU calculation to handle complex values
def calculate_iou(pred, target, threshold=0.5):
    real_pred = (pred.real > threshold).float()
    imag_pred = (pred.imag > threshold).float()
    
    combined_pred = torch.logical_or(real_pred, imag_pred).float()
    
    intersection = (combined_pred * target).sum(dim=1)
    union = (combined_pred + target).sum(dim=1) - intersection
    iou = (intersection / union).mean().item()
    return iou
def reshape_to_2d(data):
    return data.view(-1, 1, 128, 128)  # Reshape to [batch, channels, height, width]

### BCE Loss

In [None]:
# CV BCE Loss Function Definition
class ComplexValuedBCELoss(nn.Module):
    def __init__(self, reduction='mean'):
        super(ComplexValuedBCELoss, self).__init__()
        self.reduction = reduction

    def forward(self, inputs, targets):
        real_inputs = inputs.real
        imag_inputs = inputs.imag

        # Calculate binary cross-entropy for both real and imaginary parts
        real_BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction=self.reduction)
        imag_BCE_loss = F.binary_cross_entropy(imag_inputs, targets, reduction=self.reduction)
        
        # Combine the losses (you can adjust the weighting if necessary)
        combined_BCE_loss = (real_BCE_loss + imag_BCE_loss) / 2
        return combined_BCE_loss

### Training from scratch

In [None]:
import time
device="cuda"
def validate_model(model, valid_loader, criterion):
    model.eval()
    running_loss = 0.0
    iou_scores = []
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for inputs, masks in tqdm(valid_loader, desc="Validating"):
            inputs = reshape_to_2d(inputs).to(device)
            masks = masks.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, masks)
            running_loss += loss.item()

            # Calculate IoU
            iou = calculate_iou(outputs, masks, threshold=0.5)
            iou_scores.append(iou)
            
            # Calculate accuracy
            preds = (outputs.real > 0.5).float()
            correct = (preds == masks).float().sum()
            total_correct += correct.item()
            total_samples += masks.numel()

    val_loss = running_loss / len(valid_loader)
    mean_iou = sum(iou_scores) / len(iou_scores)
    accuracy = total_correct / total_samples * 100

    print(f'Validation Loss: {val_loss:.6f}')
    print(f'Validation Accuracy: {accuracy:.2f}%')

    return val_loss, accuracy

def train_model(model, train_loader, valid_loader, criterion, initial_lr=0.001, lr_steps=[0.0001], num_epochs=50, patience=5):
    train_losses = []
    val_losses = []
    val_accuracies = []
    epoch_durations = []
    
    current_lr = initial_lr
    for lr in lr_steps:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        early_stopping = EarlyStopping(patience=patience, verbose=True, delta=0.001)
        print("Current learning rate: ", lr)
        for epoch in range(num_epochs):
            epoch_start_time = time.time()
            
            model.train()
            running_loss = 0.0
            for inputs, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
                inputs = reshape_to_2d(inputs).to(device)
                masks = masks.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, masks)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

            epoch_loss = running_loss / len(train_loader)
            train_losses.append(epoch_loss)
            print(f"Training Loss: {epoch_loss:.6f}")
            val_loss, val_accuracy = validate_model(model, valid_loader, criterion)
            val_losses.append(val_loss)
            val_accuracies.append(val_accuracy)
            early_stopping(val_loss, model)

            if early_stopping.early_stop:
                print("Early stopping triggered")
                break

            epoch_duration = time.time() - epoch_start_time
            epoch_durations.append(epoch_duration)
        if early_stopping.best_model is not None:
            print(f"Loading best model from lr {lr}")
            model.load_state_dict(early_stopping.best_model)
        
    print("Training completed.")
    print("Epoch durations:", epoch_durations)
    return model, train_losses, val_losses, val_accuracies, epoch_durations

In [None]:
# Initialize and train the ResNet-18 model
model = ComplexResNet18().to(device)
criterion = ComplexFocalLoss()

model, train_losses, val_losses, val_accuracies, epoch_durations =train_model(model, train_loader, valid_loader, criterion, initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=3)
combined_epoch_time = sum(epoch_durations)
print(f"Total time spent in epochs: {combined_epoch_time:.2f} seconds.")

### Transfer Learning Load pretrained model

In [None]:
# Path to the pre-trained model weights
pretrained_model_path = "path/to/model/save.pth" #Change this model to trained model
device="cuda"
# Initialize the model architecture
model = ComplexResNet18().to(device)

# Load the pre-trained weights
checkpoint = torch.load(pretrained_model_path)
model.load_state_dict(checkpoint, strict=False)

# Set all layers as trainable (if needed)
for param in model.parameters():
    param.requires_grad = True

In [None]:
# Define a new criterion and optimizer for fine-tuning
# You may select between Focal Loss or BCE as your criterion
#criterion = ComplexValuedBCELoss()  # or ComplexValuedBCELoss()
criterion = ComplexFocalLoss()
# Use a smaller learning rate for fine-tuning
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)

# Train the model (fine-tuning)
model, train_losses, val_losses, val_accuracies, epoch_durations= train_model(
    model, train_loader, valid_loader, criterion,
    initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=3
)
combined_epoch_time = sum(epoch_durations)
print(f"Total time spent in epochs: {combined_epoch_time:.2f} seconds.")

### Plot Result and save the figures and json

In [None]:
import os
import json
import matplotlib.pyplot as plt

# Define save directory
save_dir = 'CMuSeNet_results/segmentation'

# Create the directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)

# Plot training loss
plt.figure()
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss', color='blue')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Save the training loss figure as PNG and SVG
plt.savefig(os.path.join(save_dir, 'training_loss.png'))
plt.savefig(os.path.join(save_dir, 'training_loss.svg'))

# Show the training loss plot
plt.show()

# Plot validation accuracy
plt.figure()
plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy', color='green')
plt.title('Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

# Save the validation accuracy figure as PNG and SVG
plt.savefig(os.path.join(save_dir, 'validation_accuracy.png'))
plt.savefig(os.path.join(save_dir, 'validation_accuracy.svg'))

# Show the validation accuracy plot
plt.show()

# Save the actual data to a JSON file
results = {
    "train_losses": train_losses,
    "val_accuracies": val_accuracies
}

# Save JSON file
with open(os.path.join(save_dir, 'training_validation_results.json'), 'w') as f:
    json.dump(results, f)


### BIG-RED Evaluation (Over entire dataset)

In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
# Create a DataLoader for the entire dataset
BATCH_SIZE = 64  # Adjust based on available memory
entire_dataset = WidebandSignalDataset(signal_ids=signal_dirs)  # Use all signals
entire_loader = DataLoader(entire_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# Path to the pre-trained model weights
pretrained_model_path = "path/to/model/pretrained" 
device = "cuda" 

# Initialize the model architecture
model = ComplexResNet18().to(device)

# Load the pre-trained weights
checkpoint = torch.load(pretrained_model_path, map_location=device)
model.load_state_dict(checkpoint, strict=False)
model.eval()

# Function to evaluate accuracy
def evaluate_accuracy(model, data_loader):
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for inputs, masks in tqdm(data_loader, desc="Evaluating on Entire Dataset"):
            inputs = reshape_to_2d(inputs).to(device)
            masks = masks.to(device)

            outputs = model(inputs)
            preds = (outputs.real > 0.5).float()

            correct = (preds == masks).float().sum()
            total_correct += correct.item()
            total_samples += masks.numel()

    accuracy = total_correct / total_samples * 100
    print(f"Overall Accuracy on Entire Dataset: {accuracy:.2f}%")
    return accuracy

# Run the evaluation
overall_accuracy = evaluate_accuracy(model, entire_loader)

### Function definitions

In [None]:
import torch
from tqdm import tqdm
import numpy as np
from collections import defaultdict
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from torch.utils.data import ConcatDataset

In [None]:
# Load the pre-trained model for evaluation
device = "cuda"
model_path = "path/to/model/save.pth"
model = resnet18_1D().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()


In [None]:
full_dataset = ConcatDataset([
    WidebandSignalDataset(signal_ids=train, return_snrs=True),
    WidebandSignalDataset(signal_ids=validation, return_snrs=True),
    WidebandSignalDataset(signal_ids=test, return_snrs=True)
])

In [None]:
full_loader = DataLoader(full_dataset, batch_size=64, shuffle=False)

In [None]:
def expand_true(array, distance=1):
    # Create kernel of appropriate size
    kernel = torch.ones((1, 1, distance * 2 + 1), device=array.device)
    array = array.unsqueeze(1).float()  # Add channel dimension
    result = F.conv1d(array, kernel, padding=distance)
    result = result.squeeze(1)  # Remove the extra dimension
    return result > 0

def get_true_groups(tensor, device):
    assert tensor.dim() == 2, 'This function handles 2D tensor only'
    all_groups = []
    for i in range(tensor.size(0)):
        item = tensor[i]
        item = torch.cat([torch.tensor([False]).to(device), item, torch.tensor([False]).to(device)])
        diffs = item.float().diff()
        starts = (diffs == 1).nonzero(as_tuple=True)[0]
        ends = (diffs == -1).nonzero(as_tuple=True)[0] - 1
        groups = [(start.item(), end.item()) for start, end in zip(starts, ends)]
        all_groups.append(groups)
    return all_groups

def calculate_iou(box1, box2):
    intersection = max(0, min(box1[1], box2[1]) - max(box1[0], box2[0]))
    union = max(box1[1], box2[1]) - min(box1[0], box2[0])
    return intersection / union if union != 0 else 0

def match_targets(targets, preds):
    ious = []
    for target in targets:
        iou_targets = []
        for pred in preds:
            iou_targets.append(calculate_iou(target, pred))
        ious.append(iou_targets)
    cost_matrix = np.array(ious)
    row_ind, col_ind = linear_sum_assignment(-cost_matrix)
    return row_ind, col_ind

def calculate_matched_ious(target_boxes, prediction_boxes, matching):
    ious = [0 for _ in target_boxes]
    matching_dict = dict(zip(*matching))
    for target_index, target_box in enumerate(target_boxes):
        if target_index in matching_dict:
            pred_index = matching_dict[target_index]
            if pred_index < len(prediction_boxes):
                box1 = target_box
                box2 = prediction_boxes[pred_index]
                ious[target_index] = calculate_iou(box1, box2)
    return ious


In [None]:
def evaluate(predictor, data_loader, device="cuda"):
    iou_thresholds = [0.5, 0.7, 0.9]
    snr_metrics = defaultdict(lambda: {
        "iou_sum": 0.0,
        "iou_count": 0,
        "recall_counts": defaultdict(int),
        "total_samples": defaultdict(int),
        "correct_pixels": 0,
        "total_pixels": 0
    })
    total_iou_sum, total_iou_count = 0.0, 0
    total_correct_pixels, total_total_pixels = 0, 0
    total_recall_counts = defaultdict(int)
    total_samples = defaultdict(int)

    for batch in tqdm(data_loader, desc="Evaluating"):
        if len(batch) == 3:
            inputs, masks, snrs_in_batch = batch
        else:
            inputs, masks = batch
            snrs_in_batch = [0] * len(inputs)  # Default SNR if not provided

        inputs = inputs.to(device)
        masks = masks.to(device)
        outputs = predictor(inputs)

        for i in range(len(inputs)):
            mask = masks[i]
            output = outputs[i]

            # Resize output to match mask shape if necessary
            if output.numel() != mask.numel():
                output = output.expand_as(mask) if output.numel() == 1 else output.reshape_as(mask)

            thresholded_output = (output >= 0.5).float()

            correct_pixels = (thresholded_output == mask).sum().item()
            total_pixels = mask.numel()
            total_correct_pixels += correct_pixels
            total_total_pixels += total_pixels

            # Get SNR value and round it to the nearest integer
            snr = snrs_in_batch[i]
            if isinstance(snr, torch.Tensor):
                snr = snr.item()
            snr = int(round(snr))  # Round SNR to the nearest integer

            snr_metrics[snr]["correct_pixels"] += correct_pixels
            snr_metrics[snr]["total_pixels"] += total_pixels

            target_boxes = get_true_groups(mask.unsqueeze(0), device=device)[0]
            pred_boxes = get_true_groups(thresholded_output.unsqueeze(0), device=device)[0]
            if not target_boxes or not pred_boxes:
                continue
            matching = match_targets(target_boxes, pred_boxes)
            matched_ious = calculate_matched_ious(target_boxes, pred_boxes, matching)

            snr_metrics[snr]["iou_sum"] += sum(matched_ious)
            snr_metrics[snr]["iou_count"] += len(matched_ious)
            total_iou_sum += sum(matched_ious)
            total_iou_count += len(matched_ious)

            for th in iou_thresholds:
                true_positives = sum(1 for iou in matched_ious if iou >= th)
                snr_metrics[snr]["recall_counts"][th] += true_positives
                snr_metrics[snr]["total_samples"][th] += len(target_boxes)
                total_recall_counts[th] += true_positives
                total_samples[th] += len(target_boxes)

    # Calculate overall metrics
    overall_accuracy = (total_correct_pixels / total_total_pixels) * 100 if total_total_pixels > 0 else 0
    overall_iou = total_iou_sum / total_iou_count if total_iou_count > 0 else 0
    overall_recall = {
        th: total_recall_counts[th] / total_samples[th] if total_samples[th] > 0 else 0
        for th in iou_thresholds
    }

    # Print overall results
    print(f"Overall Accuracy: {overall_accuracy:.2f}%")
    print(f"Overall IoU Score: {overall_iou:.4f}")
    for th in iou_thresholds:
        print(f"Recall at threshold {th}: {overall_recall[th]:.4f}")

    # Print per-SNR results
    for snr in sorted(snr_metrics.keys()):
        metrics = snr_metrics[snr]
        snr_accuracy = (metrics["correct_pixels"] / metrics["total_pixels"]) * 100 if metrics["total_pixels"] > 0 else 0
        snr_iou = metrics["iou_sum"] / metrics["iou_count"] if metrics["iou_count"] > 0 else 0
        print(f"SNR: {snr} dB - Accuracy: {snr_accuracy:.2f}%")
        print(f"   IoU: {snr_iou:.4f}")
        for th in iou_thresholds:
            recall = metrics["recall_counts"][th] / metrics["total_samples"][th] if metrics["total_samples"][th] > 0 else 0
            print(f"   Recall at threshold {th}: {recall:.4f}")

    return snr_metrics


def model_predictor(signals):
    # Use the already loaded model and apply thresholding
    return expand_true(model(signals) > 0.5)


In [None]:
# Run evaluation on the full dataset
snr_metrics = evaluate(model_predictor, full_loader, device=device)

### Save and plot

In [None]:
import os
import json
import matplotlib.pyplot as plt

def save_results_and_plot(snr_metrics, save_path):
    """
    Saves evaluation results to a JSON file and generates plots for Accuracy, IoU, and Recall vs. SNR.
    Sets x-axis limits to range from -9 dB to 12 dB to eliminate blank space on the right.

    Args:
        snr_metrics (dict): The evaluation results obtained from the evaluate function.
        save_path (str): The directory path where results and plots will be saved.

    Outputs:
        - evaluation_results.json
        - accuracy_vs_snr.png and .svg
        - iou_vs_snr.png and .svg
        - recall_vs_snr.png and .svg
    """
    # Ensure the directory exists
    os.makedirs(save_path, exist_ok=True)
    
    # Extract data from snr_metrics
    snr_list = sorted(snr_metrics.keys())
    accuracy_list = []
    iou_list = []
    recall_05 = []
    recall_07 = []
    recall_09 = []
    
    # Prepare data for JSON serialization
    json_data = {}
    
    for snr in snr_list:
        metrics = snr_metrics[snr]
        snr_accuracy = (metrics["correct_pixels"] / metrics["total_pixels"]) * 100 if metrics["total_pixels"] > 0 else 0
        snr_iou = metrics["iou_sum"] / metrics["iou_count"] if metrics["iou_count"] > 0 else 0
        recall_at_05 = metrics["recall_counts"][0.5] / metrics["total_samples"][0.5] if metrics["total_samples"][0.5] > 0 else 0
        recall_at_07 = metrics["recall_counts"][0.7] / metrics["total_samples"][0.7] if metrics["total_samples"][0.7] > 0 else 0
        recall_at_09 = metrics["recall_counts"][0.9] / metrics["total_samples"][0.9] if metrics["total_samples"][0.9] > 0 else 0

        # Append to lists for plotting
        accuracy_list.append(snr_accuracy)
        iou_list.append(snr_iou)
        recall_05.append(recall_at_05)
        recall_07.append(recall_at_07)
        recall_09.append(recall_at_09)

        # Prepare data for JSON
        json_data[snr] = {
            "accuracy": snr_accuracy,
            "iou": snr_iou,
            "recall": {
                "0.5": recall_at_05,
                "0.7": recall_at_07,
                "0.9": recall_at_09,
            }
        }
    
    # Save json_data to JSON file
    json_file_path = os.path.join(save_path, 'evaluation_results.json')
    with open(json_file_path, 'w') as json_file:
        json.dump(json_data, json_file, indent=4)
    
    # Plot Accuracy vs. SNR
    plt.figure(figsize=(10, 6))
    plt.plot(snr_list, accuracy_list, marker='o', label='Accuracy')
    plt.title('Accuracy vs. SNR')
    plt.xlabel('SNR (dB)')
    plt.ylabel('Accuracy (%)')
    plt.grid(True)
    plt.legend()
    
    # Set x-axis limits
    plt.xlim(-9, 12)
    
    # Save the plot
    accuracy_png_path = os.path.join(save_path, 'accuracy_vs_snr.png')
    accuracy_svg_path = os.path.join(save_path, 'accuracy_vs_snr.svg')
    plt.savefig(accuracy_png_path, format='png', bbox_inches='tight')
    plt.savefig(accuracy_svg_path, format='svg', bbox_inches='tight')
    
    plt.show()
    plt.close()
    
    # Plot IoU vs. SNR
    plt.figure(figsize=(10, 6))
    plt.plot(snr_list, iou_list, marker='o', color='orange', label='IoU')
    plt.title('IoU vs. SNR')
    plt.xlabel('SNR (dB)')
    plt.ylabel('IoU')
    plt.grid(True)
    plt.legend()
    
    # Set x-axis limits
    plt.xlim(-9, 12)
    
    # Save the plot
    iou_png_path = os.path.join(save_path, 'iou_vs_snr.png')
    iou_svg_path = os.path.join(save_path, 'iou_vs_snr.svg')
    plt.savefig(iou_png_path, format='png', bbox_inches='tight')
    plt.savefig(iou_svg_path, format='svg', bbox_inches='tight')
    
    plt.show()
    plt.close()
    
    # Plot Recall at Different IoU Thresholds vs. SNR
    plt.figure(figsize=(10, 6))
    plt.plot(snr_list, recall_05, marker='o', label='Recall @ IoU 0.5')
    plt.plot(snr_list, recall_07, marker='s', label='Recall @ IoU 0.7')
    plt.plot(snr_list, recall_09, marker='^', label='Recall @ IoU 0.9')
    plt.title('Recall at Different IoU Thresholds vs. SNR')
    plt.xlabel('SNR (dB)')
    plt.ylabel('Recall')
    plt.grid(True)
    plt.legend()
    
    # Set x-axis limits
    plt.xlim(-9, 12)
    
    # Save the plot
    recall_png_path = os.path.join(save_path, 'recall_vs_snr.png')
    recall_svg_path = os.path.join(save_path, 'recall_vs_snr.svg')
    plt.savefig(recall_png_path, format='png', bbox_inches='tight')
    plt.savefig(recall_svg_path, format='svg', bbox_inches='tight')
    
    plt.show()
    plt.close()


In [None]:
# Assuming snr_metrics is the output from the evaluate function
# Set the save path
save_path = 'CMuSeNet_BIGRED_results'

# Call the function
save_results_and_plot(snr_metrics, save_path)
