In [1]:
from transformers import WhisperTokenizer, WhisperFeatureExtractor, WhisperProcessor, WhisperForConditionalGeneration, WhisperConfig
import torch

# Create a vocab json including modulation types and 0~31 modulation indices
vocab = {
    "<|eos|>": 0,
    "<|startoftranscript|>": 1,
    "<|unk|>": 2,
    "<|pad|>": 3,
    "<|cls|>": 4,
}
vocab_len = len(vocab)
added_tokens = ["<|BPSK|>", "<|QPSK|>", "<|8PSK|>", "<|MSK|>", "<|8QAM|>", "<|16QAM|>", "<|32QAM|>", "<|8APSK|>", "<|16APSK|>", "<|32APSK|>", "<|unknownmod|>"]
for symb_wid in torch.linspace(0,1,21):
    added_tokens.append(f"<|{symb_wid:.2f}|>")
for added_token in added_tokens:
    vocab[added_token] = vocab_len
    vocab_len += 1
vocab_len = len(vocab)
for i in range(32):
    ch = chr(i + ord('0'))
    vocab[ch] = vocab_len
    vocab_len += 1

# Write to vocab.json
import json
with open("vocab.json", "w") as f:
    json.dump(vocab, f)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = WhisperTokenizer(vocab_file="vocab.json", 
                             merges_file="merges.txt", 
                             predict_timestamps=True, 
                             additional_special_tokens=added_tokens, 
                             unk_token="<|unk|>", bos_token="<|startoftranscript|>", eos_token="<|eos|>", pad_token="<|pad|>", cls_token="<|cls|>")
tokenizer.encode("<|cls|><|startoftranscript|>13AE<|eos|>")

[1, 4, 1, 38, 40, 54, 58, 0, 0]

In [3]:
feature_extractor = WhisperFeatureExtractor(sampling_rate=16000, feature_size=2)

In [4]:
processor = WhisperProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)
model_config = WhisperConfig(
    vocab_size=vocab_len,
    num_mel_bins=2,
    max_source_positions=1024,
    pad_token_id=vocab["<|pad|>"],
    bos_token_id=vocab["<|startoftranscript|>"],
    eos_token_id=vocab["<|eos|>"],
    decoder_start_token_id=vocab["<|startoftranscript|>"],
)
model = WhisperForConditionalGeneration(config=model_config)
model.save_pretrained("whisper-iq")

In [5]:
import numpy as np
inp = processor(np.array([1,3]), return_tensors="pt")

It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


In [6]:
inp.input_features.shape

torch.Size([1, 2, 3000])

In [7]:
model.train()
g = model.generate(torch.rand(1, 2, 300))
tokenizer.batch_decode(g)

['<|startoftranscript|><|startoftranscript|><|startoftranscript|><|cls|><|cls|><|16APSK|><|16APSK|><|16APSK|><|16APSK|><|16APSK|><|16APSK|><|16APSK|><|16APSK|><|16APSK|><|16APSK|><|16APSK|><|16APSK|><|16APSK|><|16APSK|><|16APSK|><|16APSK|>']

In [8]:
from torch.utils.data import Dataset
import glob
import os
import pandas as pd
from torch.nn.utils.rnn import pad_sequence as rnn_utils
from einops import rearrange

symb_type_char_dict = {
    1: "<|BPSK|>",
    2: "<|QPSK|>",
    3: "<|8PSK|>",
    4: "<|MSK|>",
    5: "<|8QAM|>",
    6: "<|16QAM|>",
    7: "<|32QAM|>",
    8: "<|8APSK|>",
    9: "<|16APSK|>",
    10: "<|32APSK|>",
    11: "<|unknownmod|>"
}

class SignalDataset(Dataset):
    def __init__(self, data_path):
        super(SignalDataset, self).__init__()
        # Recursively find all csv files in the data_path
        self.file_list = glob.glob(os.path.join(data_path, '**/*.csv'), recursive=True)
        self.cache = {}  # Dictionary for caching data

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, index):
        if index in self.cache:
            return self.cache[index]
        
        data = pd.read_csv(self.file_list[index], header=None, names=['I', 'Q', 'Code Sequence', 'Modulation Type', 'Symbol Width'])
        
        iq_wave = data[['I', 'Q']].values
        symb_seq = data['Code Sequence'].dropna().astype(int).values
        symb_type = data['Modulation Type'].values[0]
        symb_wid = data['Symbol Width'].values[0]

        iq_wave = torch.tensor(iq_wave, dtype=torch.float32)
        iq_wave = rearrange(iq_wave, 't c -> c t')
        # Pad the features to 2048
        iq_wave = torch.nn.functional.pad(iq_wave, (0, 2048 - iq_wave.shape[1]), mode='constant', value=0)

        symb_seq_chars = ''.join(map(chr, symb_seq + ord('0')))
        symb_type_char = symb_type_char_dict[symb_type]
        token_str = f'<|{symb_wid:1.2f}|>{symb_type_char}{symb_seq_chars}'
        target = tokenizer.encode(token_str, return_tensors="pt")
        # Cache processed data
        self.cache[index] = (iq_wave, target)

        return iq_wave, target
    
def _collator_fn(batch):
    input_features = rnn_utils([item[0] for item in batch], batch_first=True)
    labels = rnn_utils([item[1][0] for item in batch], batch_first=True, padding_value=-100)
    return {
        "input_features": input_features,
        "labels": labels,
    }

In [9]:
dataset = SignalDataset("train_data")
dataset[0]

(tensor([[ 0.0987,  0.0446,  0.0442,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0870,  0.1148, -0.0460,  ...,  0.0000,  0.0000,  0.0000]]),
 tensor([[ 1, 34, 13, 50, 47, 40, 38, 41, 50, 37, 43, 42, 43, 47, 40, 49, 42, 49,
          49, 40, 52, 44, 37, 40, 41, 37, 38, 49, 50, 47, 40, 42, 49, 48, 52, 39,
          42, 42, 48, 41, 41, 44, 42, 44, 44, 52, 37, 43, 48, 40, 41, 41, 41, 43,
          50, 48, 43, 50, 39, 50, 48, 49, 37, 41, 45, 43, 42, 44, 38, 48, 42, 41,
          51, 52, 46, 49, 43, 49, 38, 49, 43, 37, 40, 50, 43, 41, 44, 47, 40, 48,
          38, 38, 41, 52, 50, 41, 43, 41, 47, 40, 49, 49, 47, 46, 38,  0]]))

In [None]:
# Finetune Whisper on dataset
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-iq",
    run_name="whisper_finetune",
    learning_rate=1e-4,
    num_train_epochs=20,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    eval_strategy="no",
    save_strategy="steps",
    save_steps=100,
    report_to="none",
)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=_collator_fn
)
trainer.train()

  0%|          | 0/3599860 [00:00<?, ?it/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
  0%|          | 267/3599860 [01:04<212:30:32,  4.71it/s]

KeyboardInterrupt: 