In [1]:
!pip install -r https://raw.githubusercontent.com/google-ai-edge/ai-edge-torch/main/requirements.txt
!pip install ai-edge-torch-nightly

Looking in links: https://download.pytorch.org/whl/torch/, https://download.pytorch.org/whl/torchvision/, https://download.pytorch.org/whl/torchaudio/
Collecting ai-edge-torch-nightly
  Using cached ai_edge_torch_nightly-0.3.0.dev20240928-py3-none-any.whl.metadata (1.9 kB)
Using cached ai_edge_torch_nightly-0.3.0.dev20240928-py3-none-any.whl (291 kB)
Installing collected packages: ai-edge-torch-nightly
Successfully installed ai-edge-torch-nightly-0.3.0.dev20240928


In [2]:
# download file and save here: https://github.com/google-ai-edge/ai-edge-torch/blob/main/ai_edge_torch/generative/quantize/quant_recipes.py

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
class M5(nn.Module):
    def __init__(self, n_input=1, n_output=35, stride=16, n_channel=32):
        super().__init__()
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride)
        self.pool1 = nn.MaxPool1d(4)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.pool2 = nn.MaxPool1d(4)
        self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.pool3 = nn.MaxPool1d(4)
        self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.pool4 = nn.MaxPool1d(4)
        self.fc1 = nn.Linear(2 * n_channel, n_output)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.pool3(x)
        x = self.conv4(x)
        x = F.relu(x)
        x = self.pool4(x)
        x = F.avg_pool1d(x, x.shape[-1])
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        return F.log_softmax(x, dim=2)

# load model from pt
model = torch.load('model.pt')


  model = torch.load('model.pt')


In [20]:
import ai_edge_torch
sample_inputs = (torch.rand((256, 1, 8000)),)
edge_model = ai_edge_torch.convert(model.eval(), sample_inputs)

In [21]:
import sys
sys.path.append('/content')
import quant_recipes
quant_config = quant_recipes.full_int8_dynamic_recipe()
edge_model_compressed = ai_edge_torch.convert(
    model.eval(), sample_inputs, quant_config=quant_config
)

In [22]:
edge_model_compressed.export('model.tflite')

In [5]:
from torchaudio.datasets import SPEECHCOMMANDS
import os


class SubsetSC(SPEECHCOMMANDS):
    def __init__(self, subset: str = None, include_labels=None, silence_ratio=0.1, unknown_ratio=0.1):
        super().__init__("./", download=True)

        def load_list(filename):
            filepath = os.path.join(self._path, filename)
            with open(filepath) as fileobj:
                return [os.path.normpath(os.path.join(self._path, line.strip())) for line in fileobj]

        if subset == "validation":
            self._walker = load_list("validation_list.txt")
        elif subset == "testing":
            self._walker = load_list("testing_list.txt")
        elif subset == "training":
            excludes = load_list("validation_list.txt") + load_list("testing_list.txt")
            excludes = set(excludes)
            self._walker = [w for w in self._walker if w not in excludes]

        # 如果指定了 include_labels, 过滤 _walker 中的词
        if include_labels is not None:
            self._walker = [w for w in self._walker if self._get_label(w) in include_labels]

    def _get_label(self, filepath):
        # 根据文件路径推断出标签。一般文件路径中包含了标签名
        return os.path.normpath(filepath).split(os.sep)[-2]


# 你只想要的几个词，例如 "yes", "no", "up", "down"
selected_words = ["yes", "no"]
# selected_words = ["down", "go", "left", "no", "off", "on", "right", "stop", "up", "yes"]

# 创建仅包含选定词的训练和测试集
train_set = SubsetSC("training", include_labels=selected_words)
test_set = SubsetSC("testing", include_labels=selected_words)

waveform, sample_rate, label, speaker_id, utterance_number = train_set[0]

In [6]:
labels = sorted(list(set(datapoint[2] for datapoint in train_set)))

def label_to_index(word):
    # Return the position of the word in labels
    return torch.tensor(labels.index(word))


def index_to_label(index):
    # Return the word corresponding to the index in labels
    # This is the inverse of label_to_index
    return labels[index]

def pad_sequence(batch):
    # Make all tensor in a batch the same length by padding with zeros
    batch = [item.t() for item in batch]
    batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.)
    return batch.permute(0, 2, 1)


def collate_fn(batch):

    # A data tuple has the form:
    # waveform, sample_rate, label, speaker_id, utterance_number

    tensors, targets = [], []

    # Gather in lists, and encode labels as indices
    for waveform, _, label, *_ in batch:
        tensors += [waveform]
        targets += [label_to_index(label)]

    # Group the list of tensors into a batched tensor
    tensors = pad_sequence(tensors)
    targets = torch.stack(targets)

    return tensors, targets


batch_size = 256

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if device == "cuda":
    num_workers = 1
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False

train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
)
test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
)

pbar_update = 1 / (len(train_loader) + len(test_loader))

def number_of_correct(pred, target):
    # count number of correct predictions
    return pred.squeeze().eq(target).sum().item()


def get_likely_index(tensor):
    # find most likely label index for each element in the batch
    # print(tensor)
    tensor = torch.tensor(tensor)
    return tensor.argmax(dim=-1)

new_sample_rate = 8000
transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=new_sample_rate)

In [14]:
from tqdm import tqdm
n_epoch = 30

with tqdm(total=n_epoch) as pbar:
  pass

def test(model, epoch):
    # model.eval()
    correct = 0
    for data, target in test_loader:

        data = data.to(device)
        target = target.to(device)

        # apply transform and model on whole batch directly on device
        data = transform(data)
        try:
          # Code that might raise an error
          output = model(data)
          pred = get_likely_index(output)
          correct += number_of_correct(pred, target)

          # update progress bar
          pbar.update(pbar_update)
        except Exception as e:
          # Code to handle the error or simply pass
          print(f"Error: {e}")
          continue  # This will skip the error and continue execution

    print(f"\nTest Epoch: {epoch}\tAccuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n")

  0%|          | 0/30 [00:00<?, ?it/s]


In [8]:
test(model, 1)

  tensor = torch.tensor(tensor)



Test Epoch: 1	Accuracy: 244/824 (30%)


Test Epoch: 1	Accuracy: 494/824 (60%)


Test Epoch: 1	Accuracy: 743/824 (90%)


Test Epoch: 1	Accuracy: 797/824 (97%)



In [15]:
import ai_edge_torch
test(edge_model_compressed, 1)

Error: Cannot set tensor: Tensor is unallocated. Try calling allocate_tensors() first

Test Epoch: 1	Accuracy: 744/824 (90%)



In [23]:
# !apt-get update  # <-- run this if you get install errors
!apt-get -qq install xxd
!echo "Exporting model. Model size (in bytes):"
!stat --printf="%s" model.tflite
!xxd -i model.tflite > model.txt # xxd is just used to create a hex dump from model file

Exporting model. Model size (in bytes):
32728