In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

class PowerNormParts(nn.Module):
    def __init__(self, parts_count, cplx=False, part_last_dim=True, **kwargs):
        super().__init__(**kwargs)
        self.pc = parts_count
        self.cplx = cplx
        self.part_last_dim = part_last_dim  # if False, partitions the second dim

    def forward(self, inputs):
        shape = inputs.shape
        if self.part_last_dim:
            inputs = inputs.view(shape[0], -1, shape[-1]).permute(0, 2, 1)

        flatp = inputs.view(shape[0], self.pc, -1)
        if self.cplx:
            dsize = flatp.shape[2] // 2
        else:
            dsize = flatp.shape[2]
        dsize_f = torch.cast(dsize, dtype=torch.float32)

        norm = torch.norm(flatp, dim=2).real()
        norm = norm.unsqueeze(dim=-1)

        out = torch.sqrt(dsize_f) * flatp / norm
        if self.part_last_dim:
            out = out.view(shape[0], shape[-1], -1).permute(0, 2, 1)
        out = out.view(shape)
        return out

In [2]:
class Channel(nn.Module):
    def __init__(self, snr, cplx=False, **kwargs):
        super().__init__(**kwargs)
        self.cplx = cplx
        self.snr = snr
        self.set_noise_std()

    def forward(self, inputs):
        shape = inputs.shape
        gnoise = torch.randn(shape, mean=0, std=self.noise_std)
        return inputs + gnoise

    def get_snr(self):
        return self.snr

    def set_snr(self, snr):
        self.snr = snr
        self.set_noise_std()

    def set_noise_std(self):
        if self.cplx:
            self.noise_std = np.sqrt(10**(-self.snr/10)) / np.sqrt(2)
        else:
            self.noise_std = np.sqrt(10**(-self.snr/10))

In [3]:
def PSNR_plotter(x_axis, model, channel, testX, stages_count=1, goal=None):
    sc = stages_count
    PSNRs = np.zeros((sc, len(x_axis)))
    pre_snr = channel.get_snr()

    for i, snr in enumerate(x_axis):
        channel.set_snr(snr)
        preds = model(testX)

        for j in range(sc):
            preds_stage = preds[j]
            PSNRs[j, i] = torch.mean(torch.image.psnr(testX, preds_stage, max_val=1.0))

    channel.set_snr(pre_snr)

    if sc == 1:
        plt.plot(x_axis, PSNRs[i], label='Model')
    else:
        for i in range(sc):
            plt.plot(x_axis, PSNRs[i], label='Stage_' + str(i + 1))

    if goal is not None:
        plt.plot(x_axis, goal, label='Goal')

    plt.legend(loc='lower right')
    plt.grid()
    plt.show()

In [6]:
import torch
from torchmetrics import Metric

class PSNRMetric(Metric):
    """Computes the Peak Signal-to-Noise Ratio (PSNR) between two tensors.

    Args:
        name (str, optional): Name of the metric. Defaults to "PSNR".
    """

    def __init__(self, name="PSNR"):
        super().__init__(name=name)
        self.add_state("PSNR_additive", default=torch.zeros(1), dist_reduce_fx="sum")
        self.add_state("counter", default=torch.zeros(1), dist_reduce_fx="sum")

    def update(self, y_true, y_pred):
        """Updates the metric with the given predictions and ground truth labels.

        Args:
            y_true (torch.Tensor): Ground truth labels.
            y_pred (torch.Tensor): Predictions.
        """

        PSNR = torch.mean(torch.image.psnr(y_true, y_pred, max_val=1.0))
        self.PSNR_additive += PSNR
        self.counter += 1

    def compute(self):
        """Computes the metric value.

        Returns:
            torch.Tensor: The PSNR value.
        """

        return self.PSNR_additive / self.counter

    def reset(self):
        """Resets the metric state to its initial values.
        """

        self.PSNR_additive.zero_()
        self.counter.zero_()

In [7]:
def lr_scheduler(epoch, lr):
  if epoch == 0:
    print("\nlearning_rate: 0.001")
  elif epoch == 20:
    print("\nlearning_rate: 0.0005")
  elif epoch == 30:
    print("\nlearning_rate: 0.0001")

  if epoch < 20:
    return 0.001
  elif epoch < 30:
    return 0.0005
  else:
    return 0.0001

In [8]:
from torch.utils.data import DataLoader
from torch.optim import Optimizer
from typing import Any, Callable

class FuncCaller(object):
    def __init__(self, period: int, function: Callable, *args: Any, **kwargs: Any):
        self.period = period
        self.fn = function
        self.args = args
        self.kwargs = kwargs

    def __call__(self, epoch: int, train_loader: DataLoader, model: Any, optimizer: Optimizer, loss: Any, **kwargs: Any) -> None:
        if epoch % self.period == 0:
            self.fn(*self.args, **self.kwargs)

In [9]:
import time
from torch.utils.data import DataLoader
from torch.optim import Optimizer
from typing import Any, Callable

class EpochDotter(object):
    def __init__(self, nl_period: int, dot_period: int = 1):
        self.nl_period = nl_period
        self.dot_period = dot_period
        self.tic = None

    def __call__(self, epoch: int, train_loader: DataLoader, model: Any, optimizer: Optimizer, loss: Any, **kwargs: Any) -> None:
        if epoch % self.dot_period == 0:
            print('.', end='')

        if epoch % self.nl_period == 0:
            toc = time.time()
            print(" {} epochs".format(epoch), end='')
            if self.tic is not None:
                print(" - {0:.2f}s/epoch".format((toc - self.tic) / self.nl_period), end='')
            for key in list(kwargs.keys()):
                print(" - {}: {}".format(key, kwargs[key]), end='')
            print("")
            self.tic = time.time()

In [10]:
def create_encoder(out_chs, img_shape=(None, None, 3), name=None):
    encoder = nn.Sequential(name=name)
    encoder.add_module('input', nn.Conv2d(3, 16, kernel_size=5, stride=2, padding=2))
    encoder.add_module('prelu1', nn.PReLU())
    encoder.add_module('conv1', nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2))
    encoder.add_module('prelu2', nn.PReLU())
    encoder.add_module('conv2', nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2))
    encoder.add_module('prelu3', nn.PReLU())
    encoder.add_module('conv3', nn.Conv2d(32, out_chs, kernel_size=5, stride=1, padding=2))
    return encoder

In [11]:
def create_decoder(input_shape, img_chs=3, name=None):
    decoder = nn.Sequential(name=name)
    decoder.add_module('input', nn.ConvTranspose2d(32, 32, kernel_size=5, stride=1, padding=2))
    decoder.add_module('prelu1', nn.PReLU())
    decoder.add_module('convtranspose1', nn.ConvTranspose2d(32, 32, kernel_size=5, stride=1, padding=2))
    decoder.add_module('prelu2', nn.PReLU())
    decoder.add_module('convtranspose2', nn.ConvTranspose2d(32, 16, kernel_size=5, stride=1, padding=2))
    decoder.add_module('prelu3', nn.PReLU())
    decoder.add_module('convtranspose3', nn.ConvTranspose2d(16, img_chs, kernel_size=5, stride=2, padding=2, output_padding=1))
    decoder.add_module('sigmoid', nn.Sigmoid())
    return decoder

In [None]:
ps = 64   # patch size (Height and Width)
enc_chs = 5 * 4*2    # compression_ratio = 5 * 1/12
stages_count = 5
dec_chs = enc_chs // stages_count
SNR = 13

epochs = 500
batch_size = int(32 * (32/ps)**2)
loss_func = keras.losses.MeanSquaredError()

drive_dir = '/content/drive/'
JSCC_dir = os.path.join(drive_dir, 'MyDrive/Colab Stuff/Efficient_successive_img_tr/JSCC_2')
JSCC_channel_name = 'Channel'   # to save the channel with this name, for further usage

train_count = int(50000 * (32/ps)**2)   # number of patches in trainset
test_count = int(10000 * (32/ps)**2)   # number of patches in testset

In [12]:
import torch
import torch.nn as nn

# Create layers
encoder = create_encoder(enc_chs, img_shape=(None, None, 3), name="Encoder")
powernorm = PowerNormParts(stages_count, cplx=True, name="PowerNorm")
channel = Channel(SNR, cplx=True, name=JSCC_channel_name)
decoders = []
for i in range(stages_count):
    decoders.append(create_decoder((None, None, dec_chs * (i + 1)), img_chs=3, name="Decoder_" + str(i + 1)))

# Construct the model
model_input = torch.randn(1, 3, 224, 224)  # Sample input to initialize the model
encoder_out = encoder(model_input)
power_out = powernorm(encoder_out)
channel_out = channel(power_out)
outputs = []
losses = []
for i in range(stages_count):
    outputs.append(decoders[i](channel_out[:, :, :, :dec_chs * (i + 1)]))
    losses.append(loss_func(model_input, outputs[i]))

raw_model = nn.Sequential(encoder, powernorm, channel, *decoders)
model = nn.Sequential(encoder, powernorm, channel, *decoders)

for loss in losses:
    model.add_loss(loss)

NameError: name 'enc_chs' is not defined