In [1]:
%cd ..

import os
import torch
import copy
import numpy as np
from tqdm.notebook import tqdm


from alignment.alignment_utils import load_deep_jscc
from alignment.alignment_model import *
from alignment.alignment_model import _LinearAlignment, _MLPAlignment, _ConvolutionalAlignment, _ZeroShotAlignment, _TwoConvAlignment
from alignment.alignment_training import *
from alignment.alignment_validation import *

/home/lorenzo/repos/Deep-JSCC-PyTorch


In [2]:
snr_ae = 30
n_samples = 10000
resolution = 96

snrs = [-20, -10, 0, 10, 20, 30]
seeds = [42, 43, 44, 45, 46]

model1_fp = f'alignment/models/autoencoders/snr_{snr_ae}_seed_42.pkl'
model2_fp = f'alignment/models/autoencoders/snr_{snr_ae}_seed_43.pkl'
folder = f'psnr_vs_snr'
os.makedirs(f'alignment/models/plots/{folder}', exist_ok=True)

dataset = "cifar10"
channel = 'AWGN'
batch_size = 64
num_workers = 4

logs_folder = f'alignment/logs_mismatch_{resolution}'
os.makedirs(logs_folder, exist_ok=True)

times = 10
c = 8

device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

train_loader, test_loader = get_data_loaders(dataset, resolution, batch_size, num_workers)
data = load_alignment_dataset(model1_fp, model2_fp, snr_ae, train_loader, c, device)

Caching inputs: 100%|██████████| 782/782 [00:04<00:00, 177.46it/s]


# No mismatch - Unaligned - Zeroshot max

In [3]:
for snr in snrs:
    for seed in seeds:
        encoder = copy.deepcopy(load_deep_jscc(model1_fp, snr_ae, c, "AWGN").encoder)
        decoder = copy.deepcopy(load_deep_jscc(model2_fp, snr_ae, c, "AWGN").decoder)

        log_file = f"{logs_folder}/lines_ae_{snr_ae}_snr_{snr}_seed_{seed}.txt"

        # unaligned
        model = AlignedDeepJSCC(encoder, decoder, None, snr, "AWGN")

        result_msg = f"unaligned {validation_vectorized(model, test_loader, times, device):.2f}"
        print(result_msg)
        with open(log_file, 'a') as f:
            f.write(f"{result_msg}\n")

        # aligned
        model = AlignedDeepJSCC(encoder, copy.deepcopy(load_deep_jscc(model1_fp, snr, c, "AWGN").decoder), None, snr, "AWGN")

        result_msg = f"aligned {validation_vectorized(model, test_loader, times, device):.2f}"
        print(result_msg)
        with open(log_file, 'a') as f:
                f.write(f"{result_msg}\n")

        # zeroshot
        data.flat = True

        set_seed(seed)
        permutation = torch.randperm(len(data))

        aligner = train_zeroshot_aligner(data, permutation, resolution**2, snr, resolution**2, device)
        aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, snr, "AWGN")

        result_msg = f"zeroshot {validation_vectorized(aligned_model, test_loader, times, device):.2f}"
        print(result_msg)
        with open(log_file, 'a') as f:
                f.write(f"{result_msg}\n")

unaligned 5.23
aligned 7.28
zeroshot 5.85
unaligned 5.23
aligned 7.28
zeroshot 5.86
unaligned 5.23
aligned 7.28
zeroshot 5.86
unaligned 5.23
aligned 7.28
zeroshot 5.86
unaligned 5.23
aligned 7.28
zeroshot 5.87
unaligned 5.55
aligned 14.39
zeroshot 8.66
unaligned 5.55
aligned 14.39
zeroshot 8.72
unaligned 5.55
aligned 14.39
zeroshot 8.69
unaligned 5.55
aligned 14.39
zeroshot 8.67
unaligned 5.55
aligned 14.39
zeroshot 8.70
unaligned 5.47
aligned 25.21
zeroshot 15.90
unaligned 5.47
aligned 25.21
zeroshot 16.02
unaligned 5.47
aligned 25.21
zeroshot 15.95
unaligned 5.47
aligned 25.21
zeroshot 15.91
unaligned 5.47
aligned 25.21
zeroshot 15.97
unaligned 5.40
aligned 35.10
zeroshot 25.29
unaligned 5.40
aligned 35.10
zeroshot 25.42
unaligned 5.40
aligned 35.10
zeroshot 25.34
unaligned 5.40
aligned 35.10
zeroshot 25.30
unaligned 5.40
aligned 35.10
zeroshot 25.36
unaligned 5.39
aligned 43.75
zeroshot 33.85
unaligned 5.39
aligned 43.75
zeroshot 33.96
unaligned 5.39
aligned 43.75
zeroshot 33.90
una

# Least Squares

In [4]:
aligner_type = "linear"
data.flat = True

for snr in snrs:
    for seed in seeds:
        set_seed(seed)
        permutation = torch.randperm(len(data))

        aligner = train_linear_aligner(data, permutation, n_samples, snr)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        torch.save(aligner.state_dict(), aligner_fp)

        print(f"Done with SNR {snr} SEED {seed}.")

Done with SNR -20 SEED 42.
Done with SNR -20 SEED 43.
Done with SNR -20 SEED 44.
Done with SNR -20 SEED 45.
Done with SNR -20 SEED 46.
Done with SNR -10 SEED 42.
Done with SNR -10 SEED 43.
Done with SNR -10 SEED 44.
Done with SNR -10 SEED 45.
Done with SNR -10 SEED 46.
Done with SNR 0 SEED 42.
Done with SNR 0 SEED 43.
Done with SNR 0 SEED 44.
Done with SNR 0 SEED 45.
Done with SNR 0 SEED 46.
Done with SNR 10 SEED 42.
Done with SNR 10 SEED 43.
Done with SNR 10 SEED 44.
Done with SNR 10 SEED 45.
Done with SNR 10 SEED 46.
Done with SNR 20 SEED 42.
Done with SNR 20 SEED 43.
Done with SNR 20 SEED 44.
Done with SNR 20 SEED 45.
Done with SNR 20 SEED 46.
Done with SNR 30 SEED 42.
Done with SNR 30 SEED 43.
Done with SNR 30 SEED 44.
Done with SNR 30 SEED 45.
Done with SNR 30 SEED 46.


In [5]:
aligner_type = "linear"
aligner = _LinearAlignment(resolution**2)

for snr in snrs:
    for seed in seeds:

        log_file = f"{logs_folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.txt"

        with open(log_file, 'w') as f:
            pass

        set_seed(seed)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

        aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, snr, channel)

        psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
        
        result_msg = f"Linear model, AE {snr_ae} SNR {snr} SEED {seed} got a PSNR of {psnr_result:.2f}"
        print(result_msg)
        
        with open(log_file, 'a') as f:
            f.write(f"{result_msg}\n")

Linear model, AE 30 SNR -20 SEED 42 got a PSNR of 9.89
Linear model, AE 30 SNR -20 SEED 43 got a PSNR of 9.89
Linear model, AE 30 SNR -20 SEED 44 got a PSNR of 9.89
Linear model, AE 30 SNR -20 SEED 45 got a PSNR of 9.89
Linear model, AE 30 SNR -20 SEED 46 got a PSNR of 9.89
Linear model, AE 30 SNR -10 SEED 42 got a PSNR of 17.60
Linear model, AE 30 SNR -10 SEED 43 got a PSNR of 17.59
Linear model, AE 30 SNR -10 SEED 44 got a PSNR of 17.60
Linear model, AE 30 SNR -10 SEED 45 got a PSNR of 17.60
Linear model, AE 30 SNR -10 SEED 46 got a PSNR of 17.59
Linear model, AE 30 SNR 0 SEED 42 got a PSNR of 26.51
Linear model, AE 30 SNR 0 SEED 43 got a PSNR of 26.51
Linear model, AE 30 SNR 0 SEED 44 got a PSNR of 26.51
Linear model, AE 30 SNR 0 SEED 45 got a PSNR of 26.51
Linear model, AE 30 SNR 0 SEED 46 got a PSNR of 26.51
Linear model, AE 30 SNR 10 SEED 42 got a PSNR of 34.91
Linear model, AE 30 SNR 10 SEED 43 got a PSNR of 34.91
Linear model, AE 30 SNR 10 SEED 44 got a PSNR of 34.91
Linear mod

# Linear Neural

In [6]:
aligner_type = "neural"
data.flat = False

for snr in snrs:
    for seed in seeds:
        set_seed(seed)
        permutation = torch.randperm(len(data))
    
        aligner, epoch = train_neural_aligner(data, permutation, n_samples, batch_size, resolution, 6, snr, device)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        torch.save(aligner.state_dict(), aligner_fp)

        print(f"Done with SNR {snr} SEED {seed}.")

Done with SNR -20 SEED 42.
Done with SNR -20 SEED 43.
Done with SNR -20 SEED 44.
Done with SNR -20 SEED 45.
Done with SNR -20 SEED 46.
Done with SNR -10 SEED 42.
Done with SNR -10 SEED 43.
Done with SNR -10 SEED 44.
Done with SNR -10 SEED 45.
Done with SNR -10 SEED 46.


KeyboardInterrupt: 

In [None]:
aligner_type = "neural"
aligner = _LinearAlignment(resolution**2)

for snr in snrs:
    for seed in seeds:

        log_file = f"{logs_folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.txt"

        with open(log_file, 'w') as f:
            pass

        set_seed(seed)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

        aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, snr, channel)

        psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
        
        result_msg = f"Linear model, AE {snr_ae} SNR {snr} SEED {seed} got a PSNR of {psnr_result:.2f}"
        print(result_msg)
        
        with open(log_file, 'a') as f:
            f.write(f"{result_msg}\n")

Neural model, 1 samples got a PSNR of 12.60


Neural model, 2 samples got a PSNR of 12.78
Neural model, 4 samples got a PSNR of 12.06
Neural model, 6 samples got a PSNR of 12.69
Neural model, 11 samples got a PSNR of 12.92
Neural model, 18 samples got a PSNR of 13.27
Neural model, 29 samples got a PSNR of 13.52
Neural model, 48 samples got a PSNR of 13.96
Neural model, 78 samples got a PSNR of 14.55
Neural model, 127 samples got a PSNR of 15.00
Neural model, 206 samples got a PSNR of 15.41
Neural model, 335 samples got a PSNR of 15.98
Neural model, 545 samples got a PSNR of 16.68
Neural model, 885 samples got a PSNR of 18.00
Neural model, 1438 samples got a PSNR of 18.58
Neural model, 2335 samples got a PSNR of 20.09
Neural model, 3792 samples got a PSNR of 21.40
Neural model, 6158 samples got a PSNR of 23.84
Neural model, 10000 samples got a PSNR of 25.22


# MLP

In [None]:
aligner_type = "mlp"
data.flat = False

for snr in snrs:
    for seed in seeds:
        set_seed(seed)
        permutation = torch.randperm(len(data))

        aligner, epoch = train_mlp_aligner(data, permutation, n_samples, batch_size, resolution, 6, snr, device)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        torch.save(aligner.state_dict(), aligner_fp)

        print(f"Done with SNR {snr} SEED {seed}.")

Training: 100%|██████████| 1/1 [02:20<00:00, 140.62s/it]


In [None]:
aligner_type = "mlp"
aligner = _MLPAlignment(input_dim=resolution**2, hidden_dims=[resolution**2])

for snr in snrs:
    for seed in seeds:

        log_file = f"{logs_folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.txt"

        with open(log_file, 'w') as f:
            pass

        set_seed(seed)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

        aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, snr, channel)

        psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
        
        result_msg = f"Linear model, AE {snr_ae} SNR {snr} SEED {seed} got a PSNR of {psnr_result:.2f}"
        print(result_msg)
        
        with open(log_file, 'a') as f:
            f.write(f"{result_msg}\n")

MLP model, 10000 samples got a PSNR of 27.07


# Convolutional

In [7]:
aligner_type = "conv"
data.flat = False

for snr in snrs:
    for seed in seeds:
        set_seed(seed)
        permutation = torch.randperm(len(data))

        aligner, epoch = train_conv_aligner(data, permutation, n_samples, c, batch_size, snr, device)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        torch.save(aligner.state_dict(), aligner_fp)

        print(f"Done with SNR {snr} SEED {seed}.")

Done with SNR -20 SEED 42.
Done with SNR -20 SEED 43.
Done with SNR -20 SEED 44.
Done with SNR -20 SEED 45.
Done with SNR -20 SEED 46.
Done with SNR -10 SEED 42.
Done with SNR -10 SEED 43.
Done with SNR -10 SEED 44.
Done with SNR -10 SEED 45.
Done with SNR -10 SEED 46.
Done with SNR 0 SEED 42.
Done with SNR 0 SEED 43.
Done with SNR 0 SEED 44.
Done with SNR 0 SEED 45.
Done with SNR 0 SEED 46.
Done with SNR 10 SEED 42.
Done with SNR 10 SEED 43.
Done with SNR 10 SEED 44.
Done with SNR 10 SEED 45.
Done with SNR 10 SEED 46.
Done with SNR 20 SEED 42.
Done with SNR 20 SEED 43.
Done with SNR 20 SEED 44.
Done with SNR 20 SEED 45.
Done with SNR 20 SEED 46.
Done with SNR 30 SEED 42.
Done with SNR 30 SEED 43.
Done with SNR 30 SEED 44.
Done with SNR 30 SEED 45.
Done with SNR 30 SEED 46.


In [8]:
aligner_type = "conv"
aligner = _ConvolutionalAlignment(in_channels=2*c, out_channels=2*c, kernel_size=5)

for snr in snrs:
    for seed in seeds:

        log_file = f"{logs_folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.txt"

        with open(log_file, 'w') as f:
            pass

        set_seed(seed)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

        aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, snr, channel)

        psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
        
        result_msg = f"Linear model, AE {snr_ae} SNR {snr} SEED {seed} got a PSNR of {psnr_result:.2f}"
        print(result_msg)
        
        with open(log_file, 'a') as f:
            f.write(f"{result_msg}\n")

Linear model, AE 30 SNR -20 SEED 42 got a PSNR of 14.74
Linear model, AE 30 SNR -20 SEED 43 got a PSNR of 14.75
Linear model, AE 30 SNR -20 SEED 44 got a PSNR of 14.76
Linear model, AE 30 SNR -20 SEED 45 got a PSNR of 14.75
Linear model, AE 30 SNR -20 SEED 46 got a PSNR of 14.75
Linear model, AE 30 SNR -10 SEED 42 got a PSNR of 19.56
Linear model, AE 30 SNR -10 SEED 43 got a PSNR of 19.56
Linear model, AE 30 SNR -10 SEED 44 got a PSNR of 19.57
Linear model, AE 30 SNR -10 SEED 45 got a PSNR of 19.57
Linear model, AE 30 SNR -10 SEED 46 got a PSNR of 19.56
Linear model, AE 30 SNR 0 SEED 42 got a PSNR of 26.58
Linear model, AE 30 SNR 0 SEED 43 got a PSNR of 26.60
Linear model, AE 30 SNR 0 SEED 44 got a PSNR of 26.60
Linear model, AE 30 SNR 0 SEED 45 got a PSNR of 26.59
Linear model, AE 30 SNR 0 SEED 46 got a PSNR of 26.63
Linear model, AE 30 SNR 10 SEED 42 got a PSNR of 34.44
Linear model, AE 30 SNR 10 SEED 43 got a PSNR of 34.43
Linear model, AE 30 SNR 10 SEED 44 got a PSNR of 34.45
Linea

# Two Conv

In [None]:
aligner_type = "twoconv"
data.flat = False

for snr in snrs:
    for seed in seeds:
        set_seed(seed)
        permutation = torch.randperm(len(data))

        aligner, epoch = train_twoconv_aligner(data, permutation, n_samples, c, batch_size, snr, device)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        torch.save(aligner.state_dict(), aligner_fp)

        print(f"Done with SNR {snr} SEED {seed}.")

Training: 100%|██████████| 19/19 [16:56<00:00, 53.52s/it] 


In [None]:
aligner_type = "twoconv"
aligner = _TwoConvAlignment(in_channels=2*c, hidden_channels=2*c, out_channels=2*c, kernel_size=5)

for snr in snrs:
    for seed in seeds:

        log_file = f"{logs_folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.txt"

        with open(log_file, 'w') as f:
            pass

        set_seed(seed)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

        aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, snr, channel)

        psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
        
        result_msg = f"Linear model, AE {snr_ae} SNR {snr} SEED {seed} got a PSNR of {psnr_result:.2f}"
        print(result_msg)
        
        with open(log_file, 'a') as f:
            f.write(f"{result_msg}\n")

Twoconv model, 1 samples got a PSNR of 16.13
Twoconv model, 2 samples got a PSNR of 24.65
Twoconv model, 4 samples got a PSNR of 30.05
Twoconv model, 6 samples got a PSNR of 31.60
Twoconv model, 11 samples got a PSNR of 32.94
Twoconv model, 18 samples got a PSNR of 33.66
Twoconv model, 29 samples got a PSNR of 35.74
Twoconv model, 48 samples got a PSNR of 36.88
Twoconv model, 78 samples got a PSNR of 37.24
Twoconv model, 127 samples got a PSNR of 37.55
Twoconv model, 206 samples got a PSNR of 37.68
Twoconv model, 335 samples got a PSNR of 37.92
Twoconv model, 545 samples got a PSNR of 37.86
Twoconv model, 885 samples got a PSNR of 38.43
Twoconv model, 1438 samples got a PSNR of 38.45
Twoconv model, 2335 samples got a PSNR of 38.72
Twoconv model, 3792 samples got a PSNR of 38.53
Twoconv model, 6158 samples got a PSNR of 38.11
Twoconv model, 10000 samples got a PSNR of 38.74


# Zero-shot

In [None]:
aligner_type = "zeroshot"
data.flat = True

for snr in snrs:
    for seed in seeds:
        set_seed(seed)
        permutation = torch.randperm(len(data))

        aligner = train_zeroshot_aligner(data, permutation, n_samples, train_snr, n_samples, device)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        torch.save(aligner.state_dict(), aligner_fp)

        print(f"Done with SNR {snr} SEED {seed}.")

Training: 100%|██████████| 18/18 [01:31<00:00,  5.10s/it]


In [None]:
aligner_type = "zeroshot"
aligner = _ZeroShotAlignment(
    F_tilde=torch.zeros(n_samples, resolution**2),
    G_tilde=torch.zeros(resolution**2, n_samples), 
    G=torch.zeros(1, 1),
    L=torch.zeros(n_samples, n_samples),
    mean=torch.zeros(n_samples, 1)
)

for snr in snrs:
    for seed in seeds:

        log_file = f"{logs_folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.txt"

        with open(log_file, 'w') as f:
            pass

        set_seed(seed)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

        aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, snr, channel)

        psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
        
        result_msg = f"Linear model, AE {snr_ae} SNR {snr} SEED {seed} got a PSNR of {psnr_result:.2f}"
        print(result_msg)
        
        with open(log_file, 'a') as f:
            f.write(f"{result_msg}\n")

Zeroshot model, 2 samples got a PSNR of 11.28
Zeroshot model, 4 samples got a PSNR of 11.55
Zeroshot model, 6 samples got a PSNR of 11.77
Zeroshot model, 11 samples got a PSNR of 11.90
Zeroshot model, 18 samples got a PSNR of 12.72
Zeroshot model, 29 samples got a PSNR of 13.34
Zeroshot model, 48 samples got a PSNR of 13.29
Zeroshot model, 78 samples got a PSNR of 14.30
Zeroshot model, 127 samples got a PSNR of 15.11
Zeroshot model, 206 samples got a PSNR of 15.08
Zeroshot model, 335 samples got a PSNR of 16.24
Zeroshot model, 545 samples got a PSNR of 15.94
Zeroshot model, 885 samples got a PSNR of 16.95
Zeroshot model, 1438 samples got a PSNR of 19.43
Zeroshot model, 2335 samples got a PSNR of 18.31
Zeroshot model, 3792 samples got a PSNR of 22.53
Zeroshot model, 6158 samples got a PSNR of 26.78
Zeroshot model, 10000 samples got a PSNR of 28.33
