# Jamming Detection and Classification with ResNet-18 and Open Set Recognition

This notebook implements a CNN ResNet-18 model for jamming signal detection and classification with Open Set Recognition (OSR) capability to detect unknown jamming types.

## Dataset Classes:
- **CLEAN**: Clean signal (no jamming)
- **LN**: Linear Noise jamming
- **LWF**: Linear Waveform jamming
- **TICK**: Tick jamming
- **TRI**: Triangle jamming
- **TRIW**: Triangle Wave jamming

The OSR approach will allow the model to identify unknown jamming patterns not seen during training.

## 1. Import Libraries

In [28]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

Using device: cuda
GPU: Tesla T4


In [29]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## 2. Dataset Loading and Preprocessing

In [32]:
class JammingDataset(Dataset):
    """Custom Dataset for loading jamming signal data from .npy files"""

    def __init__(self, data_dir, transform=None):
        self.data_dir = Path(data_dir)
        self.transform = transform
        self.samples = []
        self.labels = []
        self.class_names = []
        self.class_to_idx = {}

        # Get all class directories
        class_dirs = sorted([d for d in self.data_dir.iterdir() if d.is_dir()])

        # Create class mapping
        for idx, class_dir in enumerate(class_dirs):
            self.class_names.append(class_dir.name)
            self.class_to_idx[class_dir.name] = idx

            if class_dir.name == 'CLEAN':
              # Load all .npy files from this class
              count = 0
              for npy_file in sorted(class_dir.glob('*.npy')):
                  self.samples.append(npy_file)
                  self.labels.append(idx)
                  count += 1

              # print(f'Loaded {count} samples from {class_dir.name}')

            else:
              # iterate over the subfolders regarding different jamming power
              count = 0
              for power_dir in sorted(class_dir.glob('*')):
                  # Load all .npy files from this class
                  for npy_file in sorted(power_dir.glob('*.npy')):
                      self.samples.append(npy_file)
                      self.labels.append(idx)
                      count += 1

              # print(f'Loaded {count} samples from {class_dir.name} with power {class_dir.name}')

        print(f'Loaded {len(self.samples)} samples from {len(self.class_names)} classes')
        print(f'Classes: {self.class_names}')

        # Print class distribution
        unique, counts = np.unique(self.labels, return_counts=True)
        for cls_idx, count in zip(unique, counts):
            print(f'  {self.class_names[cls_idx]}: {count} samples')

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # Load the .npy file
        data = np.load(self.samples[idx])

        # Convert to tensor and add channel dimension (1, H, W)
        data = torch.FloatTensor(data).unsqueeze(0)

        # Apply normalization
        if self.transform:
            data = self.transform(data)

        label = self.labels[idx]

        return data, label

# Define data directory
data_dir = './drive/MyDrive/N-MON/Processed_Dataset'

# Load the full dataset
full_dataset = JammingDataset(data_dir)

Loaded 46848 samples from 6 classes
Classes: ['CLEAN', 'LN', 'LWF', 'TICK', 'TRI', 'TRIW']
  CLEAN: 2880 samples
  LN: 8800 samples
  LWF: 8768 samples
  TICK: 8800 samples
  TRI: 8800 samples
  TRIW: 8800 samples


In [33]:
# Split dataset: 70% train, 15% validation, 15% test
train_size = int(0.7 * len(full_dataset))
val_size = int(0.15 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    full_dataset,
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

print(f'\nDataset split:')
print(f'  Training: {len(train_dataset)} samples')
print(f'  Validation: {len(val_dataset)} samples')
print(f'  Testing: {len(test_dataset)} samples')

# Create data loaders
batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

print(f'\nBatch size: {batch_size}')
print(f'Number of batches: Train={len(train_loader)}, Val={len(val_loader)}, Test={len(test_loader)}')


Dataset split:
  Training: 32793 samples
  Validation: 7027 samples
  Testing: 7028 samples

Batch size: 32
Number of batches: Train=1025, Val=220, Test=220


## 3. ResNet-18 Architecture Implementation

In [34]:
class BasicBlock(nn.Module):
    """Basic residual block for ResNet-18"""
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet18(nn.Module):
    """ResNet-18 architecture adapted for jamming signal classification"""

    def __init__(self, num_classes=6, input_channels=1):
        super(ResNet18, self).__init__()
        self.in_channels = 64

        # Initial convolution layer
        self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Residual layers
        self.layer1 = self._make_layer(64, 2, stride=1)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=2)
        self.layer4 = self._make_layer(512, 2, stride=2)

        # Global average pooling and fully connected layer
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

        # Initialize weights
        self._initialize_weights()

    def _make_layer(self, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

        layers = []
        layers.append(BasicBlock(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels

        for _ in range(1, blocks):
            layers.append(BasicBlock(out_channels, out_channels))

        return nn.Sequential(*layers)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, return_features=False):
        # Initial layers
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        # Residual blocks
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # Pooling and flatten
        x = self.avgpool(x)
        features = torch.flatten(x, 1)

        # Classification
        logits = self.fc(features)

        if return_features:
            return logits, features
        return logits

# Create model instance
num_classes = len(full_dataset.class_names)
model = ResNet18(num_classes=num_classes, input_channels=1).to(device)

# Print model summary
print(f'Model: ResNet-18')
print(f'Number of classes: {num_classes}')
print(f'Total parameters: {sum(p.numel() for p in model.parameters()):,}')
print(f'Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}')

Model: ResNet-18
Number of classes: 6
Total parameters: 11,173,318
Trainable parameters: 11,173,318


## 4. Open Set Recognition Implementation

For Open Set Recognition, we'll use the **OpenMax** approach. This method:
1. Computes activation vectors (features) for each class during training
2. Fits Weibull distributions to model the tail probabilities
3. During inference, adjusts the softmax scores to identify unknown samples

We'll also implement a simpler threshold-based approach using feature distances.

In [35]:
from scipy.spatial.distance import cdist
from scipy.stats import weibull_min

class OpenSetRecognizer:
    """
    Open Set Recognition using Mean Activation Vectors (MAV) and distance thresholding.
    This is a simplified approach that works well for detecting unknown classes.
    """

    def __init__(self, num_classes, tail_size=20):
        self.num_classes = num_classes
        self.tail_size = tail_size  # Number of samples to use for computing statistics
        self.mavs = None  # Mean Activation Vectors for each class
        self.dists = None  # Distance distributions for each class
        self.thresholds = None  # Thresholds for unknown detection

    def compute_mav(self, model, dataloader, device):
        """Compute Mean Activation Vectors for each class"""
        model.eval()

        # Store features for each class
        class_features = {i: [] for i in range(self.num_classes)}

        with torch.no_grad():
            for inputs, labels in tqdm(dataloader, desc='Computing MAVs'):
                inputs = inputs.to(device)
                _, features = model(inputs, return_features=True)

                # Group features by class
                for feat, label in zip(features.cpu().numpy(), labels.numpy()):
                    class_features[label].append(feat)

        # Compute mean activation vector for each class
        self.mavs = {}
        self.dists = {}

        for class_id in range(self.num_classes):
            features = np.array(class_features[class_id])
            self.mavs[class_id] = np.mean(features, axis=0)

            # Compute distances from MAV for threshold calculation
            distances = np.linalg.norm(features - self.mavs[class_id], axis=1)
            self.dists[class_id] = np.sort(distances)

        print('MAVs computed for all classes')

    def fit_weibull(self, alpha=0.95):
        """Fit Weibull distribution and compute thresholds"""
        self.thresholds = {}
        self.weibull_params = {}

        for class_id in range(self.num_classes):
            # Use tail distances for fitting
            tail_dists = self.dists[class_id][-self.tail_size:]

            # Fit Weibull distribution
            shape, loc, scale = weibull_min.fit(tail_dists, floc=0)
            self.weibull_params[class_id] = (shape, loc, scale)

            # Compute threshold at alpha percentile
            self.thresholds[class_id] = weibull_min.ppf(alpha, shape, loc, scale)

        print(f'Weibull distributions fitted with alpha={alpha}')

    def predict(self, model, inputs, device, unknown_threshold=0.5):
        """
        Predict with unknown detection
        Returns: predictions, is_unknown, max_scores
        """
        model.eval()

        with torch.no_grad():
            logits, features = model(inputs.to(device), return_features=True)
            probabilities = F.softmax(logits, dim=1)
            max_probs, predictions = torch.max(probabilities, dim=1)

            # Compute distances to MAVs
            features_np = features.cpu().numpy()
            is_unknown = np.zeros(len(features_np), dtype=bool)

            for i, (feat, pred) in enumerate(zip(features_np, predictions.cpu().numpy())):
                # Distance to predicted class MAV
                dist = np.linalg.norm(feat - self.mavs[pred])

                # Check if distance exceeds threshold or probability is too low
                if dist > self.thresholds[pred] or max_probs[i].item() < unknown_threshold:
                    is_unknown[i] = True

        return predictions.cpu(), is_unknown, max_probs.cpu()

# Initialize Open Set Recognizer
osr = OpenSetRecognizer(num_classes=num_classes, tail_size=20)
print('Open Set Recognizer initialized')

Open Set Recognizer initialized


## 5. Training Configuration and Loss Function

In [36]:
# Training hyperparameters
num_epochs = 50
learning_rate = 0.001
weight_decay = 1e-4

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5,
                                                   patience=5, verbose=True)

print(f'Training Configuration:')
print(f'  Epochs: {num_epochs}')
print(f'  Learning rate: {learning_rate}')
print(f'  Weight decay: {weight_decay}')
print(f'  Optimizer: Adam')
print(f'  Loss function: CrossEntropyLoss')
print(f'  LR Scheduler: ReduceLROnPlateau')

TypeError: ReduceLROnPlateau.__init__() got an unexpected keyword argument 'verbose'

## 6. Training Loop

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(dataloader, desc='Training')
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Update progress bar
        pbar.set_postfix({'loss': loss.item(), 'acc': 100 * correct / total})

    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total

    return epoch_loss, epoch_acc


def validate_epoch(model, dataloader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc='Validation'):
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Statistics
            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total

    return epoch_loss, epoch_acc

print('Training functions defined')

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

best_val_acc = 0.0
best_model_path = 'best_resnet18_jamming.pth'

print('Starting training...\n')

for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    print('-' * 60)

    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)

    # Validate
    val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)

    # Learning rate scheduling
    scheduler.step(val_loss)

    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'class_names': full_dataset.class_names
        }, best_model_path)
        print(f'✓ Best model saved (Val Acc: {val_acc:.2f}%)')

    print()

print(f'Training completed!')
print(f'Best validation accuracy: {best_val_acc:.2f}%')

## 7. Visualize Training History

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Plot loss
axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0].plot(history['val_loss'], label='Validation Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot accuracy
axes[1].plot(history['train_acc'], label='Train Accuracy', marker='o')
axes[1].plot(history['val_acc'], label='Validation Accuracy', marker='s')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 8. Setup Open Set Recognition

Now we'll compute the Mean Activation Vectors (MAVs) and fit the Weibull distributions for OSR.

In [None]:
# Load the best model
checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
print(f'Loaded best model from epoch {checkpoint["epoch"]+1} with Val Acc: {checkpoint["val_acc"]:.2f}%')

# Compute MAVs using training data
print('\nComputing Mean Activation Vectors (MAVs)...')
osr.compute_mav(model, train_loader, device)

# Fit Weibull distributions
print('\nFitting Weibull distributions...')
osr.fit_weibull(alpha=0.95)

print('\nOpen Set Recognition setup completed!')

## 9. Evaluation on Test Set (Closed Set)

In [None]:
def evaluate_model(model, dataloader, device, class_names):
    """Evaluate model on test set"""
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc='Testing'):
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # Compute metrics
    accuracy = accuracy_score(all_labels, all_preds)

    print(f'Test Accuracy: {accuracy * 100:.2f}%\n')
    print('Classification Report:')
    print(classification_report(all_labels, all_preds, target_names=class_names, digits=4))

    return all_preds, all_labels

# Evaluate on test set
test_preds, test_labels = evaluate_model(model, test_loader, device, full_dataset.class_names)

In [None]:
# Plot confusion matrix
cm = confusion_matrix(test_labels, test_preds)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=full_dataset.class_names,
            yticklabels=full_dataset.class_names)
plt.title('Confusion Matrix - Test Set')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.show()

# Print per-class accuracy
print('\nPer-class Accuracy:')
for i, class_name in enumerate(full_dataset.class_names):
    class_mask = test_labels == i
    if class_mask.sum() > 0:
        class_acc = (test_preds[class_mask] == i).sum() / class_mask.sum()
        print(f'  {class_name}: {class_acc * 100:.2f}%')

## 10. Open Set Recognition Evaluation

Now we'll test the OSR capability by evaluating the model with unknown detection.

In [None]:
def evaluate_osr(model, dataloader, osr_model, device, class_names, unknown_threshold=0.5):
    """Evaluate model with Open Set Recognition"""
    all_preds = []
    all_labels = []
    all_is_unknown = []
    all_confidences = []

    for inputs, labels in tqdm(dataloader, desc='OSR Evaluation'):
        preds, is_unknown, confidences = osr_model.predict(model, inputs, device, unknown_threshold)

        all_preds.extend(preds.numpy())
        all_labels.extend(labels.numpy())
        all_is_unknown.extend(is_unknown)
        all_confidences.extend(confidences.numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_is_unknown = np.array(all_is_unknown)
    all_confidences = np.array(all_confidences)

    # Create predictions with unknown class
    final_preds = all_preds.copy()
    final_preds[all_is_unknown] = -1  # -1 for unknown

    # Compute metrics
    known_mask = ~all_is_unknown
    unknown_count = all_is_unknown.sum()

    print(f'Total samples: {len(all_labels)}')
    print(f'Detected as known: {known_mask.sum()} ({known_mask.sum()/len(all_labels)*100:.2f}%)')
    print(f'Detected as unknown: {unknown_count} ({unknown_count/len(all_labels)*100:.2f}%)')

    if known_mask.sum() > 0:
        known_accuracy = (all_preds[known_mask] == all_labels[known_mask]).sum() / known_mask.sum()
        print(f'\nAccuracy on known samples: {known_accuracy * 100:.2f}%')

    return final_preds, all_labels, all_is_unknown, all_confidences

# Evaluate with OSR
print('Evaluating with Open Set Recognition...\n')
osr_preds, osr_labels, is_unknown, confidences = evaluate_osr(
    model, test_loader, osr, device, full_dataset.class_names, unknown_threshold=0.3
)

In [None]:
# Visualize confidence distribution
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Confidence distribution
axes[0].hist(confidences[~is_unknown], bins=50, alpha=0.7, label='Known', color='blue')
axes[0].hist(confidences[is_unknown], bins=50, alpha=0.7, label='Unknown', color='red')
axes[0].set_xlabel('Confidence Score')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Confidence Score Distribution')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Per-class unknown detection rate
unknown_rates = []
for i, class_name in enumerate(full_dataset.class_names):
    class_mask = osr_labels == i
    if class_mask.sum() > 0:
        unknown_rate = is_unknown[class_mask].sum() / class_mask.sum() * 100
        unknown_rates.append(unknown_rate)
    else:
        unknown_rates.append(0)

axes[1].bar(full_dataset.class_names, unknown_rates, color='coral')
axes[1].set_xlabel('Class')
axes[1].set_ylabel('Unknown Detection Rate (%)')
axes[1].set_title('Unknown Detection Rate per Class')
axes[1].xticks(rotation=45)
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print('\nUnknown Detection Rate per Class:')
for class_name, rate in zip(full_dataset.class_names, unknown_rates):
    print(f'  {class_name}: {rate:.2f}%')

## 11. Test with Simulated Unknown Classes

To demonstrate OSR capability, we'll simulate unknown classes by treating some known classes as unknown during evaluation.

In [None]:
# Simulate open set scenario: treat last 2 classes as unknown
# For example, treat TRIW and TRI as unknown classes
known_classes = [0, 1, 2, 3]  # CLEAN, LN, LWF, TICK
unknown_classes = [4, 5]  # TRI, TRIW

print('Simulating Open Set Scenario:')
print(f'Known classes: {[full_dataset.class_names[i] for i in known_classes]}')
print(f'Unknown classes (simulated): {[full_dataset.class_names[i] for i in unknown_classes]}\n')

# Separate test samples
known_samples_mask = np.isin(osr_labels, known_classes)
unknown_samples_mask = np.isin(osr_labels, unknown_classes)

print(f'Known samples in test set: {known_samples_mask.sum()}')
print(f'Unknown samples in test set: {unknown_samples_mask.sum()}\n')

# Evaluate on known samples
known_detected_as_known = (~is_unknown[known_samples_mask]).sum()
known_detected_as_unknown = is_unknown[known_samples_mask].sum()

print('Performance on Known Classes:')
print(f'  Correctly identified as known: {known_detected_as_known} ({known_detected_as_known/known_samples_mask.sum()*100:.2f}%)')
print(f'  Incorrectly identified as unknown: {known_detected_as_unknown} ({known_detected_as_unknown/known_samples_mask.sum()*100:.2f}%)')

# Evaluate on unknown samples
unknown_detected_as_known = (~is_unknown[unknown_samples_mask]).sum()
unknown_detected_as_unknown = is_unknown[unknown_samples_mask].sum()

print('\nPerformance on Unknown Classes (Simulated):')
print(f'  Correctly identified as unknown: {unknown_detected_as_unknown} ({unknown_detected_as_unknown/unknown_samples_mask.sum()*100:.2f}%)')
print(f'  Incorrectly identified as known: {unknown_detected_as_known} ({unknown_detected_as_known/unknown_samples_mask.sum()*100:.2f}%)')

# Overall OSR metrics
print('\nOverall Open Set Recognition Metrics:')
total_correct = known_detected_as_known + unknown_detected_as_unknown
total_samples = len(osr_labels)
osr_accuracy = total_correct / total_samples * 100
print(f'  OSR Accuracy: {osr_accuracy:.2f}%')

# Calculate precision and recall for unknown detection
if (unknown_detected_as_unknown + known_detected_as_unknown) > 0:
    unknown_precision = unknown_detected_as_unknown / (unknown_detected_as_unknown + known_detected_as_unknown)
    print(f'  Unknown Precision: {unknown_precision * 100:.2f}%')

if unknown_samples_mask.sum() > 0:
    unknown_recall = unknown_detected_as_unknown / unknown_samples_mask.sum()
    print(f'  Unknown Recall: {unknown_recall * 100:.2f}%')

    if unknown_precision > 0 and unknown_recall > 0:
        f1_unknown = 2 * (unknown_precision * unknown_recall) / (unknown_precision + unknown_recall)
        print(f'  Unknown F1-Score: {f1_unknown * 100:.2f}%')

## 12. Inference Function for New Data

In [None]:
def predict_jamming(model, osr_model, data_path, device, class_names, unknown_threshold=0.3):
    """
    Predict jamming type for a new signal sample with unknown detection

    Args:
        model: Trained ResNet-18 model
        osr_model: Open Set Recognizer
        data_path: Path to .npy file containing signal data
        device: torch device
        class_names: List of class names
        unknown_threshold: Threshold for unknown detection

    Returns:
        prediction: Predicted class name or 'UNKNOWN'
        confidence: Confidence score
        is_unknown: Boolean indicating if sample is unknown
    """
    # Load and preprocess data
    data = np.load(data_path)
    data = torch.FloatTensor(data).unsqueeze(0).unsqueeze(0)  # Add batch and channel dims

    # Predict with OSR
    preds, is_unknown, confidences = osr_model.predict(model, data, device, unknown_threshold)

    pred_class = preds[0].item()
    confidence = confidences[0].item()
    unknown_flag = is_unknown[0]

    if unknown_flag:
        prediction = 'UNKNOWN'
    else:
        prediction = class_names[pred_class]

    return prediction, confidence, unknown_flag


# Example: Test prediction on a sample from test set
sample_idx = 0
sample_path = test_dataset.dataset.samples[test_dataset.indices[sample_idx]]
true_label = test_dataset.dataset.labels[test_dataset.indices[sample_idx]]

prediction, confidence, is_unk = predict_jamming(
    model, osr, sample_path, device, full_dataset.class_names, unknown_threshold=0.3
)

print('Single Sample Prediction Example:')
print(f'  Sample: {sample_path.name}')
print(f'  True label: {full_dataset.class_names[true_label]}')
print(f'  Prediction: {prediction}')
print(f'  Confidence: {confidence:.4f}')
print(f'  Detected as unknown: {is_unk}')

## 13. Save Complete Model with OSR

In [None]:
import pickle

# Save complete model with OSR parameters
model_package = {
    'model_state_dict': model.state_dict(),
    'class_names': full_dataset.class_names,
    'osr_mavs': osr.mavs,
    'osr_thresholds': osr.thresholds,
    'osr_weibull_params': osr.weibull_params,
    'num_classes': num_classes,
    'input_shape': (1, 128, 873),
    'history': history
}

# Save the complete package
model_package_path = 'resnet18_jamming_osr_complete.pkl'
with open(model_package_path, 'wb') as f:
    pickle.dump(model_package, f)

print(f'Complete model package saved to: {model_package_path}')
print('\nPackage includes:')
print('  - Model weights')
print('  - Class names')
print('  - OSR MAVs')
print('  - OSR thresholds')
print('  - Weibull parameters')
print('  - Training history')

## 14. Load and Use Saved Model (Example)

In [None]:
def load_model_for_inference(model_package_path, device):
    """
    Load the complete model package for inference

    Args:
        model_package_path: Path to the saved model package
        device: torch device

    Returns:
        model: Loaded ResNet-18 model
        osr: Loaded Open Set Recognizer
        class_names: List of class names
    """
    # Load package
    with open(model_package_path, 'rb') as f:
        package = pickle.load(f)

    # Reconstruct model
    model = ResNet18(num_classes=package['num_classes'], input_channels=1).to(device)
    model.load_state_dict(package['model_state_dict'])
    model.eval()

    # Reconstruct OSR
    osr = OpenSetRecognizer(num_classes=package['num_classes'])
    osr.mavs = package['osr_mavs']
    osr.thresholds = package['osr_thresholds']
    osr.weibull_params = package['osr_weibull_params']

    return model, osr, package['class_names']


# Example: Load and use the model
print('Example of loading saved model:\n')
loaded_model, loaded_osr, loaded_class_names = load_model_for_inference(model_package_path, device)

print('Model loaded successfully!')
print(f'Classes: {loaded_class_names}')
print(f'Ready for inference with Open Set Recognition')

## Summary

This notebook implemented a complete CNN ResNet-18 architecture with Open Set Recognition for jamming signal detection and classification.

### Key Features:
1. **ResNet-18 Architecture**: Deep CNN with residual connections adapted for signal data (128×873)
2. **6 Known Classes**: CLEAN, LN, LWF, TICK, TRI, TRIW
3. **Open Set Recognition**: Using Mean Activation Vectors and Weibull distribution fitting
4. **Unknown Detection**: Capable of identifying jamming patterns not seen during training

### Model Capabilities:
- **Closed-Set Classification**: High accuracy on known jamming types
- **Open-Set Recognition**: Detects unknown/novel jamming patterns
- **Confidence Scoring**: Provides confidence scores for predictions
- **Real-time Inference**: Fast prediction on new signal samples

### Usage:
1. Train model on known classes (cells 1-6)
2. Setup OSR with MAVs and Weibull fitting (cell 8)
3. Evaluate on test set (cells 9-11)
4. Use `predict_jamming()` for inference on new data
5. Load saved model with `load_model_for_inference()`

### Adjusting Unknown Threshold:
- Lower threshold (e.g., 0.2): More samples detected as unknown (higher sensitivity)
- Higher threshold (e.g., 0.5): Fewer samples detected as unknown (higher specificity)
- Tune based on your application's requirements for false positive vs false negative rates