In [1]:
import pandas as pd
import torch
import torchaudio
import torchaudio.transforms as T
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
import os
import matplotlib.pyplot as plt

In [2]:
df = pd.read_csv('processed_train.csv', encoding='utf-8')

In [3]:
alphabet = sorted(set("".join(df['label'])))
char2idx = {ch: i + 1 for i, ch in enumerate(alphabet)}
idx2char = {i: ch for ch, i in char2idx.items()}

In [4]:
mel_spec = T.MelSpectrogram(sample_rate=8000, n_mels=40)

In [5]:
spectrograms = []
targets = []
input_lengths = []
target_lengths = []

In [6]:
for i, row in df.iterrows():
    waveform = torch.load(row['filepath'])
    spec = mel_spec(waveform)
    spec = spec.squeeze(0)
    
    spectrograms.append(spec)
    input_lengths.append(spec.shape[1])

    label = row['label']
    encoded = torch.tensor([char2idx[c] for c in label], dtype=torch.long)
    
    targets.append(encoded)
    target_lengths.append(len(encoded))

In [7]:
spectrograms[0].shape

torch.Size([40, 321])

In [8]:
torch.save({
    'spectrograms': spectrograms,
    'targets': targets,
    'input_lengths': input_lengths,
    'target_lengths': target_lengths,
    'char2idx': char2idx,
    'idx2char': idx2char
}, 'prepared_data.pt')


In [53]:
data = torch.load('prepared_data.pt')

In [54]:
spectrograms = data['spectrograms']
targets = data['targets']
input_lengths = data['input_lengths']
target_lengths = data['target_lengths']
char2idx = data['char2idx']
idx2char = data['idx2char']

In [4]:
from torch.utils.data import Dataset

class MorseDataset(Dataset):
    def __init__(self, spectrograms, targets, input_lengths, target_lengths):
        self.spectrograms = spectrograms
        self.targets = targets
        self.input_lengths = input_lengths
        self.target_lengths = target_lengths

    def __len__(self):
        return len(self.spectrograms)

    def __getitem__(self, idx):
        return {
            'spectrogram': self.spectrograms[idx],
            'target': self.targets[idx],
            'input_length': self.input_lengths[idx],
            'target_length': self.target_lengths[idx]
        }

## Определение модели

In [3]:
class MorseModel(nn.Module):
    def __init__(self, n_mels, n_classes):
        super().__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2))
        )
        
        self.lstm = nn.LSTM(
            input_size=(n_mels // 2) * 32,
            hidden_size=128,
            num_layers=2,
            bidirectional=True,
            batch_first=True
        )

        self.fc = nn.Linear(128 * 2, n_classes + 1) 

    def forward(self, x):
        x = self.cnn(x)  
        b, c, h, t = x.size()

        x = x.permute(0, 3, 1, 2).contiguous()
        x = x.view(b, t, c * h)

        x, _ = self.lstm(x)

        x = self.fc(x)  
        return x.permute(1, 0, 2)  

In [None]:
def collate_fn(batch):
    max_spec_len = max(item['spectrogram'].shape[1] for item in batch)

    specs = torch.stack([
        torch.nn.functional.pad(item['spectrogram'], (0, max_spec_len - item['spectrogram'].shape[1]))
        for item in batch
    ])
    specs = specs.unsqueeze(1) 

    targets = torch.cat([item['target'] for item in batch])

    input_lengths = torch.tensor([item['input_length'] for item in batch], dtype=torch.long)
    target_lengths = torch.tensor([item['target_length'] for item in batch], dtype=torch.long)

    return specs, targets, input_lengths, target_lengths


In [12]:
csv_path = "D:/vs_projects/Data/morse/test.csv"
pt_folder = "D:/vs_projects/Data/morse/morse_wav"


test_df = pd.read_csv(csv_path)
test_df = test_df.rename(columns={'id':'filename'})

In [6]:
def decode_prediction(preds, idx2char):
    preds = preds.permute(1, 0, 2)
    pred_indices = preds.argmax(dim=-1).cpu().numpy()
    results = []

    for seq in pred_indices:
        prev = -1
        chars = []
        for i in seq:
            if i != prev and i != 0:
                chars.append(idx2char.get(i, '?'))
            prev = i
        results.append("".join(chars))

    return results

In [16]:
mel = T.MelSpectrogram(sample_rate=8000, n_mels=40)

## Функция для вычисления метрики

In [31]:
import Levenshtein

def evaluate_levenshtein(model, csv_path, pt_folder, idx2char, sample_rate=8000, n_mels=40, limit=None):
    df = pd.read_csv(csv_path)

    if limit is not None:
        df = df.tail(limit)

    mel = T.MelSpectrogram(sample_rate=sample_rate, n_mels=n_mels)
    model.eval()
    distances = []

    with torch.no_grad():
        for _, row in df.iterrows():
            fname = row['id']
            true_text = row['message']

            pt_path = os.path.join(pt_folder, fname.replace(".opus", ".pt"))
            waveform = torch.load(pt_path)

            if waveform.ndim == 3:
                waveform = waveform.squeeze(0)

            spec = mel(waveform)
            spec = spec.squeeze(0).unsqueeze(0).unsqueeze(0)
            out = model(spec) 
            pred_text = decode_prediction(out, idx2char)[0]

            dist = Levenshtein.distance(pred_text, true_text)
            distances.append(dist)

    mean_distance = sum(distances) / len(distances)
    return mean_distance

In [None]:
data = torch.load("prepared_data.pt")

dataset = MorseDataset(
    data['spectrograms'],
    data['targets'],
    data['input_lengths'],
    data['target_lengths']
)

loader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=collate_fn
)

model = MorseModel(n_mels=40, n_classes=len(char2idx))

## Процесс обучения

In [36]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn.utils import clip_grad_norm_
from collections import Counter

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
ctc_loss = nn.CTCLoss(blank=0)
best_score = float("inf")
wait = 0
patience = 15

num_epochs = 100
for epoch in range(101, 150+1):
    model.train()
    epoch_loss = 0

    for specs, targets, input_lengths, target_lengths in loader:
        input_lengths = input_lengths // 2
        preds = model(specs)
        loss = ctc_loss(preds, targets, input_lengths, target_lengths)

        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(loader)
    scheduler.step(avg_loss)
    current_lr = optimizer.param_groups[0]['lr']

    val_score = evaluate_levenshtein(
        model=model,
        csv_path="D:/vs_projects/Data/morse/train.csv",
        pt_folder="D:/vs_projects/Data/morse/morse_wav",
        idx2char=idx2char,
        limit=300
    )

    sample_texts = []
    model.eval()
    with torch.no_grad():
        for i in range(10):
            waveform = torch.load(f"D:/vs_projects/Data/morse/morse_wav/{29991 + i}.pt")
            if waveform.ndim == 3:
                waveform = waveform.squeeze(0)
            spec = mel(waveform)
            spec = spec.squeeze(0).unsqueeze(0).unsqueeze(0)
            out = model(spec)
            pred = decode_prediction(out, idx2char)[0]
            sample_texts.append(pred)

    avg_pred_len = sum(len(p) for p in sample_texts) / len(sample_texts)
    counter = Counter("".join(sample_texts)).most_common(5)

    print(f"Epoch {epoch:02} | Loss: {avg_loss:.2f} | LR: {current_lr:.5f} | "
          f"Levenshtein: {val_score:.2f} | Pred avg len: {avg_pred_len:.2f}")
    
    if val_score < best_score - 0.01:
        best_score = val_score
        wait = 0
        torch.save(model.state_dict(), "morse_model_epoch{epoch}.pt")
        print(f"best model epoch{epoch} (score={val_score:.2f})")
    else:
        wait += 1
        if wait >= patience:
            print('Stopped')
            break        




Epoch 101 | Loss: 1.62 | LR: 0.00100 | Levenshtein: 3.74 | Pred avg len: 7.60
best model epoch101 (score=3.74)
Epoch 102 | Loss: 1.61 | LR: 0.00100 | Levenshtein: 3.65 | Pred avg len: 7.60
best model epoch102 (score=3.65)
Epoch 103 | Loss: 1.61 | LR: 0.00100 | Levenshtein: 3.74 | Pred avg len: 7.70
Epoch 104 | Loss: 1.60 | LR: 0.00100 | Levenshtein: 3.61 | Pred avg len: 7.80
best model epoch104 (score=3.61)
Epoch 105 | Loss: 1.59 | LR: 0.00100 | Levenshtein: 3.75 | Pred avg len: 7.90
Epoch 106 | Loss: 1.59 | LR: 0.00100 | Levenshtein: 3.80 | Pred avg len: 7.60
Epoch 107 | Loss: 1.58 | LR: 0.00100 | Levenshtein: 3.66 | Pred avg len: 7.60
Epoch 108 | Loss: 1.57 | LR: 0.00100 | Levenshtein: 3.80 | Pred avg len: 7.70
Epoch 109 | Loss: 1.57 | LR: 0.00100 | Levenshtein: 3.58 | Pred avg len: 7.30
best model epoch109 (score=3.58)
Epoch 110 | Loss: 1.56 | LR: 0.00100 | Levenshtein: 3.60 | Pred avg len: 7.50
Epoch 111 | Loss: 1.56 | LR: 0.00100 | Levenshtein: 3.60 | Pred avg len: 7.90
Epoch 112 

In [37]:
torch.save(model.state_dict(), "morse_model.pt")

## Предсказание

In [38]:
from torch.nn.functional import pad

results = []

model.eval()
with torch.no_grad():
    for fname in test_df['filename']:
        pt_path = os.path.join(pt_folder, fname.replace(".opus", ".pt"))

        waveform = torch.load(pt_path)
        if waveform.ndim == 3:
            waveform = waveform.squeeze(0)

        spec = mel(waveform)
        spec = spec.squeeze(0).unsqueeze(0).unsqueeze(0) 

        pred = model(spec)
        decoded = decode_prediction(pred, idx2char)[0]

        results.append((fname, decoded))

In [39]:
submission_df = pd.DataFrame(results, columns=["id", "message"])
submission_df.to_csv("submission.csv", index=False, encoding='utf-8')