In [None]:
# run in kaggle to fetch repo

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()

GITHUB_TOKEN = user_secrets.get_secret("GITHUB_MORSE_TOKEN")
USER = "SwedishSquid"
REPO_NAME = 'KC25_morse'
CLONE_URL = f"https://{USER}:{GITHUB_TOKEN}@github.com/{USER}/{REPO_NAME}.git"
get_ipython().system(f"git clone {CLONE_URL}")

import sys
sys.path.append("/kaggle/working/KC25_morse/src")

import morse

Cloning into 'KC25_morse'...
remote: Enumerating objects: 114, done.[K
remote: Counting objects: 100% (114/114), done.[K
remote: Compressing objects: 100% (81/81), done.[K
remote: Total 114 (delta 64), reused 79 (delta 29), pack-reused 0 (from 0)[K
Receiving objects: 100% (114/114), 16.83 MiB | 33.39 MiB/s, done.
Resolving deltas: 100% (64/64), done.


In [None]:
!pip install Levenshtein
!pip install MorseCodePy

Collecting Levenshtein
  Downloading levenshtein-0.27.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.6 kB)
Collecting rapidfuzz<4.0.0,>=3.9.0 (from Levenshtein)
  Downloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading levenshtein-0.27.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (161 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m161.7/161.7 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m42.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, Levenshtein
Successfully installed Levenshtein-0.27.1 rapidfuzz-3.13.0
Collecting MorseCodePy
  Downloading morsecodepy-4.1.tar.gz (9.5 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting req

In [None]:
import wandb
import os
from kaggle_secrets import UserSecretsClient

secret_value_0 = UserSecretsClient().get_secret('WANDB_API_KEY')
os.environ["WANDB_API_KEY"] = secret_value_0

common_wandb_kvals = {
    'project': 'KC25',
    'entity': 'fishwere',
}

# let there be no noise
os.environ["WANDB_SILENT"] = "true"

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import Levenshtein
import time
import torchaudio
import librosa

from morse.models import CNNResidualBlock, TransformerResidualBlock, PoolingTransition, CNNTransformer, CTCHead
from morse.models import MySomething
from morse.models import SimpleCNN
from morse.my_datasets import ListDataset, load_tensors, filenames_to_torch
from morse.samplers import LongCTCSampler
# from morse.augmentations import rotation_transform, volume_signal_transform
from morse.augmentations import make_volume_signal_transform, make_compose_transform, make_noise_signal_transform, make_runtime_rotation_transform, make_runtime_mel_bounded_noise_transform
from morse.text_helpers import Vectorizer, encode_to_morse, decode_from_morse

from morse.my_datasets import generate_dataset, read_dataset_from_files

In [None]:
labels_dir = '/kaggle/input/kc25-dataset-copy'
audio_dir = '/kaggle/input/kc25-dataset-copy/morse_dataset/morse_dataset'


dev_flag = False


full_train_df = pd.read_csv(Path(labels_dir, 'train.csv'))
test_df = pd.read_csv(Path(labels_dir, 'test.csv'))
full_train_df.head()

Unnamed: 0,id,message
0,1.opus,03ЩУЫЛПИГХ
1,2.opus,ЪЛТ0ДС6А3Г
2,3.opus,5ЭКЫБЗХЯН
3,4.opus,ЖЫЦОИ68КФ
4,5.opus,32Ю7МЫ ЗЛ


# real

In [None]:
from sklearn.model_selection import train_test_split

train_index, val_index = train_test_split(np.arange(full_train_df.shape[0]), test_size=1/6, shuffle=True, 
                                           random_state=42)
real_val_set = read_dataset_from_files(audio_dir, 
                                       filenames = full_train_df.iloc[val_index]['id'], 
                                       labels=list(full_train_df.iloc[val_index]['message']))
print(len(real_val_set))

real_train_set = read_dataset_from_files(audio_dir, 
                                       filenames = full_train_df.iloc[train_index]['id'], 
                                       labels=list(full_train_df.iloc[train_index]['message']))
print(len(real_train_set))

100%|██████████| 5000/5000 [03:23<00:00, 24.59it/s]

5000





# some helpers

In [None]:
index_to_letter = sorted(set(''.join(full_train_df['message'])))
pad_value = 0
print(index_to_letter)
letter_to_index = dict([(letter, i) for i, letter in enumerate(index_to_letter)])
dictionary_size = len(index_to_letter)
print(dictionary_size)
print(letter_to_index)

vectorizer = Vectorizer(letter_to_index, index_to_letter)
print(vectorizer.text_transform('ПРИВЕТ #'))


def batch_text_transform(texts):
    vecs, lengths = vectorizer.batch_text_transform(texts, pad_value=pad_value)
    return vecs + 1, lengths

[' ', '#', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'А', 'Б', 'В', 'Г', 'Д', 'Е', 'Ж', 'З', 'И', 'Й', 'К', 'Л', 'М', 'Н', 'О', 'П', 'Р', 'С', 'Т', 'У', 'Ф', 'Х', 'Ц', 'Ч', 'Ш', 'Щ', 'Ъ', 'Ы', 'Ь', 'Э', 'Ю', 'Я']
44
{' ': 0, '#': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, 'А': 12, 'Б': 13, 'В': 14, 'Г': 15, 'Д': 16, 'Е': 17, 'Ж': 18, 'З': 19, 'И': 20, 'Й': 21, 'К': 22, 'Л': 23, 'М': 24, 'Н': 25, 'О': 26, 'П': 27, 'Р': 28, 'С': 29, 'Т': 30, 'У': 31, 'Ф': 32, 'Х': 33, 'Ц': 34, 'Ч': 35, 'Ш': 36, 'Щ': 37, 'Ъ': 38, 'Ы': 39, 'Ь': 40, 'Э': 41, 'Ю': 42, 'Я': 43}
tensor([27, 28, 20, 14, 17, 30,  0,  1])


In [None]:
device = 0 if torch.cuda.is_available() else 'cpu'
device

0

In [None]:
def calculate_target_metric(valset, model):
    model.eval()
    with torch.no_grad():
        distance_buffer = []
        for features, labels in tqdm([valset[i] for i in range(700, 700 + 500)]):
            features = features.to(device)
            outs = model(features[None]).squeeze().to('cpu')
            probs = F.softmax(outs, dim=0)
            seqs, likelihood = LongCTCSampler.sample(probs, beam_size=10)
            text = vectorizer.from_tensor(torch.tensor(seqs) - 1)
            decoded_message = text
            dist = Levenshtein.distance(decoded_message, labels)
            distance_buffer.append(dist)
        mean_dist = np.mean(distance_buffer)
    return mean_dist

# model

In [None]:
checkpoint_period = 10

n_epochs = 3 if dev_flag else 30
batch_size = 128

lr = 3e-4
step_gamma = 0.33
dropout = 0.165

n_pools = 4
n_blocks_before_pool = 3
pooling_overlap = True

group = 'RealTune'

run_name = 'testrun' if dev_flag else 'SimpleCNN_baseline'

config = {
    'n_epochs': n_epochs,
    'batch_size': batch_size,
    
    'lr': lr,
    'step_gamma': step_gamma,
    'dropout': dropout,

    'n_pools': n_pools,
    'n_blocks_before_pool': n_blocks_before_pool,
    'pooling_overlap': pooling_overlap,
}

model = SimpleCNN(d_input=64, d_model=64, d_inner=64, d_output=dictionary_size + 1, 
              n_pools=n_pools, n_blocks_before_pool=n_blocks_before_pool, pooling_overlap=pooling_overlap,
              dropout=dropout).to(device)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20], gamma=step_gamma)
ctc_loss = nn.CTCLoss()

train_loader = torch.utils.data.DataLoader(real_train_set, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(real_val_set, batch_size=batch_size, shuffle=False)

with wandb.init(
    **common_wandb_kvals,
    group=group,
    config=config,
    name=run_name,
    ) as run:
    for epoch in range(n_epochs):
        model.train()
        fake_train_loss_buffer = []
        for features, labels in tqdm(train_loader):
            features = features.to(device)
            targets, target_lengths = batch_text_transform(labels)
            targets, target_lengths = targets.to(device), target_lengths.to(torch.int32).to(device)
            outs = model(features).transpose(0, 2).transpose(1, 2)
            inputs = F.log_softmax(outs, dim=2)
            input_lengths = torch.full(size=(inputs.shape[1],), fill_value=inputs.shape[0], dtype=torch.int32).to(device)
            loss = ctc_loss(inputs, targets, input_lengths, target_lengths)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            fake_train_loss_buffer.append(loss.detach())
        scheduler.step()
    
        model.eval()
        fake_val_loss_buffer = []
        with torch.no_grad():
            for features, labels in tqdm(val_loader):
                features = features.to(device)
                targets, target_lengths = batch_text_transform(labels)
                targets, target_lengths = targets.to(device), target_lengths.to(torch.int32).to(device)
                outs = model(features).transpose(0, 2).transpose(1, 2)
                inputs = F.log_softmax(outs, dim=2)
                input_lengths = torch.full(size=(inputs.shape[1],), fill_value=inputs.shape[0], dtype=torch.int32).to(device)
                loss = ctc_loss(inputs, targets, input_lengths, target_lengths)
                fake_val_loss_buffer.append(loss.detach())
    
        fake_train_loss_value = torch.mean(torch.stack(fake_train_loss_buffer)).item()
        fake_val_loss_value = torch.mean(torch.stack(fake_val_loss_buffer)).item()

        wandb.log({
            'fake_train_loss': fake_train_loss_value,
            'fake_val_loss': fake_val_loss_value,
            'lr': scheduler.get_last_lr()[0],
        })

        if (epoch + 1) % checkpoint_period == 0:
            torch.save(model.state_dict(), f'{run_name}_{epoch+1}ep.pt')
            print('saved model')
    print('calculating target metric')
    target_metric = calculate_target_metric(real_val_set, model)
    # time_spent_on_train = time.perf_counter() - train_start_time
    wandb.log({
        'Levenshtein_distance': target_metric,
        # 'final_loss': final_loss,
        # 'mean_epoch_duration': time_spent_on_train / n_epochs,
    })

dataset creation


100%|██████████| 235/235 [00:12<00:00, 18.64it/s]
100%|██████████| 40/40 [00:01<00:00, 35.58it/s]
100%|██████████| 40/40 [00:00<00:00, 42.93it/s]


8.210727262496949


100%|██████████| 235/235 [00:11<00:00, 20.32it/s]
100%|██████████| 40/40 [00:01<00:00, 37.08it/s]
100%|██████████| 40/40 [00:00<00:00, 43.46it/s]


5.960347661972046


100%|██████████| 235/235 [00:11<00:00, 20.29it/s]
100%|██████████| 40/40 [00:00<00:00, 45.36it/s]
100%|██████████| 40/40 [00:00<00:00, 43.93it/s]


4.314725569605827


100%|██████████| 235/235 [00:11<00:00, 20.33it/s]
100%|██████████| 40/40 [00:00<00:00, 42.85it/s]
100%|██████████| 40/40 [00:01<00:00, 39.34it/s]


3.166647103631496


100%|██████████| 235/235 [00:11<00:00, 20.30it/s]
100%|██████████| 40/40 [00:00<00:00, 44.70it/s]
100%|██████████| 40/40 [00:00<00:00, 45.28it/s]


2.3393262873518466


100%|██████████| 235/235 [00:11<00:00, 20.35it/s]
100%|██████████| 40/40 [00:00<00:00, 44.32it/s]
100%|██████████| 40/40 [00:00<00:00, 44.41it/s]


1.7614910895923375


100%|██████████| 235/235 [00:11<00:00, 20.36it/s]
100%|██████████| 40/40 [00:00<00:00, 43.97it/s]
100%|██████████| 40/40 [00:00<00:00, 40.89it/s]


1.3566136237690567


100%|██████████| 235/235 [00:11<00:00, 20.32it/s]
100%|██████████| 40/40 [00:00<00:00, 43.71it/s]
100%|██████████| 40/40 [00:00<00:00, 44.56it/s]


1.062861314783176


100%|██████████| 235/235 [00:11<00:00, 20.32it/s]
100%|██████████| 40/40 [00:00<00:00, 44.40it/s]
100%|██████████| 40/40 [00:00<00:00, 44.12it/s]


0.8549210071154197


100%|██████████| 235/235 [00:11<00:00, 20.34it/s]
100%|██████████| 40/40 [00:00<00:00, 42.66it/s]
100%|██████████| 40/40 [00:00<00:00, 42.57it/s]


0.7104847687196167


100%|██████████| 235/235 [00:11<00:00, 20.36it/s]
100%|██████████| 40/40 [00:00<00:00, 45.10it/s]
100%|██████████| 40/40 [00:00<00:00, 41.54it/s]


0.6127140134968363


100%|██████████| 235/235 [00:11<00:00, 20.32it/s]
100%|██████████| 40/40 [00:00<00:00, 45.78it/s]
100%|██████████| 40/40 [00:00<00:00, 42.29it/s]


0.550962653485795


100%|██████████| 235/235 [00:11<00:00, 20.34it/s]
100%|██████████| 40/40 [00:00<00:00, 40.14it/s]
100%|██████████| 40/40 [00:00<00:00, 45.06it/s]


0.5014898992672583


100%|██████████| 235/235 [00:11<00:00, 20.35it/s]
100%|██████████| 40/40 [00:00<00:00, 44.64it/s]
100%|██████████| 40/40 [00:00<00:00, 44.54it/s]


0.471555912690015


100%|██████████| 235/235 [00:11<00:00, 20.33it/s]
100%|██████████| 40/40 [00:00<00:00, 43.56it/s]
100%|██████████| 40/40 [00:00<00:00, 44.80it/s]


0.4472293873008648


100%|██████████| 235/235 [00:11<00:00, 20.32it/s]
100%|██████████| 40/40 [00:00<00:00, 43.32it/s]
100%|██████████| 40/40 [00:00<00:00, 42.01it/s]


0.4405026479128591


100%|██████████| 235/235 [00:11<00:00, 20.33it/s]
100%|██████████| 40/40 [00:00<00:00, 43.70it/s]
100%|██████████| 40/40 [00:00<00:00, 42.22it/s]


0.42619568114424883


100%|██████████| 235/235 [00:11<00:00, 20.34it/s]
100%|██████████| 40/40 [00:00<00:00, 40.74it/s]
100%|██████████| 40/40 [00:00<00:00, 44.50it/s]


0.40954400006395175


100%|██████████| 235/235 [00:11<00:00, 20.40it/s]
100%|██████████| 40/40 [00:00<00:00, 45.39it/s]
100%|██████████| 40/40 [00:00<00:00, 45.65it/s]


0.4030538651587875


100%|██████████| 235/235 [00:11<00:00, 20.42it/s]
100%|██████████| 40/40 [00:00<00:00, 42.25it/s]
100%|██████████| 40/40 [00:00<00:00, 44.02it/s]


0.3940638469924149


100%|██████████| 235/235 [00:11<00:00, 20.46it/s]
100%|██████████| 40/40 [00:00<00:00, 45.91it/s]
100%|██████████| 40/40 [00:00<00:00, 46.01it/s]


0.39061361820387164


100%|██████████| 235/235 [00:11<00:00, 20.49it/s]
100%|██████████| 40/40 [00:00<00:00, 45.38it/s]
100%|██████████| 40/40 [00:00<00:00, 46.12it/s]


0.3870326359035683


100%|██████████| 235/235 [00:11<00:00, 20.41it/s]
100%|██████████| 40/40 [00:00<00:00, 44.87it/s]
100%|██████████| 40/40 [00:00<00:00, 43.69it/s]


0.3865367288662471


100%|██████████| 235/235 [00:11<00:00, 20.26it/s]
100%|██████████| 40/40 [00:00<00:00, 40.32it/s]
100%|██████████| 40/40 [00:00<00:00, 43.21it/s]


0.38745373693023466


100%|██████████| 235/235 [00:11<00:00, 20.19it/s]
100%|██████████| 40/40 [00:00<00:00, 44.29it/s]
100%|██████████| 40/40 [00:00<00:00, 43.86it/s]


0.3847221567714218


100%|██████████| 235/235 [00:11<00:00, 20.31it/s]
100%|██████████| 40/40 [00:00<00:00, 44.48it/s]
100%|██████████| 40/40 [00:00<00:00, 45.52it/s]


0.3824182961522343


100%|██████████| 235/235 [00:11<00:00, 20.24it/s]
100%|██████████| 40/40 [00:00<00:00, 41.61it/s]
100%|██████████| 40/40 [00:00<00:00, 44.06it/s]


0.3830247819321515


100%|██████████| 235/235 [00:11<00:00, 20.29it/s]
100%|██████████| 40/40 [00:00<00:00, 43.88it/s]
100%|██████████| 40/40 [00:00<00:00, 43.81it/s]


0.3843936204838766


100%|██████████| 235/235 [00:11<00:00, 20.26it/s]
100%|██████████| 40/40 [00:00<00:00, 45.17it/s]
100%|██████████| 40/40 [00:00<00:00, 44.69it/s]


0.3847901260802755


100%|██████████| 235/235 [00:11<00:00, 20.32it/s]
100%|██████████| 40/40 [00:00<00:00, 45.85it/s]
100%|██████████| 40/40 [00:00<00:00, 44.72it/s]


0.3828575755322701


In [None]:
torch.save(model.state_dict(), f'{run_name}_final.pt')