In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import datetime

from torch.utils.data import TensorDataset, Dataset, DataLoader, random_split


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))


In [None]:
def parse_matrix(filename):
    try:
        with open(filename, 'r') as f:
            matr = torch.tensor([[int(elem) for elem in line.split()] for line in f.readlines()], device=device)
            return matr
    except Exception as e:
        print('bad filename')
        raise

# MDPC 73

In [None]:
enc_path = 'G-45-73.txt'
syn_path = 'H-73-73-2.txt'
encoding_matrix = parse_matrix(enc_path)
syndrome_matrix = parse_matrix(syn_path)

print(encoding_matrix)
print(syndrome_matrix)

encoding_matrix = encoding_matrix.float()
syndrome_matrix = syndrome_matrix.float()

k, n = encoding_matrix.shape
syndrome_len = syndrome_matrix.shape[0] 

input_size = n + syndrome_len # Входной вектор. Синдром + абсолютное сообщение
output_size = n  # Выходной вектор. Оценка ошибок в координатах
hidden_size = 6 * output_size

# MDPC 105

In [None]:
enc_path = 'G-53-105.txt'
syn_path = 'H-105-105.txt'
encoding_matrix = parse_matrix(enc_path)
syndrome_matrix = parse_matrix(syn_path).T

print(encoding_matrix)
print(syndrome_matrix)

encoding_matrix = encoding_matrix.float()
syndrome_matrix = syndrome_matrix.float()

k, n = encoding_matrix.shape
syndrome_len = syndrome_matrix.shape[0]

input_size = n + syndrome_len # Входной вектор. Синдром + абсолютное сообщение
output_size = n  # Выходной вектор. Оценка ошибок в координатах
hidden_size = 6 * output_size 

# data generators

In [None]:
# n is message_length
def gen_encoded_messages(num_mess, device):
    res = torch.empty(num_mess, n, dtype=torch.float32, device=device)
    for i in range(num_mess):

        input_message = torch.randint(0, 2, (k,), dtype=torch.float32, device=device)
        # encoding
        enc_mess = (torch.matmul(input_message, encoding_matrix) % 2).float()
        # BPSK
        enc_mod_mess = (2 * enc_mess - 1).float()
        
        res[i, :] = enc_mod_mess
    return res


def noisify(messages, device, snr=4):

    linear_snr = 10 ** (snr / 10)
    sigma = math.sqrt(1 / (2 * k / n * linear_snr))
    
    noise = torch.normal(0, sigma, size=messages.shape, dtype=torch.float32, device=device)
      
    return messages + noise


def get_syndromes(messages):
    # BPSK demodulation
    demo_mes = (messages > 0).float()
    # returning syndrome
    return torch.matmul(demo_mes, syndrome_matrix.T) % 2


def get_abses(messages):
    return torch.abs(messages)


def get_pure_error(enc_mod_messes, noisy_meses):
    c = noisy_meses * enc_mod_messes
    return (c < 0).float()


def gen_data(data_len, device=torch.device("cpu"), snr=4):
    messes = gen_encoded_messages(data_len, device)
    noisy_messes = noisify(messes, device, snr)
    
    abses = get_abses(noisy_messes)
    syns = get_syndromes(noisy_messes)
    perrors = get_pure_error(messes, noisy_messes)

    synabses = torch.cat((syns, abses), dim=1)
    return synabses, perrors



# Infinite dataset class

In [None]:
class InfiniteDataloader(Dataset):
    def __init__(self, limit, gen_function, device, batch_size=512, shuffle=False, snr=4):
        self.gen_function = gen_function
        self.snr = snr
        
        self.batch_size = int(batch_size)
        self.limit = int(limit)
        self.current = 0
        
        self.need_shuffle = shuffle
        self.indices = torch.arange(self.limit)
        
        self.device = device

    def __len__(self):
        return (self.limit + self.batch_size -1) // self.batch_size

    def __iter__(self):
        self.current = 0
        if self.need_shuffle:
            self.indices = torch.randperm(self.limit)
        return self

    def __next__(self):
        if self.current >= self.limit:
            raise StopIteration
        
        cur_batch_size = min(self.batch_size, self.limit - self.current)
        synabses, pure_errors = self.gen_function(cur_batch_size, self.device, snr=self.snr)
        self.current += cur_batch_size

        return synabses, pure_errors



# Custom functions

In [None]:

class SquaredReLU(nn.Module):
    def __init__(self, max_value=10.0):
        super().__init__()
        self.max_value = max_value

    def forward(self, x):
        relu_output = torch.relu(x)
        clamped_output = torch.clamp(relu_output, max=self.max_value)
        return clamped_output ** 2


class ElementaryCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(ElementaryCrossEntropyLoss, self).__init__()

    def forward(self, output, target):
        if output.dim() != 2 or target.dim() != 2:
            raise ValueError(f'Both_should_be_2D_tensors_error. Output: {output.shape}, target: {target.shape}')
        
        if output.shape != target.shape:
            raise ValueError(f'Shape mismatch. Output: {output.shape}, target: {target.shape}')
        
        losses = torch.stack([F.cross_entropy(output[:, _], target[:, _]) for _ in range(output.size(1))])

        return losses.sum()

# Model

In [None]:

class NoiseEstimator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=11, activation=nn.ReLU, device=torch.device("cpu")):
        super(NoiseEstimator, self).__init__()
        
        # Входной. синдром + абсолютные значения
        self.input_layer = nn.Linear(input_size, hidden_size)

        # Скрытные. предыдущие + вход
        self.hidden_layers = nn.ModuleList([nn.Linear(hidden_size + input_size, hidden_size) for _ in range(num_layers - 2)])

        # Выходной. external Sigmoid
        self.output_layer = nn.Linear(hidden_size + input_size, output_size)
        
        # Активация.
        self.activation = activation()

        self.defaults = {
            'input_size' : input_size,
            'hidden_size' : hidden_size,
            'output_size' : output_size,
            'activation' : activation,
            'num_layers' : num_layers
        }
        
        self.to(device)
        
    
    def forward(self, x):
        y = self.activation(self.input_layer(x))

        for layer in self.hidden_layers:
          
            y_1 = torch.cat((x, y), dim=1)
            y_2 = layer(y_1)
            y = self.activation(y_2)

        y_1 = torch.cat((x, y), dim=1)
        y = self.output_layer(y_1)
        
        return y



# Utilities

In [None]:
import os

def save_checkpoint(
        model, opt, 
        metric_values=None, model_info=None,
        checkpoint_dir='checkpoints', additional_mark=None, add_time=False
    ):
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    checkpoint = {
        'model_class' : model.__class__,
        'model_config' : model.defaults,
        'model_state_dict': model.state_dict(),
        
        'optimizer_class' : opt.__class__,
        'optimizer_config' : opt.defaults,
        'optimizer_state_dict': opt.state_dict(),
        
        'metric_values' : metric_values,
        'model_info' : model_info,
    }

    checkpoint_model_name = str(model.__class__).split(".")[-1].rstrip("\'>")
    model_code_len = checkpoint['model_config']['input_size'] - checkpoint['model_config']['output_size']

    scheduler_name = 'None'
    if checkpoint['model_info']['scheduler'] is not None:
        scheduler_name = str(checkpoint['model_info']['scheduler'].__class__).split('.')[-1].rstrip("'>") + '_'
    
    marker = ''
    if additional_mark is not None:
        marker = '_' + additional_mark + '_'

    checkpoint_path = os.path.join(
        checkpoint_dir,
        f'{checkpoint_model_name}_'
        f"{str(checkpoint['model_config']['num_layers'])}_"
        f"{str(checkpoint['model_config']['activation']())[:-2]}_"
        f"{str(model_code_len)}_"
        f"{scheduler_name}"
        f"{marker}"
        f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S') if add_time else ''}.pth"
    )
    torch.save(checkpoint, checkpoint_path)
    print(f"Saved checkpoint at {checkpoint_path}")

    return checkpoint_path


def load_checkpoint(path, device=torch.device("cpu")):
    if not os.path.exists(path):
        raise FileNotFoundError(f"No such checkpoint file {path}")

    checkpoint = torch.load(path, weights_only=False)

    try:
        model  = checkpoint['model_class'](**checkpoint['model_config'], device = device)
        model.load_state_dict(checkpoint['model_state_dict'])
    except Exception as e:
        raise ValueError(f'Can\'t load model due to incorrect checkpoint info, bastard. {str(e)}')

    try:
        optimizer = checkpoint['optimizer_class'](model.parameters(), **checkpoint['optimizer_config'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    except:
        raise ValueError('Can\'t load optimizer for you due to incorrect checkpoint optimizer info, bastard')

    model_info = checkpoint['model_info']
    metric_values = checkpoint['metric_values']
    
    print(f"Loaded checkpoint at {path}")
    return model, optimizer, model_info, metric_values


def set_seed(seed=42):
    """Установка начального состояния генераторов случайных чисел для воспроизводимости."""
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

## Custom optimizer

In [None]:

class warmup_cosine_scheduler:
    def __init__(self, optimizer, num_epochs, warmup_epochs, min_lr):
        cosine_t_max = num_epochs - warmup_epochs
        if cosine_t_max <= 0:
            raise ValueError("T_max для CosineAnnealingLR <= 0. Оптимизатор не может работать корректно")            
        
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.num_epochs = num_epochs
        self.warmup_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.warmup_lr_scheduler)
        self.cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.num_epochs - self.warmup_epochs, eta_min=min_lr)



    def step(self, epoch):
        if epoch < self.warmup_epochs:
            self.warmup_scheduler.step()
        else:
            self.cosine_scheduler.step()
    
    @staticmethod
    def warmup_lr_scheduler(epoch):
        if epoch < warmup_epochs:
            return (epoch + 1) / warmup_epochs
        else:
            return 1.0

# Train and Eval cycles

In [None]:
from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np
from torchmetrics import Accuracy, Precision, Recall, F1Score, AUROC
# from tqdm.auto import tqdm
if device == "cuda":
    import gc


def eval_model(
        model, 
        eval_config={
            'val_data_size' : 65536,
            'batch_size' : 1024,
            'classification_threshold' : 0.5, 
            'val_snr' : 4, 
            'early_stop_max_errors' : None # 100 errors if so
        }, 
        device=torch.device("cpu"), verbose=True
    ):
    
    val_loader = InfiniteDataloader(
            limit=eval_config['val_data_size'],
            gen_function=gen_data,
            device=device,
            batch_size=eval_config['batch_size'],
            shuffle=False,
            snr=eval_config['val_snr']
        )

    criterion = ElementaryCrossEntropyLoss()
    
    metric_evaluators = {
        'accuracy' : Accuracy(task="binary").to(device),
        'precision' : Precision(task="binary").to(device),
        'recall' : Recall(task="binary").to(device),
        'f1score' : F1Score(task="binary").to(device)
    }
    
    metric_values = { 'accuracy' : 0, 'precision' : 0, 'recall' : 0, 'f1score': 0}

    # evaluation
    model.eval()
    with torch.no_grad():
        total_val_loss = 0

        for metric in metric_evaluators.values():
            metric.reset()

        val_loader_iterable = val_loader

        if verbose:
            val_loader_iterable = tqdm(val_loader, desc=f"Model [Validation]")
        
        for syn_abses_batch, perrors_batch in val_loader_iterable:
            outputs = model(syn_abses_batch)
            total_val_loss += criterion(outputs, perrors_batch)
            
            # Порог классификации
            outputs_probabilities = torch.sigmoid(outputs)
            output_classes = (outputs_probabilities >= eval_config['classification_threshold']).float()

            # Metrics update
            for metric in metric_evaluators.values():
                metric.update(output_classes, perrors_batch)

            if eval_config['early_stop_max_errors'] is not None:
                current_errors = metric_evaluators['precision'].fp + metric_evaluators['recall'].fn
                if current_errors > eval_config['early_stop_max_errors']:
                    break
            
        avg_val_loss = total_val_loss / len(val_loader)

        for metric, evaluator in metric_evaluators.items():
            metric_values[metric] = evaluator.compute()

        metric_values['loss'] = avg_val_loss
        metric_values['correct_predictions'] = metric_evaluators['accuracy'].tp + metric_evaluators['accuracy'].tn
        metric_values['total_predictions'] = metric_values['correct_predictions'] + metric_evaluators['accuracy'].fp + metric_evaluators['accuracy'].fn

        return metric_values


def train_loop(
        model, optimizer, scheduler = None,
        train_config = {
            'num_epochs'               : 64,
            'train_data_size'          : 65536,
            'val_data_size'            : 32768,
            'batch_size'               : 1024,
            'train_snr'                : 4,
            'val_snr'                  : 4,
            'classification_threshold' : 0.5,
            'start_epoch'              : 0,
            'shuffle_data'             : False,
            'early_stop_bad_epochs'    : None,
            'early_stop_max_errors'    : None,
        },
        checkpoint_dir = 'checkpoints',
        checkpoint_mark = None,
        checkpoint_timestamps = True,
        verbose = True,
        device = torch.device("cpu"),
    ):
    
    train_loader     = InfiniteDataloader(
        limit        = train_config['train_data_size'],
        gen_function = gen_data,
        device       = device,
        batch_size   = train_config['batch_size'],
        shuffle      = train_config['shuffle_data'],
        snr          = train_config['train_snr']
    )
    
    if verbose:
        print(f"Using device: {device}")
    
    criterion = ElementaryCrossEntropyLoss()
    
    last_val_loss = float('inf')
    
    metric_evaluators = {
        'accuracy' : Accuracy(task="binary").to(device),
        'precision' : Precision(task="binary").to(device),
        'recall' : Recall(task="binary").to(device),
        'f1score' : F1Score(task="binary").to(device)
    }

    best_metric_values = {'accuracy' : -1, 'precision' : -1, 'recall' : -1, 'f1score': -1}
    metric_values = {'accuracy' : 0, 'precision' : 0, 'recall' : 0, 'f1score': 0}

    num_epochs = train_config['num_epochs']
    epoch_no_progress = 0

    best_checkpoint_path = None
    losses = dict()
    
    # training loop
    for epoch in range(train_config['start_epoch'], num_epochs):
        if verbose:
            print(f"Epoch {epoch+1}/{num_epochs}")
        
        # training
        model.train()
        total_train_loss = 0

        train_loader_iterable = train_loader
        if verbose:
            train_loader_iterable = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]")

        for syn_abses_batch, perrors_batch in train_loader_iterable:            
            optimizer.zero_grad()
            outputs = model(syn_abses_batch)
            loss = criterion(outputs, perrors_batch)
            total_train_loss += loss.item()

            loss.backward()

            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()


        # Шаг планировщика LR
        if scheduler is not None:
            scheduler.step(epoch)

        avg_train_loss = total_train_loss / len(train_loader)
        if verbose:
            print(f"Epoch {epoch+1} Training Loss: {avg_train_loss:.4f}")
        
        # validation
        metric_values = eval_model(
            model, 
            eval_config={
                'val_data_size' : train_config['val_data_size'],
                'batch_size' : train_config['batch_size'],
                'classification_threshold' : train_config['classification_threshold'], 
                'val_snr' : train_config['val_snr'], 
                'early_stop_max_errors' : train_config['early_stop_max_errors'],
            }, 
            device=device, verbose=True
        )
        
        if verbose:
            print(f"Epoch {epoch+1} Results:")
            print((
                f"Loss: {metric_values['loss']:.4f} "
                f"({(metric_values['loss'] / last_val_loss -1)*100:+.4f}% "
                f"{'😊' if metric_values['loss'] - last_val_loss < 0 else '😒'})"
            ))
            print((
                f"Accuracy: {metric_values['accuracy']:.4f} "
                f"({(metric_values['accuracy'] / best_metric_values['accuracy'] -1)*100:+.4f}% " 
                f"{'👍' if metric_values['accuracy'] - best_metric_values['accuracy'] > 0 else '👎'})"
            ))
            print(f"Precision: {metric_values['precision']:.4f}")
            print(f"Recall: {metric_values['recall']:.4f}")
            print(f"F1 Score: {metric_values['f1score']:.4f}")

        losses[epoch] = metric_values['loss']
        last_val_loss = metric_values['loss']
        
        if metric_values['accuracy'] <= best_metric_values['accuracy']:
            epoch_no_progress += 1
        else:
            for name, value in metric_values.items():
                best_metric_values[name] = value
    
            best_checkpoint_path = save_checkpoint(
                model, 
                optimizer,
                metric_values = metric_values,
                model_info = {
                    'epoch' : epoch + 1,
                    'train_snr' : train_config['train_snr'],
                    'val_snr' : train_config['val_snr'],
                    'optimal_classification_threshold' : None,
                    'scheduler' : scheduler,
                    'batch_size' : train_config['batch_size']
                },
                checkpoint_dir=checkpoint_dir,
                additional_mark = checkpoint_mark,
                add_time=checkpoint_timestamps, 
            )
            
            epoch_no_progress = 0

        # Ранняя остановка
        if (train_config['early_stop_bad_epochs'] is not None) and (epoch_no_progress >= train_config['early_stop_bad_epochs']):
            if verbose:
                print(f"Early stopping triggered. {train_config['early_stop_bad_epochs']} epochs of NO progress")
            break

    # Вывод лучших метрик за тренировочный цикл.
    if verbose:
        print(f"\n~~~ Loop Ended ~~~")
        print(f"Best Validation Metrics Achieved:")
        for name in ['accuracy','loss','precision','recall','f1score']:
            print(f"{name}: {best_metric_values[name]:.4f}")
        print(f"Total epochs: {epoch + 1}")
        print("\n\n")

    if device == "cuda":
        torch.cuda.empty_cache()
        gc.collect()
        torch.cuda.synchronize()

    # saving losses
    loss_dir = os.path.join(checkpoint_dir, 'loss_stats')
    os.makedirs(loss_dir, exist_ok=True)
    losses_path = os.path.join(loss_dir, best_checkpoint_path.split('\\')[-1].split('.')[0] + '_losses.txt')
    with open(losses_path, 'w') as loss_file:
        for epoch, loss in losses.items():
            loss_file.write(f'{epoch};{loss:.4f}' + '\n')
    
    train_report = {
        'path'             : best_checkpoint_path,
        'best_val_metrics' : best_metric_values,
        'total_epochs'     : epoch + 1
    }
    
    return train_report
    

# Threshold optimization

In [None]:
from sklearn.metrics import precision_recall_curve

def optimize_threshold(
        model, 
        eval_config={
            'val_data_size' : 32768,
            'batch_size' : 1024,
            'val_snr' : 4, 
        }, 
        device=torch.device("cpu"), verbose=True
    ):
    val_loader = InfiniteDataloader(
        limit=eval_config['val_data_size'],
        gen_function=gen_data,
        device=device,
        batch_size=eval_config['batch_size'],
        shuffle=False,
        snr=eval_config['val_snr']
    )

    all_probs = []
    all_trues = []
    
    model.eval()
    with torch.no_grad():
        val_loader_iterable = val_loader
        if verbose:
            val_loader_iterable = tqdm(val_loader, desc=f"Model [Validation]")
        
        for syn_abses_batch, perrors_batch in val_loader_iterable:
            outputs = model(syn_abses_batch)
            
            outputs_probabilities = torch.sigmoid(outputs).cpu().numpy()
            perrors_numpy = perrors_batch.cpu().numpy()
        
            all_probs.append(outputs_probabilities.flatten())
            all_trues.append(perrors_numpy.flatten())

    all_probs = np.array(all_probs).flatten()
    all_trues = np.array(all_trues).flatten()
        
    precs, recs, thrs = precision_recall_curve(all_trues, all_probs)
    f1s = 2 * (precs * recs) / (precs + recs + 1e-8)
    optimal_idx = np.argmax(f1s)
    optimal_threshold = thrs[optimal_idx]

    if verbose:
        print(f"Optimal threshold found: {optimal_threshold:.4f}")
        print(f"Max F1-score: {f1s[optimal_idx]:.4f}")
    
    return optimal_threshold

# Research analysis

In [None]:
import os
from collections import defaultdict

def checkpoints_analyze(checkpoints_dir, criterion, research_dir='metrics', additional_mark='', eval_config=None, metrics_open='w'): # or 'a'
    best_checkpoints = defaultdict(lambda: {'accuracy' : 0 })
    
    for checkpoint_name in [f for f in os.listdir(checkpoints_dir) if f.endswith('pth')]:
        checkpoint_path = os.path.join(checkpoints_dir, checkpoint_name)
        
        if not os.path.isfile(checkpoint_path):
            continue

        model, optim, model_info, metric_values = load_checkpoint(checkpoint_path, device)

        model_config = model.defaults

        if eval_config is None:
            eval_config = {
                'val_data_size' : 65536,
                'batch_size' : 1024,
                'classification_threshold' : model_info['optimal_classification_threshold'],
                'val_snr' : model_info['val_snr'],
                'early_stop_max_errors' : None
            }

        optimal_metrics = eval_model(model, eval_config, device)
        print(metric_values['accuracy'], optimal_metrics['accuracy'], metric_values['accuracy'] < optimal_metrics['accuracy'])
        if metric_values['accuracy'] < optimal_metrics['accuracy']:

            metric_values = optimal_metrics
            print('optimal threshold set')


        model_data = {**model_info, ** model_config, **metric_values}

        if metric_values['accuracy'] > best_checkpoints[model_data[criterion]]['accuracy']:

            best_checkpoints[model_data[criterion]] = {
                'accuracy' : metric_values['accuracy'],
                'path' : checkpoint_path,
                'model_config' : model_config,
                'valuable_data' : model_data,
            }
    
    models_sorted_by_accuracy = sorted(
        best_checkpoints.items(),
        key=lambda x: x[1]['accuracy'],
        reverse=True
    )
    
    print('\n\noverall results', end='\n\n\n')
    for crit, stats in best_checkpoints.items():
        print(f'{criterion} = {crit}:', end=' ')
        for name in ['accuracy', 'precision', 'recall', 'f1score', 'epoch', 'train_snr']:
            print(f"{stats['valuable_data'][name]:.4f}", end=', ')
        print()

    research_dir = os.path.join(checkpoint_dir, research_dir)
    os.makedirs(research_dir, exist_ok=True)
    res_path = os.path.join(research_dir, 'metrics.txt')
    
    print('\n\n\nsorted results', end='\n\n\n')
    with open(res_path, metrics_open) as metr_file:
        for crit, stats in models_sorted_by_accuracy:
            print(f'{criterion} = {crit}:', end=' ')
            
            metr_file.write(f'{crit};')
            for name in ['accuracy', 'precision', 'recall', 'f1score', 'epoch', 'train_snr']:
                metr_file.write(f"{stats['valuable_data'][name]:.4f};")
                
                print(f"{stats['valuable_data'][name]:.4f}", end=', ')
            print()
            for name in ['activation', 'scheduler']:
                metr_file.write(f"{str(stats['valuable_data'][name])};")
            metr_file.write(f'{str(additional_mark)}' + '\n')

    return models_sorted_by_accuracy, best_checkpoints

# Research validation

In [None]:
from itertools import product 
import random
import os
import numpy as np

def draw_BERSNR_with_accuracy(
        data, 
        dir='graphs', 
        label='random', 
        filename='random',
        save_format='png', 
        dpi=300,
        fig_size = (12, 9),
        mark_every = 5,
        plot_title = ''
    ):
    os.makedirs(dir, exist_ok=True)
    
    graph_path = os.path.join(dir, filename + f'_BERSNR.{save_format}')
    
    bers = [1 - stats['accuracy'].cpu() for stats in data.values()]
    snrs = data.keys()
    plt.figure(figsize=fig_size)


    plt.plot(snrs, bers, marker='o', markevery=mark_every, linestyle='-', color='b', label=label)

    plt.title(plot_title, fontsize=14)
    plt.xlabel("SNR (dB)", fontsize=12)
    plt.ylabel("BER (log scale)", fontsize=12)
    plt.yscale('log', base=10)
    plt.grid(True, which="both", linestyle='--', alpha=0.6)
    plt.legend()
    
    if save_format.lower() == 'png':
        plt.savefig(graph_path, dpi=dpi, bbox_inches='tight')
    else:
        plt.savefig(graph_path, bbox_inches='tight', format='pdf')
    plt.show()


def draw_multiple_BERSNR(
        multimodel_data, 
        dir='graphs',
        filename='random_multimodel', 
        save_format='png', 
        dpi=300,
        fig_size = (12, 9),
        mark_every = 5,
        plot_title = '' # "BER to SNR dependency (multiple models)"
    ):
    os.makedirs(dir, exist_ok=True)
    graph_path = os.path.join(dir, f'{filename}_BERSNR.{save_format}')
    
    plt.figure(figsize=fig_size)
    
    # Цвета и стили для различия моделей
    colors = plt.cm.tab20.colors


    class get_unique_marker_style():
        def __init__(self):
            markers = ['o', 's', '^', 'D', 'v', 'p', '*', 'h', '8', 'X']
            styles = ['-', '--', '-.', ':']
            self.combinations = list(product(markers, styles))

        def __len__(self):
            return len(self.combinations)

        def __iter__(self):
            self.current = 0
            return self

        def __next__(self):
            if self.current >= self.__len__():
                s = get_unique_marker_style()
                self = iter(s)

                
            idx = random.randint(0, self.__len__() - 1)
            
            return self.combinations.pop(idx)
            
    markerstyle = iter(get_unique_marker_style())
    
    for idx, (model_name, data) in enumerate(multimodel_data.items()):
        bers = [1 - stats['accuracy'].cpu() for stats in data.values()]
        snrs = list(data.keys())

        color = colors[idx % len(colors)]
        marker, style = markerstyle.__next__()
        
        plt.plot(
            snrs, bers, 
            marker=marker, 
            linestyle=style, 
            color=color, 
            label=model_name,
            markevery=mark_every
        )
    
    plt.title(plot_title, fontsize=14)
    plt.xlabel("SNR (dB)", fontsize=12)
    plt.ylabel("BER (log scale)", fontsize=12)
    plt.yscale('log', base=10)
    plt.grid(True, which="both", linestyle='--', alpha=0.6)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    if save_format.lower() == 'png':
        plt.savefig(graph_path, dpi=dpi, bbox_inches='tight')
    else:
        plt.savefig(graph_path, bbox_inches='tight', format='pdf')
    plt.show()



def accuracy_to_snr_checkpoint_research(
        checkpoint_path,
        research_config = {
            'dir' : 'SNRs',
            'max_size' : 1e7,
            'batch_size' : 1024,
            'max_errors' : 100,
            
            'snr_left' : 1,
            'snr_right' : 5.5, # 5.5 for 73, 7.5 for 105
            'snr_step' : 0.1
        },
        graph_config    = {
            'draw?' : True,
            'dir' : 'graphs', 
            'label' : 'random',
            'filename' : 'random',
            'save_format' : 'pdf', 
            'fig_size' : (12, 9),
            'mark_every' : 5,
        },
        mark            = '',
        verbose         = True,
        device          = torch.device("cpu")
    ):

    model, optim, model_info, metrics = load_checkpoint(checkpoint_path, device=device)
    classification_threshold = model_info['optimal_classification_threshold']

    val_res = dict()
    
    for cur_snr in np.arange(
            research_config['snr_left'], 
            research_config['snr_right'] + research_config['snr_step'], 
            research_config['snr_step']
        ):
        val_res[cur_snr] = eval_model(
            model,
            eval_config={
                'val_data_size' : research_config['max_size'],
                'batch_size' : research_config['batch_size'],
                'classification_threshold' : classification_threshold,
                'val_snr' : cur_snr,
                'early_stop_max_errors' : research_config['max_errors'] # 100 errors if so
            },
            device = device,
            verbose = verbose,
        )

    checkpoint_name = checkpoint_path.split('\\')[-1].split('.')[0]
    res_path = os.path.join(research_config['dir'], checkpoint_name + str(mark) + '_valres.txt')
    
    with open(res_path, 'w') as res_file:
        for cur_snr, metrics in val_res.items():
            res_file.write(
                str(cur_snr) + ';' + ';'.join([str(val) for val in metrics.values()]) + '\n'
            )

    if graph_config['draw?']:
        draw_BERSNR_with_accuracy(
            val_res,
            dir = graph_config['dir'],
            label = graph_config['label'],
            filename = graph_config['filename'],
            save_format = graph_config['save_format'],
            fig_size = graph_config['fig_size'],
            mark_every = graph_config['mark_every']
        )

    return val_res


def validate_multimodal(
        multimodel_data,
        criterion_name,
        checkpoints_dir='random',
        research_config = None,
        graph_config = None,
        device=torch.device('cpu'),
        verbose=False
    ):
    if graph_config is None:
        graph_config = {
            'draw?' : False,
            'dir' : 'graphs',
            'label' : 'random',
            'filename' : 'random',
            'save_format' : 'pdf',
            'fig_size' : (12, 9),
            'mark_every' : 5,
        }
    if research_config is None:
        research_config = {
            'dir' : 'SNRs',
            'max_size' : 1e7,
            'batch_size' : 1024,
            'max_errors' : 100,

            'snr_left' : 1,
            'snr_right' : 7, # 5.5 for 73, 7.5 for 105
            'snr_step' : 0.1
        }

    multisnr_multiaccuracy = dict()

    research_config['dir'] = os.path.join(checkpoints_dir, research_config['dir'])
    os.makedirs(research_config['dir'], exist_ok=True)
    graph_config['dir'] = os.path.join(checkpoints_dir, graph_config['dir'])
    os.makedirs(graph_config['dir'], exist_ok=True)

    for crit, data in multimodel_data.items():
        multisnr_multiaccuracy[f'{crit} {criterion_name}'] = accuracy_to_snr_checkpoint_research(
            data['path'],
            research_config=research_config,
            graph_config=graph_config,
            device=device,
            verbose=verbose
        )

    draw_multiple_BERSNR(
        multisnr_multiaccuracy,
        dir=graph_config['dir'],
        filename=criterion_name,
        save_format=graph_config['save_format'],
        fig_size=graph_config['fig_size'],
        mark_every=graph_config['mark_every']
    )

    return


def draw_metric_table(
        init_data, 
        criterion, 
        numeric_column_names = [
            'Accuracy', 'Precision', 'Recall', 'F1 Score'
        ],
        numeric_column_formal_names = [
            'accuracy', 'precision', 'recall', 'f1score'
        ],
        categorial_column_names = [ # yea yea
          'Epochs'  
        ],
        categorial_column_formal_names = [
          'epoch'
        ],
        dir = 'tables',
        table_name = 'random', 
        save_format='pdf'
    ):
    import matplotlib.pyplot as plt
    import numpy as np
    
    data = [
        [criterion, *numeric_column_names, *categorial_column_names],
    ]

    if type(init_data) == type(dict()):
        init_data_iter = init_data.items()
    elif type(init_data) == type(list()):
        init_data_iter = init_data
    
    for name, metrics in init_data_iter:
        table_cols = []
        for col_name in numeric_column_formal_names:
            table_cols.append(f"{metrics[col_name]:.4f}")
        for col_name in categorial_column_formal_names:
            table_cols.append(f"{metrics[col_name]}")
        data.append([name, *table_cols])
    # print(data)
    fig, ax = plt.subplots(figsize=(8, 2))
    ax.axis('off')
    table = ax.table(cellText=data, loc='center', cellLoc='center', colWidths=[0.2] * len(data))
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    dir = os.path.join(dir, 'tables')
    os.makedirs(dir, exist_ok=True)
    save_path = os.path.join(dir, f"{table_name}.{save_format}")
    print(save_path)
    plt.savefig(save_path, bbox_inches='tight', format=save_format)

# MDPC 73 research

### Number of layers

In [None]:
import torch.optim as optim

min_lr = 1e-6 # eta min
initial_lr = 1e-3 #
const_weight_decay = 1e-5

optimistic_num_epochs = 96

checkpoint_dir = 'optimal_layer_search_73'
criterion = 'num_layers'

search_layers = range(3, 17, 1)

train_config = {
    'num_epochs'               : optimistic_num_epochs,
    'train_data_size'          : 64 * 1024,
    'val_data_size'            : 32 * 1024,
    'batch_size'               : 1024,
    'train_snr'                : 4,
    'val_snr'                  : 4,
    'classification_threshold' : 0.5,
    'start_epoch'              : 0,
    'shuffle_data'             : False,
    'early_stop_bad_epochs'    : 8,
    'early_stop_max_errors'    : None,
}


set_seed(43)


for num_lay in search_layers:
    print(f"{num_lay}-layered model train started")
    model = NoiseEstimator(input_size, hidden_size, output_size, num_layers=num_lay, device=device)
    optimizer = optim.AdamW(model.parameters(), lr=initial_lr, weight_decay=const_weight_decay)
    
    report = train_loop(
        model, 
        optimizer, 
        scheduler = None, 
        train_config = train_config,
        checkpoint_dir = checkpoint_dir,
        checkpoint_timestamps = False,
        verbose = True,
        device = device
    )

    print('classification threshold optimization')
    
    model, optimizer, model_info, metric_values = load_checkpoint(report['path'], device)
    model_info['optimal_classification_threshold'] = optimize_threshold(model, device=device)
    save_checkpoint(
        model, 
        optimizer, 
        metric_values = metric_values, 
        model_info = model_info,
        checkpoint_dir = checkpoint_dir,
        add_time = False
    )
    
    print(f"{num_lay}-layered model train ended", end='\n\n\n')
    del model
    del optimizer

In [None]:
checkpoint_dir = 'optimal_layer_search_73'
criterion = 'num_layers'
top_models, all_models = checkpoints_analyze(checkpoints_dir = checkpoint_dir, criterion = criterion)

universal = sorted(top_models[:3])[2] 
best_universal_layer = universal[0]
best_warmup_epochs = universal[1]['valuable_data']['epoch'] // 2


print('\n\n' + f'best {criterion} is {best_universal_layer} with ~{best_warmup_epochs * 2} epochs!')

In [None]:
validate_multimodal(
    all_models,
    criterion_name='layer',
    checkpoints_dir=checkpoint_dir,
    device=device
)


In [None]:
draw_metric_table({p[0] : p[1]['valuable_data'] for p in top_models}, 'Layers', dir=checkpoint_dir, table_name='layers_table_73')

### Activations and scheduler

In [None]:
import torch.optim as optim

min_lr = 1e-6 # eta min
initial_lr = 1e-3 #
const_weight_decay = 1e-5
warmup_epochs = 32

checkpoint_dir = 'activation_lrsched_search_73'
criterion = 'activation'

activation_list = [nn.ReLU, nn.SiLU, nn.Tanh, SquaredReLU]
scheduler_list = [None, warmup_cosine_scheduler]

train_config = {
    'num_epochs'               : optimistic_num_epochs,
    'train_data_size'          : 64 * 1024,
    'val_data_size'            : 32 * 1024,
    'batch_size'               : 1024,
    'train_snr'                : 4,
    'val_snr'                  : 4,
    'classification_threshold' : 0.5,
    'start_epoch'              : 0,
    'shuffle_data'             : False,
    'early_stop_bad_epochs'    : 8,
    'early_stop_max_errors'    : None,
}


set_seed(43)


for activation in activation_list:
    for scheduler_object in scheduler_list:
        print(f"{str(activation())[:-2]}-active model train started")
        
        model = NoiseEstimator(input_size, hidden_size, output_size, 
                               num_layers=best_universal_layer, activation=activation, device=device
                )
        optimizer = optim.AdamW(model.parameters(), lr=initial_lr, weight_decay=const_weight_decay)

        scheduler = None
        if scheduler_object is not None:
            scheduler = scheduler_object(optimizer, train_config['num_epochs'], best_warmup_epochs, min_lr)
        
        report = train_loop(
            model, 
            optimizer, 
            scheduler = scheduler, 
            train_config = train_config,
            checkpoint_dir = checkpoint_dir,
            checkpoint_timestamps = False,
            verbose = True,
            device = device
        )
    
        print('classification threshold optimization')
        
        model, optimizer, model_info, metric_values = load_checkpoint(report['path'], device)
        model_info['optimal_classification_threshold'] = optimize_threshold(model, device=device)
        save_checkpoint(
            model, 
            optimizer, 
            metric_values = metric_values, 
            model_info = model_info,
            checkpoint_dir = checkpoint_dir,
            add_time = False
        )
        
        print(f"{str(activation())[:-2]}-active model train ended", end='\n\n\n')
        del model
        del optimizer

In [None]:
checkpoint_dir = 'activation_lrsched_search_73'
criterion = 'activation'


top_models = checkpoints_analyze(checkpoints_dir = checkpoint_dir, criterion = criterion)

print('\n\n\n')

best_model_data = top_models[0][0] 
best_activation = best_model_data[0]
best_activation_name = str(best_activation())[:-2]
act_sched = dict()
act_sched_zip = dict()

criterion = 'scheduler'

for activation in activation_list:
    act_name = str(activation())[:-2]
    cur_dir = os.path.join(checkpoint_dir, act_name)
    os.makedirs(cur_dir, exist_ok=True)
    for checkpoint_name in [f for f in os.listdir(checkpoint_dir) if f.endswith('pth')]:
        checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
        

        if not '_' + act_name in checkpoint_name:
            continue

        if not os.path.isfile(checkpoint_path):
            continue

        model, optim, model_info, metric_values = load_checkpoint(checkpoint_path, device)

        save_checkpoint(model, optim, metric_values = metric_values, model_info=model_info,
                       checkpoint_dir = cur_dir)
    
    
    top_models, all_models = checkpoints_analyze(checkpoints_dir = cur_dir, criterion = criterion, metrics_open='a')


    for sch, data in top_models:
        act_sched_zip[f'{act_name} {"+" if sch is not None else "-"}'] = data
    
    best_model_act_scheduler = top_models[0][0]
    act_sched[act_name] = best_model_act_scheduler
    print(f"for activation {act_name} scheduler {best_model_act_scheduler}", end='\n\n\n')

best_scheduler = act_sched[str(best_activation())[:-2]]

print('\n\n' + f'best {criterion} for {best_activation_name} is {best_scheduler}!')



In [None]:
validate_multimodal(
    act_sched_zip,
    criterion_name='',
    checkpoints_dir=checkpoint_dir,
    device=device
)

In [None]:

draw_metric_table({p[0] : p[1]['valuable_data'] for p in act_sched_zip.items()}, 'Activation', dir=checkpoint_dir, table_name='activations_table_73')

### Train SNR

In [None]:
import torch.optim as optim

min_lr = 1e-6 # eta min
initial_lr = 1e-3
const_weight_decay = 1e-5


checkpoint_dir = 'train_snr_search_73'
criterion = 'train_snr'

train_snr_list = [1, 1.5, 2, 3, 3.4, 3.8, 4, 4.2, 4.6, 5] #

train_config = {
    'num_epochs'               : optimistic_num_epochs * 2,
    'train_data_size'          : 64 * 1024,
    'val_data_size'            : 32 * 1024,
    'batch_size'               : 1024,
    'train_snr'                : 4,
    'val_snr'                  : 4,
    'classification_threshold' : 0.5,
    'start_epoch'              : 0,
    'shuffle_data'             : False,
    'early_stop_bad_epochs'    : 8,
    'early_stop_max_errors'    : None,
}


set_seed(43)


for train_snr in train_snr_list:
    print(f"{str(train_snr)}-snr driven model train started")
    
    model = NoiseEstimator(input_size, hidden_size, output_size, 
                           num_layers=best_universal_layer, activation=best_activation, device=device
            )
    optimizer = optim.AdamW(model.parameters(), lr=initial_lr, weight_decay=const_weight_decay)

    scheduler = None
    if best_scheduler is not None:
        scheduler = best_scheduler.__class__(optimizer, train_config['num_epochs'], best_warmup_epochs, min_lr)

    train_config['train_snr'] = train_snr
    report = train_loop(
        model, 
        optimizer, 
        scheduler = scheduler, 
        train_config = train_config,
        checkpoint_dir = checkpoint_dir,
        checkpoint_mark = f'{str(train_snr)}',
        checkpoint_timestamps = False,
        verbose = True,
        device = device
    )

    print('classification threshold optimization')
    
    model, optimizer, model_info, metric_values = load_checkpoint(report['path'], device)

    model_info['optimal_classification_threshold'] = optimize_threshold(model, device=device)
    save_checkpoint(
        model, 
        optimizer, 
        metric_values = metric_values, 
        model_info = model_info,
        checkpoint_dir = checkpoint_dir,
        additional_mark = f'{str(train_snr)}',
        add_time = False
    )
    
    print(f"{str(train_snr)}-snr driven model train ended", end='\n\n\n')
    del model
    del optimizer


In [None]:
checkpoint_dir = 'train_snr_search_73'
criterion = 'train_snr'
top_models, all_models = checkpoints_analyze(checkpoints_dir = checkpoint_dir, criterion = criterion)

best_model_data = top_models[0]
best_train_snr = best_model_data[0]

print('\n\n' + f'best {criterion} is {best_train_snr}!')

In [None]:
validate_multimodal(
    all_models,
    criterion_name='',
    checkpoints_dir=checkpoint_dir,
    device=device
)

In [None]:
draw_metric_table({p[0] : p[1]['valuable_data'] for p in top_models}, 'Train SNR', dir=checkpoint_dir, table_name='SNRs_table_73')

# MDPC 105

### Layers 105

In [None]:
import torch.optim as optim 

min_lr = 1e-6 # eta min
initial_lr = 1e-3
const_weight_decay = 1e-5

optimistic_num_epochs = 96

checkpoint_dir = 'optimal_layer_search_105'
criterion = 'num_layers'

search_layers = range(3, 17, 1)

train_config = {
    'num_epochs'               : optimistic_num_epochs,
    'train_data_size'          : 64 * 1024,
    'val_data_size'            : 32 * 1024,
    'batch_size'               : 1024,
    'train_snr'                : 4,
    'val_snr'                  : 4,
    'classification_threshold' : 0.5,
    'start_epoch'              : 0,
    'shuffle_data'             : False,
    'early_stop_bad_epochs'    : 8,
    'early_stop_max_errors'    : None,
}


set_seed(43)


for num_lay in search_layers:
    print(f"{num_lay}-layered model train started")
    model = NoiseEstimator(input_size, hidden_size, output_size, num_layers=num_lay, device=device)
    optimizer = optim.AdamW(model.parameters(), lr=initial_lr, weight_decay=const_weight_decay)
    
    report = train_loop(
        model, 
        optimizer, 
        scheduler = None, 
        train_config = train_config,
        checkpoint_dir = checkpoint_dir,
        checkpoint_timestamps = False,
        verbose = True,
        device = device
    )

    print('classification threshold optimization')
    
    model, optimizer, model_info, metric_values = load_checkpoint(report['path'], device)

    model_info['optimal_classification_threshold'] = optimize_threshold(model, device=device)
    save_checkpoint(
        model, 
        optimizer, 
        metric_values = metric_values, 
        model_info = model_info,
        checkpoint_dir = checkpoint_dir,
        add_time = False
    )
    
    print(f"{num_lay}-layered model train ended", end='\n\n\n')
    del model
    del optimizer

In [None]:
# checkpoint_dir = 'optimal_layer_search_105'
# criterion = 'num_layers'
top_models, all_models = checkpoints_analyze(checkpoints_dir = checkpoint_dir, criterion = criterion)

universal = sorted(top_models[:3])[0]
best_universal_layer = universal[0]
best_warmup_epochs = universal[1]['valuable_data']['epoch'] // 2


print('\n\n' + f'best {criterion} is {best_universal_layer} with ~{best_warmup_epochs * 2} epochs!')

In [None]:
validate_multimodal(
    all_models,
    criterion_name='layer',
    checkpoints_dir=checkpoint_dir,
    device=device
)

In [None]:
draw_metric_table({p[0] : p[1]['valuable_data'] for p in top_models}, 'Layers', dir=checkpoint_dir, table_name='layers_table_105')

### Activations and Scheduler 105

In [None]:
import torch.optim as optim

min_lr = 1e-6 # eta min
initial_lr = 1e-3
const_weight_decay = 1e-5
warmup_epochs = 32

checkpoint_dir = 'activation_lrsched_search_105'
criterion = 'activation'

activation_list = [nn.ReLU, nn.SiLU, nn.Tanh, SquaredReLU]
scheduler_list = [None, warmup_cosine_scheduler]

train_config = {
    'num_epochs'               : optimistic_num_epochs,
    'train_data_size'          : 64 * 1024,
    'val_data_size'            : 32 * 1024,
    'batch_size'               : 1024,
    'train_snr'                : 4,
    'val_snr'                  : 4,
    'classification_threshold' : 0.5,
    'start_epoch'              : 0,
    'shuffle_data'             : False,
    'early_stop_bad_epochs'    : 8,
    'early_stop_max_errors'    : None,
}


set_seed(43)


for activation in activation_list:
    for scheduler_object in scheduler_list:
        print(f"{str(activation())[:-2]}-active model train started")
        
        model = NoiseEstimator(input_size, hidden_size, output_size, 
                               num_layers=best_universal_layer, activation=activation, device=device
                )
        optimizer = optim.AdamW(model.parameters(), lr=initial_lr, weight_decay=const_weight_decay)

        scheduler = None
        if scheduler_object is not None:
            scheduler = scheduler_object(optimizer, train_config['num_epochs'], best_warmup_epochs, min_lr)
        
        report = train_loop(
            model, 
            optimizer, 
            scheduler = scheduler, 
            train_config = train_config,
            checkpoint_dir = checkpoint_dir,
            checkpoint_timestamps = False,
            verbose = True,
            device = device
        )
    
        print('classification threshold optimization')
        
        model, optimizer, model_info, metric_values = load_checkpoint(report['path'], device)

        model_info['optimal_classification_threshold'] = optimize_threshold(model, device=device)
        save_checkpoint(
            model, 
            optimizer, 
            metric_values = metric_values, 
            model_info = model_info,
            checkpoint_dir = checkpoint_dir,
            
            add_time = False
        )
        
        print(f"{str(activation())[:-2]}-active model train ended", end='\n\n\n')
        del model
        del optimizer

In [None]:
# checkpoint_dir = 'activation_lrsched_search_105'
# criterion = 'activation'

top_models = checkpoints_analyze(checkpoints_dir = checkpoint_dir, criterion = criterion)

print('\n\n\n')

best_model_data = top_models[0][0]
best_activation = best_model_data[0]
best_activation_name = str(best_activation())[:-2]
act_sched = dict()
act_sched_zip = dict()

criterion = 'scheduler'

for activation in activation_list:
    act_name = str(activation())[:-2]
    cur_dir = os.path.join(checkpoint_dir, act_name)
    os.makedirs(cur_dir, exist_ok=True)
    for checkpoint_name in [f for f in os.listdir(checkpoint_dir) if f.endswith('pth')]:
        checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)


        if not '_' + act_name in checkpoint_name:
            continue

        if not os.path.isfile(checkpoint_path):
            continue

        model, optim, model_info, metric_values = load_checkpoint(checkpoint_path, device)

        save_checkpoint(model, optim, metric_values = metric_values, model_info=model_info,
                       checkpoint_dir = cur_dir)


    top_models, all_models = checkpoints_analyze(checkpoints_dir = cur_dir, criterion = criterion, metrics_open='a')


    for sch, data in top_models:
        act_sched_zip[f'{act_name} {"+" if sch is not None else "-"}'] = data
    
    best_model_act_scheduler = top_models[0][0]
    act_sched[act_name] = best_model_act_scheduler
    print(f"for activation {act_name} scheduler {best_model_act_scheduler}", end='\n\n\n')

best_scheduler = act_sched[str(best_activation())[:-2]]

print('\n\n' + f'best {criterion} for {best_activation_name} is {best_scheduler}!')



In [None]:
validate_multimodal(
    act_sched_zip,
    criterion_name='',
    checkpoints_dir=checkpoint_dir,
    device=device
)

In [None]:
draw_metric_table({p[0] : p[1]['valuable_data'] for p in act_sched_zip.items()}, 'Activation', dir=checkpoint_dir, table_name='activations_table_105')

### Train SNR 105

In [None]:
import torch.optim as optim

min_lr = 1e-6 # eta min
initial_lr = 1e-3
const_weight_decay = 1e-5


checkpoint_dir = 'train_snr_search_105'
criterion = 'train_snr'

train_snr_list = [1, 1.5, 2, 3, 3.4, 3.8, 4, 4.2, 4.6, 5] #

train_config = {
    'num_epochs'               : optimistic_num_epochs,
    'train_data_size'          : 64 * 1024,
    'val_data_size'            : 32 * 1024,
    'batch_size'               : 1024,
    'train_snr'                : 4,
    'val_snr'                  : 4,
    'classification_threshold' : 0.5,
    'start_epoch'              : 0,
    'shuffle_data'             : False,
    'early_stop_bad_epochs'    : 8,
    'early_stop_max_errors'    : None,
}


set_seed(43)


for train_snr in train_snr_list:
    print(f"{str(train_snr)}-snr driven model train started")
    
    model = NoiseEstimator(input_size, hidden_size, output_size, 
                           num_layers=best_universal_layer, activation=best_activation, device=device
            )
    optimizer = optim.AdamW(model.parameters(), lr=initial_lr, weight_decay=const_weight_decay)

    scheduler = None
    if best_scheduler is not None:
        scheduler = best_scheduler(optimizer, train_config['num_epochs'], best_warmup_epochs, min_lr)

    train_config['train_snr'] = train_snr
    report = train_loop(
        model, 
        optimizer, 
        scheduler = scheduler, 
        train_config = train_config,
        checkpoint_dir = checkpoint_dir,
        checkpoint_mark = f'{str(train_snr)}',
        checkpoint_timestamps = False,
        verbose = True,
        device = device
    )

    print('classification threshold optimization')
    
    model, optimizer, model_info, metric_values = load_checkpoint(report['path'], device)

    model_info['optimal_classification_threshold'] = optimize_threshold(model, device=device)
    save_checkpoint(
        model, 
        optimizer, 
        metric_values = metric_values, 
        model_info = model_info,
        checkpoint_dir = checkpoint_dir,
        additional_mark = f'{str(train_snr)}',
        add_time = False
    )
    
    print(f"{str(train_snr)}-snr driven model train ended", end='\n\n\n')
    del model
    del optimizer


In [None]:
# checkpoint_dir = 'train_snr_search_105'
# criterion = 'train_snr'
top_models, all_models = checkpoints_analyze(checkpoints_dir = checkpoint_dir, criterion = criterion)

best_model_data = top_models[0]
best_train_snr = best_model_data[0]

print('\n\n' + f'best {criterion} is {best_train_snr}!')

In [None]:
validate_multimodal(
    all_models,
    criterion_name='',
    checkpoints_dir=checkpoint_dir,
    device=device
)

In [None]:
draw_metric_table({p[0] : p[1]['valuable_data'] for p in top_models}, 'Train SNR', dir=checkpoint_dir, table_name='SNRs_table_105')