In [1]:
%cd ..

import os
import torch
import copy
import numpy as np

from alignment.alignment_utils import load_deep_jscc
from alignment.alignment_model import *
from alignment.alignment_model import _LinearAlignment, _MLPAlignment, _ConvolutionalAlignment, _ZeroShotAlignment, _TwoConvAlignment
from alignment.alignment_training import *
from alignment.alignment_validation import *

/home/lorenzo/repos/Deep-JSCC-PyTorch


In [2]:
model1_fp = r'alignment/models/autoencoders/snr_0_seed_42.pkl'
model2_fp = r'alignment/models/autoencoders/snr_0_seed_43.pkl'
folder = r'psnr_vs_pilots_5'
os.makedirs(f'alignment/models/plots/{folder}', exist_ok=True)

dataset = "cifar10"
resolution = 96
channel = 'AWGN'
batch_size = 64
num_workers = 4

train_snr = 0
val_snr = 0
times = 10
c = 8
seed = 42

samples_sets = np.unique(np.logspace(0, np.log10(10000), num=100, base=10).astype(int))
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

encoder = copy.deepcopy(load_deep_jscc(model1_fp, val_snr, c, "AWGN").encoder)
decoder = copy.deepcopy(load_deep_jscc(model2_fp, val_snr, c, "AWGN").decoder)

train_loader, test_loader = get_data_loaders(dataset, resolution, batch_size, num_workers)
data = load_alignment_dataset(model1_fp, model2_fp, train_snr, train_loader, c, device)

Caching inputs: 100%|██████████| 782/782 [00:06<00:00, 129.04it/s]


# No mismatch - Unaligned - Zeroshot max

In [3]:
model = AlignedDeepJSCC(encoder, decoder, None, val_snr, "AWGN")
print(f"Unaligned {validation_vectorized(model, test_loader, times, device):.2f}")

Unaligned 5.39


In [4]:
model = AlignedDeepJSCC(encoder, copy.deepcopy(load_deep_jscc(model1_fp, val_snr, c, "AWGN").decoder), None, val_snr, "AWGN")
print(f"Aligned {validation_vectorized(model, test_loader, times, device):.2f}")

Aligned 48.50


In [None]:
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

aligner = train_zeroshot_aligner(data, permutation, resolution**2, train_snr, resolution**2, device)
aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")
print(f"Zeroshot max size {validation_vectorized(model, test_loader, times, device):.2f}")

# Least Squares

In [5]:
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets:

    aligner = train_linear_aligner(data, permutation, n_samples, train_snr)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_linear_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    print(f"Done with {n_samples}.")

Done with 1.
Done with 2.
Done with 3.
Done with 4.
Done with 5.
Done with 6.
Done with 7.
Done with 8.
Done with 9.
Done with 10.
Done with 11.
Done with 12.
Done with 13.
Done with 14.
Done with 16.
Done with 17.
Done with 19.
Done with 21.
Done with 23.
Done with 25.
Done with 28.
Done with 31.
Done with 34.
Done with 37.
Done with 41.
Done with 45.
Done with 49.
Done with 54.
Done with 59.
Done with 65.
Done with 72.
Done with 79.
Done with 86.
Done with 95.
Done with 104.
Done with 114.
Done with 126.
Done with 138.
Done with 151.
Done with 166.
Done with 183.
Done with 200.
Done with 220.
Done with 242.
Done with 265.
Done with 291.
Done with 319.
Done with 351.
Done with 385.
Done with 422.
Done with 464.
Done with 509.
Done with 559.
Done with 613.
Done with 673.
Done with 739.
Done with 811.
Done with 890.
Done with 977.
Done with 1072.
Done with 1176.
Done with 1291.
Done with 1417.
Done with 1555.
Done with 1707.
Done with 1873.
Done with 2056.
Done with 2257.
Done with 2477

In [6]:
set_seed(seed)
aligner = _LinearAlignment(resolution**2)

for n_samples in samples_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_linear_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, channel)

    print(f"Linear model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times, device):.2f}")

Linear model, 1 samples got a PSNR of 11.91
Linear model, 2 samples got a PSNR of 12.02
Linear model, 3 samples got a PSNR of 12.07
Linear model, 4 samples got a PSNR of 12.22
Linear model, 5 samples got a PSNR of 12.63
Linear model, 6 samples got a PSNR of 12.72
Linear model, 7 samples got a PSNR of 12.79
Linear model, 8 samples got a PSNR of 12.84
Linear model, 9 samples got a PSNR of 13.27
Linear model, 10 samples got a PSNR of 13.25
Linear model, 11 samples got a PSNR of 13.53
Linear model, 12 samples got a PSNR of 13.63
Linear model, 13 samples got a PSNR of 13.68
Linear model, 14 samples got a PSNR of 13.76
Linear model, 16 samples got a PSNR of 13.94
Linear model, 17 samples got a PSNR of 14.00
Linear model, 19 samples got a PSNR of 14.04
Linear model, 21 samples got a PSNR of 14.40
Linear model, 23 samples got a PSNR of 14.56
Linear model, 25 samples got a PSNR of 14.66
Linear model, 28 samples got a PSNR of 14.85
Linear model, 31 samples got a PSNR of 14.99
Linear model, 34 sa

# Linear Neural

In [None]:
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets:
    
    aligner, epoch = train_neural_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_neural_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    print(f"Done with {n_samples}. Trained for {epoch} epochs.")

Done with 1. Trained for 18 epochs.
Done with 2. Trained for 24 epochs.
Done with 3. Trained for 31 epochs.
Done with 4. Trained for 45 epochs.
Done with 5. Trained for 38 epochs.
Done with 6. Trained for 67 epochs.
Done with 7. Trained for 84 epochs.
Done with 8. Trained for 44 epochs.
Done with 9. Trained for 157 epochs.
Done with 10. Trained for 105 epochs.
Done with 11. Trained for 146 epochs.
Done with 12. Trained for 103 epochs.
Done with 13. Trained for 102 epochs.
Done with 14. Trained for 12 epochs.
Done with 16. Trained for 12 epochs.
Done with 17. Trained for 12 epochs.
Done with 19. Trained for 12 epochs.
Done with 21. Trained for 12 epochs.
Done with 23. Trained for 12 epochs.
Done with 25. Trained for 12 epochs.
Done with 28. Trained for 424 epochs.
Done with 31. Trained for 599 epochs.
Done with 34. Trained for 505 epochs.
Done with 37. Trained for 614 epochs.
Done with 41. Trained for 667 epochs.
Done with 45. Trained for 723 epochs.
Done with 49. Trained for 687 epochs

: 

In [None]:
set_seed(seed)
aligner = _LinearAlignment(resolution**2)

for n_samples in [10000]:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_neural_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"Neural model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times, device):.2f}")

Neural model, 10000 samples got a PSNR of 23.24


# MLP

In [8]:
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets[79:]:
    
    aligner, epoch = train_mlp_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_mlp_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    print(f"Done with {n_samples}. Trained for {epoch} epochs.")

Done with 6892. Trained for 63 epochs.
Done with 7564. Trained for 38 epochs.
Done with 8302. Trained for 91 epochs.
Done with 9111. Trained for 46 epochs.
Done with 10000. Trained for 45 epochs.


In [9]:
set_seed(seed)
aligner = _MLPAlignment(input_dim=resolution**2, hidden_dims=[resolution**2])

for n_samples in samples_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_mlp_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"MLP model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times, device):.2f}")

MLP model, 1 samples got a PSNR of 11.55
MLP model, 2 samples got a PSNR of 12.05
MLP model, 3 samples got a PSNR of 12.53
MLP model, 4 samples got a PSNR of 12.50
MLP model, 5 samples got a PSNR of 12.69
MLP model, 6 samples got a PSNR of 13.42
MLP model, 7 samples got a PSNR of 13.11
MLP model, 8 samples got a PSNR of 13.30
MLP model, 9 samples got a PSNR of 13.63
MLP model, 10 samples got a PSNR of 13.89
MLP model, 11 samples got a PSNR of 13.67
MLP model, 12 samples got a PSNR of 13.82
MLP model, 13 samples got a PSNR of 14.23
MLP model, 14 samples got a PSNR of 14.57
MLP model, 16 samples got a PSNR of 14.82
MLP model, 17 samples got a PSNR of 15.12
MLP model, 19 samples got a PSNR of 15.29
MLP model, 21 samples got a PSNR of 15.21
MLP model, 23 samples got a PSNR of 15.30
MLP model, 25 samples got a PSNR of 15.61
MLP model, 28 samples got a PSNR of 15.55
MLP model, 31 samples got a PSNR of 15.64
MLP model, 34 samples got a PSNR of 15.80
MLP model, 37 samples got a PSNR of 15.93
M

# Convolutional

In [7]:
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets:

    aligner, epoch = train_conv_aligner(data, permutation, n_samples, c, batch_size, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_conv_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    print(f"Done with {n_samples}. Trained for {epoch} epochs.")

Done with 1. Trained for 521 epochs.
Done with 2. Trained for 610 epochs.
Done with 3. Trained for 507 epochs.
Done with 4. Trained for 527 epochs.
Done with 5. Trained for 686 epochs.
Done with 6. Trained for 641 epochs.
Done with 7. Trained for 758 epochs.
Done with 8. Trained for 763 epochs.
Done with 9. Trained for 827 epochs.
Done with 10. Trained for 881 epochs.
Done with 11. Trained for 704 epochs.
Done with 12. Trained for 686 epochs.
Done with 13. Trained for 1173 epochs.
Done with 14. Trained for 829 epochs.
Done with 16. Trained for 1070 epochs.
Done with 17. Trained for 830 epochs.
Done with 19. Trained for 1045 epochs.
Done with 21. Trained for 1029 epochs.
Done with 23. Trained for 875 epochs.
Done with 25. Trained for 944 epochs.
Done with 28. Trained for 815 epochs.
Done with 31. Trained for 1073 epochs.
Done with 34. Trained for 878 epochs.
Done with 37. Trained for 1134 epochs.
Done with 41. Trained for 1242 epochs.
Done with 45. Trained for 1147 epochs.
Done with 49.

In [8]:
set_seed(seed)
aligner = _ConvolutionalAlignment(in_channels=2*c, out_channels=2*c, kernel_size=5)

for n_samples in samples_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_conv_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"Conv model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times, device):.2f}")

Conv model, 1 samples got a PSNR of 23.19
Conv model, 2 samples got a PSNR of 33.89
Conv model, 3 samples got a PSNR of 35.18
Conv model, 4 samples got a PSNR of 36.29
Conv model, 5 samples got a PSNR of 37.16
Conv model, 6 samples got a PSNR of 37.53
Conv model, 7 samples got a PSNR of 38.59
Conv model, 8 samples got a PSNR of 38.29
Conv model, 9 samples got a PSNR of 37.73
Conv model, 10 samples got a PSNR of 38.61
Conv model, 11 samples got a PSNR of 38.64
Conv model, 12 samples got a PSNR of 38.68
Conv model, 13 samples got a PSNR of 40.05
Conv model, 14 samples got a PSNR of 39.49
Conv model, 16 samples got a PSNR of 39.78
Conv model, 17 samples got a PSNR of 39.39
Conv model, 19 samples got a PSNR of 39.65
Conv model, 21 samples got a PSNR of 39.40
Conv model, 23 samples got a PSNR of 39.20
Conv model, 25 samples got a PSNR of 39.47
Conv model, 28 samples got a PSNR of 39.25
Conv model, 31 samples got a PSNR of 39.68
Conv model, 34 samples got a PSNR of 39.82
Conv model, 37 sampl

# Two Conv

In [None]:
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets:

    aligner, epoch = train_twoconv_aligner(data, permutation, n_samples, c, batch_size, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_twoconv_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    print(f"Done with {n_samples}. Trained for {epoch} epochs.")

Done with 10. Trained for 382 epochs.


In [None]:
set_seed(seed)
aligner = _TwoConvAlignment(in_channels=2*c, hidden_channels=2*c, out_channels=2*c, kernel_size=5)

for n_samples in samples_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_twoconv_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"Twpconv model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times, device):.2f}")

Twpconv model, 10 samples got a PSNR of 27.64


# Zero-shot

In [9]:
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets[1:]:

    aligner = train_zeroshot_aligner(data, permutation, n_samples, train_snr, n_samples, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_zeroshot_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    print(f"Done with {n_samples}.")

Done with 2.
Done with 3.
Done with 4.
Done with 5.
Done with 6.
Done with 7.
Done with 8.
Done with 9.
Done with 10.
Done with 11.
Done with 12.
Done with 13.
Done with 14.
Done with 16.
Done with 17.
Done with 19.
Done with 21.
Done with 23.
Done with 25.
Done with 28.
Done with 31.
Done with 34.
Done with 37.
Done with 41.
Done with 45.
Done with 49.
Done with 54.
Done with 59.
Done with 65.
Done with 72.
Done with 79.
Done with 86.
Done with 95.
Done with 104.
Done with 114.
Done with 126.
Done with 138.
Done with 151.
Done with 166.
Done with 183.
Done with 200.
Done with 220.
Done with 242.
Done with 265.
Done with 291.
Done with 319.
Done with 351.
Done with 385.
Done with 422.
Done with 464.
Done with 509.
Done with 559.
Done with 613.
Done with 673.
Done with 739.
Done with 811.
Done with 890.
Done with 977.
Done with 1072.
Done with 1176.
Done with 1291.
Done with 1417.
Done with 1555.
Done with 1707.
Done with 1873.
Done with 2056.
Done with 2257.
Done with 2477.
Done with 2

In [10]:
set_seed(seed)

for n_samples in samples_sets[1:]:

    aligner = _ZeroShotAlignment(
        F_tilde=torch.zeros(n_samples, resolution**2),
        G_tilde=torch.zeros(resolution**2, n_samples), 
        G=torch.zeros(1, 1),
        L=torch.zeros(n_samples, n_samples),
        mean=torch.zeros(n_samples, 1)
    )


    aligner_fp = f'alignment/models/plots/{folder}/aligner_zeroshot_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"Zeroshot model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times, device):.2f}")

Zeroshot model, 2 samples got a PSNR of 11.23
Zeroshot model, 3 samples got a PSNR of 11.66
Zeroshot model, 4 samples got a PSNR of 12.17
Zeroshot model, 5 samples got a PSNR of 12.97
Zeroshot model, 6 samples got a PSNR of 13.13
Zeroshot model, 7 samples got a PSNR of 13.40
Zeroshot model, 8 samples got a PSNR of 13.53
Zeroshot model, 9 samples got a PSNR of 14.24
Zeroshot model, 10 samples got a PSNR of 14.29
Zeroshot model, 11 samples got a PSNR of 14.90
Zeroshot model, 12 samples got a PSNR of 15.14
Zeroshot model, 13 samples got a PSNR of 15.20
Zeroshot model, 14 samples got a PSNR of 15.26
Zeroshot model, 16 samples got a PSNR of 15.43
Zeroshot model, 17 samples got a PSNR of 15.52
Zeroshot model, 19 samples got a PSNR of 15.57
Zeroshot model, 21 samples got a PSNR of 15.93
Zeroshot model, 23 samples got a PSNR of 15.88
Zeroshot model, 25 samples got a PSNR of 16.30
Zeroshot model, 28 samples got a PSNR of 16.55
Zeroshot model, 31 samples got a PSNR of 16.73
Zeroshot model, 34 sa