In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torchmetrics

In [2]:
from pathlib import Path
log_dir = Path('logs/')

In [3]:
from torch.utils.data import Dataset, DataLoader

class ModulatorDataset(Dataset):
    def __init__(self, bit_count):
        super().__init__()
        self.bit_count = bit_count
        
    def __len__(self):
        return 2 ** self.bit_count
    
    def __getitem__(self, idx):
        result = torch.zeros(2 ** self.bit_count)
        result[idx] = 1
        return result

bit_count = 4
class_count = 2 ** bit_count
dataset = ModulatorDataset(bit_count)
dataloader = DataLoader(dataset, batch_size=class_count, shuffle=True)

In [None]:
class EntropyNormalization(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        energy = (x ** 2).sum(axis=1).mean()
        return x / energy.sqrt()

class AwgnNoise(nn.Module):
    def __init__(self, snr):
        super().__init__()
        self.sigma = np.sqrt(1 / (2 * 10 ** (snr / 10)))
        
    def forward(self, x):
        return x + torch.normal(torch.zeros_like(x, device=self.device), 
                                torch.full_like(x, self.sigma, device=self.device))


class ModulatorAutoencoder(pl.LightningModule):
    def __init__(self, class_count, encoding_shape, snr):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(class_count, 4 * class_count),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(4 * class_count, encoding_shape),
            EntropyNormalization()
        )
        self.noise = AwgnNoise(snr)
        self.decoder = nn.Sequential(
            nn.Linear(encoding_shape, 4 * class_count),
            nn.ReLU(),
            nn.Linear(4 * class_count, class_count),
            nn.Softmax()
        )
        self.loss_function = nn.CrossEntropyLoss()
        self.symbol_error_rate = torchmetrics.Accuracy()
        
    def forward(self, x):
        encoded = self.encoder(x)
        noised = self.noise(encoded)
        return self.decoder(noised)
    
    def training_step(self, batch, batch_idx):
        decoded = self(batch)
        prediction = decoded.argmax(-1)
        true_classes = batch.argmax(-1)
        loss = self.loss_function(decoded, batch)
        ser = self.symbol_error_rate(prediction, true_classes)
        self.log('ser', ser, on_epoch=True, on_step=False, prog_bar=True)
        self.log('loss', loss, on_epoch=True, on_step=False)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.005)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.3, patience=1000, threshold=0.005, min_lr=1e-5)
        return {
            'optimizer': optimizer,
            'scheduler': scheduler,
            'monitor': 'loss'
        }