In [1]:
import CnnModel as cnnModel
import os
import Settings
from ModelTeacher import ModelTeacher, LossHistory
from DataGenerator import DataGenerator
import json

import librosa
import librosa.display
import matplotlib.pyplot as plt
import numpy as np

In [7]:
model = cnnModel.create_model()

In [8]:
mt = ModelTeacher()

ids = mt.load_IDs()
songs = mt.group_by_song(ids)
#songs = mt.remove_excess_ids(songs)
ids_train, ids_test = mt.split_songs(songs, 42)
gen_train = DataGenerator(ids_train)
gen_test = DataGenerator(ids_test)

In [9]:
print(len(ids_train), len(ids_test))

377688 94872


In [10]:
lossHistory = LossHistory()
#history = mt.teach_model(model, gen_train, [lossHistory])
history = mt.teach_model(model, gen_train, [lossHistory], gen_test=gen_test, epochs=6)

#mt.save_weights(model)

plt.title('Loss')
plt.plot(history.history['loss'], label='train')
plt.legend()
plt.show()

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


In [11]:
historyFile = os.path.realpath("history.json")
json.dump(history.history, open(historyFile, 'w'))
lossHistoryFile = os.path.realpath("lossHistory.json")
json.dump(lossHistory.get_data(), open(lossHistoryFile, 'w'))

results = mt.test_model(model, gen_test)
print("test loss, test acc:", results)

test loss, test acc: [5.254683017730713, 0.7527006268501282]


In [19]:
history.history

{'loss': [18.264070510864258,
  18.22051429748535,
  18.165611267089844,
  18.089155197143555,
  17.96691131591797,
  17.720401763916016,
  17.059518814086914,
  14.069555282592773,
  9.99571418762207,
  9.20816421508789,
  8.946211814880371,
  8.783968925476074,
  8.650217056274414,
  8.529457092285156,
  8.534322738647461],
 'avg_acc': [0.06503904610872269,
  0.091145820915699,
  0.13539060950279236,
  0.2068750262260437,
  0.3148828446865082,
  0.45065099000930786,
  0.5633333325386047,
  0.6248569488525391,
  0.6933854818344116,
  0.7113021016120911,
  0.710885763168335,
  0.7121224999427795,
  0.7108855247497559,
  0.7126822471618652,
  0.7096484303474426],
 'val_loss': [18.246492385864258,
  18.197797775268555,
  18.123916625976562,
  18.017066955566406,
  17.841955184936523,
  17.453847885131836,
  16.237289428710938,
  10.450220108032227,
  9.12672233581543,
  8.843110084533691,
  8.43378734588623,
  8.265089988708496,
  8.28780746459961,
  8.162917137145996,
  8.26716804504394

In [12]:
plt.title('loss')
plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.legend()
plt.show()

In [13]:
plt.title('acc')
plt.plot(history.history['avg_acc'], label='avg_acc')
plt.plot(history.history['val_avg_acc'], label='val_avg_acc')
plt.legend()
plt.show()

In [15]:
historyData = lossHistory.get_data()

In [16]:
plt.title('history acc')
plt.plot(historyData['avg_acc'], label='avg_acc')
plt.legend()
plt.show()

In [17]:
plt.title('history loss')
plt.plot(historyData['loss'], label='loss')
plt.legend()
plt.show()

In [18]:
mt.save_weights(model)

## Тест

In [2]:
model = cnnModel.create_model()
model.load_weights(os.path.join(".", Settings.weights_path))

In [3]:
from AudioPreprocessor import AudioPreprocessor
from PredictionGenerator import PredictionGenerator

a = AudioPreprocessor()

audio = a.process_audiofile('./audio_2022-09-15_14-10-56.mp3')
dima_gen  = PredictionGenerator(audio)
out = model.predict(dima_gen)



In [4]:
import OutputBeautifier as ob

b_out = ob.beautify_outputs(out)

In [5]:
secs = []
with open("./dima_out.txt", mode='w') as dima:
    for i in range(len(b_out)):
        if b_out[i][0][0] != 1 or b_out[i][1][0] != 1 or b_out[i][2][0] != 1 or b_out[i][3][0] != 1 or b_out[i][4][0] != 1 or b_out[i][5][0] != 1:
            dima.write(str(i * Settings.hop_length / Settings.sr_downs) + 's:\n')
            secs.append(i * Settings.hop_length / Settings.sr_downs)
            for j in range(6):
                for k in range(21):
                    dima.write(str(b_out[i][j][k]) + ' ')
                dima.write('\n')
            dima.write('\n')

In [7]:
fig, ax = plt.subplots()
img = librosa.display.specshow(librosa.amplitude_to_db(np.swapaxes(audio, 1, 0),
                                                       ref=np.max),
                               y_axis='log', x_axis='time', ax=ax)
ax.set_title('Power spectrogram')
fig.colorbar(img, ax=ax, format="%+2.0f dB")
#for i in range(len(secs)):
#    plt.plot([secs[i], secs], [-25, 9300], color='cyan')
plt.show()