In [1]:
import numpy as np
import  matplotlib.pyplot as plt
import torch
import torchaudio
from tqdm.notebook import trange
import os
import time
import datetime

In [None]:
torch.__version__

## Global parameters

In [3]:
identifier = str(datetime.datetime.now())
contact_threshold = 0.005
RESAMPLE_FACTOR = 24
PART_SIZE = 512

## Load dataset

In [4]:
def genFiles(idx):
    return [ f"./roughness_dataset/opus/stereo/{idx:02}_{variant}.wav" for variant in ("long_light", "long_loud", "long_medium", "short_light", "short_loud", "short_medium", "wiggle")]

In [None]:
rough_stones = [0,11,13,15,17,19,1,8,20,22]
rough_stones_test = [20,22]
                                             
for s in rough_stones_test:
    rough_stones.remove(s)
print("rough_training:",rough_stones)
print("rough_test:",rough_stones_test)
TRAIN_ROUGH_FILES = [f for i in rough_stones for f in genFiles(i)]
rough_wav = torch.cat([torchaudio.load(f)[0][0:2] for f in TRAIN_ROUGH_FILES], dim=1)
rough_wav = torchaudio.functional.resample(rough_wav, RESAMPLE_FACTOR, 1)
rough_wav.shape, rough_wav.dtype

In [None]:
smooth_stones = [2,3,4,5,6,7,9,10,12,14,16,18,21]
smooth_stones_test = [21]

for s in smooth_stones_test:
    smooth_stones.remove(s)
print("smooth_training:",smooth_stones)
print("smooth_test:",smooth_stones_test)
TRAIN_SMOOTH_FILES = [f for i in smooth_stones for f in genFiles(i)]
smooth_wav = torch.cat([torchaudio.load(f)[0][0:2] for f in TRAIN_SMOOTH_FILES], dim=1)
smooth_wav = torchaudio.functional.resample(smooth_wav, RESAMPLE_FACTOR, 1)
smooth_wav.shape, smooth_wav.dtype

## Cut into chunks

In [None]:
def cut_stepped(x):
    ret = []
    for i in range(x.shape[1]-PART_SIZE+1):
        ret.append(x[:,i:i+PART_SIZE])
    ret = torch.stack(ret, dim=0)
    return ret

rough_parts = cut_stepped(rough_wav)
smooth_parts = cut_stepped(smooth_wav)

rough_chunks = rough_parts.shape[0]
smooth_chunks = smooth_parts.shape[0]
imbalance_chunks = rough_chunks-smooth_chunks
print("rough_chunks:",rough_chunks,"smooth_chunks:",smooth_chunks,"imbalance_chunks:",imbalance_chunks)


## Identify contact

In [None]:
rough_loudness = rough_parts[:,0].abs().mean(dim=1)
rough_loud_parts = rough_parts[rough_loudness > rough_loudness.mean(),:]

smooth_loudness = smooth_parts[:,0].abs().mean(dim=1)
smooth_loud_parts = smooth_parts[smooth_loudness > smooth_loudness.mean(), :]

smooth_silent_parts = smooth_parts[smooth_loudness <= smooth_loudness.mean(),:]

rough_loudness_mean = rough_loudness.mean().item()
smooth_loudness_mean = smooth_loudness.mean().item()
print("rough_loudness:",rough_loudness_mean,"smooth_loudness:",smooth_loudness_mean)

rough_loud_parts_num = rough_loud_parts.shape[0]
smooth_loud_parts_num = smooth_loud_parts.shape[0]
smooth_silent_parts_num = smooth_silent_parts.shape[0]
print("rough_loud_parts_num:",rough_loud_parts_num,"smooth_loud_parts_num:",smooth_loud_parts_num,"smooth_silent_parts_num:",smooth_silent_parts_num)

## Train model

In [9]:
   
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        part_size = PART_SIZE
        
        self.resampler = torchaudio.transforms.Resample(RESAMPLE_FACTOR, 1, dtype=torch.float32)
        #self.layer1 = torch.nn.Linear((part_size//2+1), 1024) # one microphone
        self.layer1 = torch.nn.Linear((part_size//2+1)+(part_size//2+1), 1024) # two microphone
        self.layer2 = torch.nn.Linear(1024, 512)
        self.layer3 = torch.nn.Linear(512, 256)
        
        self.layer_resnet = torch.nn.ModuleList([torch.nn.Linear(256, 256) for i in range(10)])
        
        self.layer4 = torch.nn.Linear(256, 128)
        self.layer5 = torch.nn.Linear(128, 64)
        self.linear = torch.nn.Linear(64, 3)
    
    def forward(self, x):
        if x.shape[-1] != 512:
            x = self.resampler(x)
        
        fft_data = torch.fft.rfft(x[:,0])
        fft_as_real = torch.view_as_real(fft_data)
        fft = fft_as_real.norm(p=2, dim=-1)
        
        air_fft_data = torch.fft.rfft(x[:,1])
        air_fft_as_real = torch.view_as_real(air_fft_data)
        air_fft = air_fft_as_real.norm(p=2, dim=-1)
        
        #x = torch.nn.functional.relu(self.layer1(fft)) # one microphone
        x = torch.nn.functional.relu(self.layer1(torch.cat([fft, air_fft], dim=1))) # two microphone
        x = torch.nn.functional.relu(self.layer2(x))
        x = torch.nn.functional.relu(self.layer3(x))
        
        for layer in self.layer_resnet:
            x = x + torch.nn.functional.relu(layer(x))

        x = torch.nn.functional.relu(self.layer4(x))
        x = torch.nn.functional.relu(self.layer5(x))
        x = self.linear(x)
        x = torch.nn.functional.log_softmax(x, dim=-1)
        return x

In [None]:
start = time.time()

model = Model().cuda()
data = torch.cat([smooth_silent_parts, smooth_loud_parts, rough_loud_parts], dim=0)
labels = torch.cat([torch.zeros(smooth_silent_parts.shape[0], dtype=torch.long),
    torch.ones(smooth_loud_parts.shape[0], dtype=torch.long),
    2*torch.ones(rough_loud_parts.shape[0], dtype=torch.long)], dim=0)
class_weights = torch.tensor([1.0, 0.05, 0.05]).cuda()

dataset = torch.utils.data.TensorDataset(data, labels)
loader = torch.utils.data.DataLoader(dataset,
    shuffle=True,
    pin_memory=True,
    batch_size=6000,
    num_workers=6
)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

all_losses = []

for epoch in trange(5):
    for batch_data, batch_label in loader:
        optimizer.zero_grad()
        
        batch_data = batch_data.cuda()
        batch_data = batch_data + 0.005 * torch.randn_like(batch_data)
        
        loss = torch.nn.functional.nll_loss(model(batch_data.cuda()), batch_label.cuda(), weight=class_weights)
        loss.backward()
        optimizer.step()

        all_losses.append(loss.item())

end = time.time()
training_time = end-start
print("Trained for",training_time,"seconds")

In [None]:
plt.plot(all_losses)

### Save Model

In [13]:
try:
    os.mkdir(identifier)
except:
    pass
    
torch.save({
    'model': model.state_dict(),
    'RESAMPLE_FACTOR': RESAMPLE_FACTOR,
    'PART_SIZE': PART_SIZE,
    'contact_threshold': contact_threshold,
    'rough_stones': rough_stones,
    'rough_stones_test': rough_stones_test,
    'smooth_stones': smooth_stones,
    'smooth_stones_test': smooth_stones_test,
    'rough_chunks': rough_chunks,
    'smooth_chunks': smooth_chunks,
    'imbalance_chunks': imbalance_chunks,
    'training_time': training_time,
    'rough_loudness_mean': rough_loudness_mean,
    'smooth_loudness_mean': smooth_loudness_mean,
}, "./"+identifier+"/model.pt")

scriptModel = torch.jit.script(model.cpu())
scriptModel.save("./"+identifier+"/model_script.pt")