In [1]:
import torch
import torch.nn as nn
from einops import rearrange, repeat, reduce
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from tqdm.notebook import tqdm
import random
import time

In [2]:
# 1d convolution neural network for signal classification
class SignalNet(nn.Module):
    def __init__(self, input_length, out_classes) -> None:
        super().__init__()
        self.input_length = input_length
        self.out_classes = out_classes
        self.layer = nn.Sequential(
            # 2 x 138
            nn.Conv1d(kernel_size=3, in_channels=2, out_channels=16),
            # 32 x 136
            nn.ReLU(),
            nn.BatchNorm1d(16),
            nn.AdaptiveAvgPool1d(34),
            # 32 x 34
            nn.Conv1d(kernel_size=3, in_channels=16, out_channels=16),
            # 32 x 32
            nn.ReLU(),
            nn.BatchNorm1d(16),
            nn.AdaptiveAvgPool1d(8),
            # 32 x 8
            nn.Flatten(),
            nn.Linear(16 * 8, 32),
            nn.ReLU(),
            nn.Linear(32, out_classes),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        return self.layer(x)
    
points = 138
sample_rate = 30_000

In [3]:
import serial
import pyvisa
global com
global device

def open_com(com_port):
    global com
    if globals().get("com") is not None:
        com.close()
    com = serial.Serial(
        port=com_port,
        baudrate=115200,
        bytesize=8,
        timeout=2,
        parity=serial.PARITY_NONE,
        stopbits=serial.STOPBITS_ONE,
    )

def open_device():
    global device
    rm = pyvisa.ResourceManager()
    devices = rm.list_resources()
    for d in devices:
        if d.startswith("USB0"):
            device = rm.open_resource(d)
            return
    raise Exception("No device found")

def popcount(x):
    return bin(x).count("1")

def collect_waveform(channels, points, sample_rate):
    com.write(b"\x01")
    # 1 byte for channel flag
    # 4 bytes for sample points
    # 4 bytes for sample rate
    com.write(int(channels).to_bytes(1, "little"))
    com.write(int(points).to_bytes(4, "little"))
    com.write(int(sample_rate).to_bytes(4, "little"))
    com.write(b"\xff\xff\xff")
    com.flush()
    head = com.read_until(b"\xff\xff\xff\xff")
    if head != b"\xff\xff\xff\xff":
        com.close()
        raise ValueError("Read timeout, invalid header")
    channel_count = popcount(channels)
    buf = com.read(channel_count * points * 2)
    arr = np.frombuffer(buf, dtype=np.int16)
    arr = rearrange(arr, "(n ch) -> ch n", ch=channel_count)
    return arr

def set_frequency(freq):
    com.write(b"\x02")
    com.write(int(freq).to_bytes(4, "little"))
    com.write(b"\xff\xff\xff")
    com.flush()

In [29]:
NONE = 0
FM = 1
AM = 2


def rand_not_in_range(center):
    while True:
        r = random.randint(30_000_000, 100_000_000)
        if abs(r - center) > 500_000:
            return r


def collect_fm_data(count_mod):
    global device
    # for each mod, generate 3 positive samples and 3 negative samples
    out_waveform = torch.zeros(count_mod * 6, 2, points)
    out_label = torch.zeros(count_mod * 6, dtype=torch.long)
    # (count_mod), range from 30MHz to 100MHz, uniform distribution
    baseband_freq = torch.randint(20_000_000, 100_000_000, (count_mod,))
    # (count_mod), range from 300Hz to 3kHz, uniform distribution
    mod_freq = torch.randint(300, 3000, (count_mod,))
    # (count_mod), range from 5kHz to 100kHz, uniform distribution
    mod_devi = torch.randint(5_000, 100_000, (count_mod,))
    idx = 0
    device.write("C2:MDWV FM")
    for i in tqdm(range(count_mod)):
        device.write(f"C2:BSWV FRQ,{baseband_freq[i]}")
        device.write(f"C2:MDWV FM,FRQ,{mod_freq[i]}")
        device.write(f"C2:MDWV FM,DEVI,{mod_devi[i]}")
        time.sleep(1)
        out_label[idx : idx + 6] = torch.LongTensor([FM, FM, FM, FM, NONE, NONE])
        scan_freq = [
            baseband_freq[i],
            baseband_freq[i] + random.randint(-100_000, +100_000),
            baseband_freq[i] + random.randint(-100_000, +100_000),
            baseband_freq[i] + 10_700_000 * 2,
            rand_not_in_range(baseband_freq[i]),
            rand_not_in_range(baseband_freq[i]),
        ]
        for j in range(6):
            set_frequency(scan_freq[j] - 10_700_000)
            time.sleep(0.01)
            tmp = torch.from_numpy(
                collect_waveform(0b11, points, sample_rate).astype(np.float32)
            )
            tmp = tmp * 2.5 / 32768
            out_waveform[idx + j, 0] = tmp[0]
            out_waveform[idx + j, 1] = tmp[1]
            out_waveform[idx + j, 0] -= torch.mean(out_waveform[idx + j, 0])
            out_waveform[idx + j, 1] -= torch.mean(out_waveform[idx + j, 1])
        idx += 6
    return out_waveform, out_label


def collect_am_data(count_mod):
    global device
    # for each mod, generate 3 positive samples and 3 negative samples
    out_waveform = torch.zeros(count_mod * 6, 2, points)
    out_label = torch.zeros(count_mod * 6, dtype=torch.long)
    # (count_mod), range from 30MHz to 100MHz, uniform distribution
    baseband_freq = torch.randint(20_000_000, 100_000_000, (count_mod,))
    # (count_mod), range from 200Hz to 10kHz, uniform distribution
    mod_freq = torch.randint(200, 10_000, (count_mod,))
    # (count_mod), range from 20 to 100, uniform distribution
    mod_depth = torch.randint(30, 100, (count_mod,))
    idx = 0
    device.write("C2:MDWV AM")
    for i in tqdm(range(count_mod)):
        device.write(f"C2:BSWV FRQ,{baseband_freq[i]}")
        device.write(f"C2:MDWV AM,FRQ,{mod_freq[i]}")
        device.write(f"C2:MDWV AM,DEPTH,{mod_depth[i]}")
        time.sleep(1)
        out_label[idx : idx + 6] = torch.LongTensor([AM, AM, AM, AM, NONE, NONE])
        scan_freq = [
            baseband_freq[i],
            baseband_freq[i] + random.randint(-100_000, +100_000),
            baseband_freq[i] + random.randint(-100_000, +100_000),
            baseband_freq[i] + 10_700_000 * 2,
            rand_not_in_range(baseband_freq[i]),
            rand_not_in_range(baseband_freq[i]),
        ]
        for j in range(6):
            set_frequency(scan_freq[j] - 10_700_000)
            time.sleep(0.01)
            tmp = torch.from_numpy(
                collect_waveform(0b0011, points, sample_rate).astype(np.float32)
            )
            tmp = tmp * 2.5 / 32768
            out_waveform[idx + j, 0] = tmp[0]
            out_waveform[idx + j, 1] = tmp[1]
            out_waveform[idx + j, 0] -= torch.mean(out_waveform[idx + j, 0])
            out_waveform[idx + j, 1] -= torch.mean(out_waveform[idx + j, 1])
        idx += 6
    return out_waveform, out_label

In [30]:
open_com("COM8")
open_device()
fm_waveform, fm_label = collect_fm_data(200)
am_waveform, am_label = collect_am_data(200)
waveform = torch.cat([fm_waveform, am_waveform], dim=0)
label = torch.cat([fm_label, am_label], dim=0)
torch.save(waveform, "waveform.pt")
torch.save(label, "label.pt")

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [3]:
def train():
    model = SignalNet(138, 3)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_fn = torch.nn.CrossEntropyLoss()
    waveform = torch.load("waveform.pt")
    label = torch.load("label.pt")
    dataset = torch.utils.data.TensorDataset(waveform, label)
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
    # show plot of accuracy
    for epoch in tqdm(range(10)):
        correct_count = 0
        for i, (waveform, label) in enumerate(train_loader):
            pred = model(waveform)
            loss = loss_fn(pred, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pred = torch.argmax(pred, dim=1)
            correct_count += torch.sum(pred == label)
            # print(f"Epoch {epoch}, batch {i}, loss {loss.item()}")
        print(f"Epoch {epoch}, accuracy {correct_count / len(dataset)}")
        torch.save(model.state_dict(), f"model_{epoch}.pt")

In [7]:
train()

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 0, accuracy 0.715833306312561
Epoch 1, accuracy 0.9208333492279053
Epoch 2, accuracy 0.9479166865348816
Epoch 3, accuracy 0.9312499761581421
Epoch 4, accuracy 0.9487500190734863
Epoch 5, accuracy 0.95333331823349
Epoch 6, accuracy 0.9558333158493042
Epoch 7, accuracy 0.9583333134651184
Epoch 8, accuracy 0.9541666507720947
Epoch 9, accuracy 0.95333331823349


In [8]:
model = SignalNet(138, 3)
model.load_state_dict(torch.load("model_9.pt"))
model.eval()
waveform = torch.load("waveform.pt")
label = torch.load("label.pt")
dataset = torch.utils.data.TensorDataset(waveform, label)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
# show confusion matrix
confusion_matrix = torch.zeros(3, 3)
for i, (waveform, label) in enumerate(test_loader):
    pred = model(waveform)
    pred = torch.argmax(pred, dim=1)
    for j in range(len(label)):
        confusion_matrix[label[j], pred[j]] += 1
print(confusion_matrix)

tensor([[778.,  20.,   2.],
        [  3., 791.,   6.],
        [ 45.,  44., 711.]])


In [68]:
# take live data and predict
set_frequency(50_000_000)
data = collect_waveform(0b0011, points, sample_rate)
data = torch.from_numpy(data.astype(np.float32))
data = data * 2.5 / 32768
data[0] -= torch.mean(data[0])
data[1] -= torch.mean(data[1])
data = data.unsqueeze(0)
pred = model(data)
pred = torch.argmax(pred, dim=1)
if pred == 0:
    print("NONE")
elif pred == 1:
    print("FM")
elif pred == 2:
    print("AM")

AM
