In [None]:
%cd ..

import os
import torch
import copy
import numpy as np

from alignment.alignment_utils import load_deep_jscc
from alignment.alignment_model import *
from alignment.alignment_training import *
from alignment.alignment_validation import *

/home/lorenzo/repos/Deep-JSCC-PyTorch


In [2]:
model1_fp = r'alignment/models/autoencoders/snr_0_seed_42.pkl'
model2_fp = r'alignment/models/autoencoders/snr_0_seed_43.pkl'
folder = r'psnr_vs_pilots_5'
os.makedirs(f'alignment/models/plots/{folder}', exist_ok=True)

dataset = "cifar10"
resolution = 96
channel = 'AWGN'
batch_size = 64
num_workers = 4

train_snr = 0
val_snr = 0
times = 10
c = 8
seed = 42

samples_sets = np.unique(np.logspace(0, np.log10(10000), num=100, base=10).astype(int))
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

encoder = copy.deepcopy(load_deep_jscc(model1_fp, val_snr, c, "AWGN").encoder)
decoder = copy.deepcopy(load_deep_jscc(model2_fp, val_snr, c, "AWGN").decoder)

train_loader, test_loader = get_data_loaders(dataset, resolution, batch_size, num_workers)
data = load_alignment_dataset(model1_fp, model2_fp, train_snr, train_loader, c, device)

Caching inputs: 100%|██████████| 782/782 [00:04<00:00, 174.71it/s]


# No mismatch - Unaligned

In [None]:
model = AlignedDeepJSCC(encoder, decoder, None, val_snr, "AWGN")
print(f"Unaligned {validation_vectorized(model, test_loader, times):.2f}")

In [None]:
model = AlignedDeepJSCC(encoder, copy.deepcopy(load_deep_jscc(model1_fp, val_snr, c, "AWGN").decoder), None, val_snr, "AWGN")
print(f"Aligned {validation_vectorized(model, test_loader, times):.2f}")

# Least Squares

In [4]:
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets:

    aligner = train_linear_aligner(data, permutation, n_samples, train_snr)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_linear_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    print(f"Done with {n_samples}.")

Done with 1.
Done with 2.
Done with 3.
Done with 4.
Done with 5.
Done with 6.
Done with 7.
Done with 8.
Done with 9.
Done with 10.
Done with 11.
Done with 12.
Done with 13.
Done with 14.
Done with 16.
Done with 17.
Done with 19.
Done with 21.
Done with 23.
Done with 25.
Done with 28.
Done with 31.
Done with 34.
Done with 37.
Done with 41.
Done with 45.
Done with 49.
Done with 54.
Done with 59.
Done with 65.
Done with 72.
Done with 79.
Done with 86.
Done with 95.
Done with 104.
Done with 114.
Done with 126.
Done with 138.
Done with 151.
Done with 166.
Done with 183.
Done with 200.
Done with 220.
Done with 242.
Done with 265.
Done with 291.
Done with 319.
Done with 351.
Done with 385.
Done with 422.
Done with 464.
Done with 509.
Done with 559.
Done with 613.
Done with 673.
Done with 739.
Done with 811.
Done with 890.
Done with 977.
Done with 1072.
Done with 1176.
Done with 1291.
Done with 1417.
Done with 1555.
Done with 1707.
Done with 1873.
Done with 2056.
Done with 2257.
Done with 2477

In [6]:
set_seed(seed)
aligner = _LinearAlignment(None, None)

for n_samples in samples_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_linear_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, channel)

    print(f"Linear model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times, device):.2f}")

Linear model, 1 samples got a PSNR of 10.05
Linear model, 2 samples got a PSNR of 10.40
Linear model, 3 samples got a PSNR of 10.76
Linear model, 4 samples got a PSNR of 11.07
Linear model, 5 samples got a PSNR of 11.00
Linear model, 6 samples got a PSNR of 11.16
Linear model, 7 samples got a PSNR of 11.16
Linear model, 8 samples got a PSNR of 11.19
Linear model, 9 samples got a PSNR of 11.15
Linear model, 10 samples got a PSNR of 11.32
Linear model, 11 samples got a PSNR of 11.28
Linear model, 12 samples got a PSNR of 11.45
Linear model, 13 samples got a PSNR of 11.54
Linear model, 14 samples got a PSNR of 11.54
Linear model, 16 samples got a PSNR of 11.61
Linear model, 17 samples got a PSNR of 11.63
Linear model, 19 samples got a PSNR of 11.62
Linear model, 21 samples got a PSNR of 11.71
Linear model, 23 samples got a PSNR of 11.90
Linear model, 25 samples got a PSNR of 11.91
Linear model, 28 samples got a PSNR of 12.06
Linear model, 31 samples got a PSNR of 12.23
Linear model, 34 sa

# Linear Neural

In [None]:
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets:
    
    aligner, epoch = train_neural_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_neural_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    print(f"Done with {n_samples}. Trained for {epoch} epochs.")

Done with 1. Trained for 18 epochs.
Done with 2. Trained for 24 epochs.
Done with 3. Trained for 31 epochs.
Done with 4. Trained for 45 epochs.
Done with 5. Trained for 38 epochs.
Done with 6. Trained for 67 epochs.
Done with 7. Trained for 84 epochs.
Done with 8. Trained for 44 epochs.
Done with 9. Trained for 157 epochs.
Done with 10. Trained for 105 epochs.
Done with 11. Trained for 146 epochs.
Done with 12. Trained for 103 epochs.
Done with 13. Trained for 102 epochs.
Done with 14. Trained for 12 epochs.
Done with 16. Trained for 12 epochs.
Done with 17. Trained for 12 epochs.
Done with 19. Trained for 12 epochs.
Done with 21. Trained for 12 epochs.
Done with 23. Trained for 12 epochs.
Done with 25. Trained for 12 epochs.
Done with 28. Trained for 424 epochs.
Done with 31. Trained for 599 epochs.
Done with 34. Trained for 505 epochs.
Done with 37. Trained for 614 epochs.
Done with 41. Trained for 667 epochs.
Done with 45. Trained for 723 epochs.
Done with 49. Trained for 687 epochs

In [None]:
set_seed(seed)
aligner = _LinearAlignment(None, None)

for n_samples in samples_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_neural_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"Neural model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times, device):.2f}")

Neural model, 1 samples got a PSNR of 11.09
Neural model, 2 samples got a PSNR of 11.03
Neural model, 3 samples got a PSNR of 11.02
Neural model, 4 samples got a PSNR of 11.11
Neural model, 5 samples got a PSNR of 11.10
Neural model, 6 samples got a PSNR of 11.10
Neural model, 7 samples got a PSNR of 11.30
Neural model, 8 samples got a PSNR of 11.24
Neural model, 9 samples got a PSNR of 11.24
Neural model, 10 samples got a PSNR of 11.38
Neural model, 11 samples got a PSNR of 11.42
Neural model, 12 samples got a PSNR of 11.51
Neural model, 13 samples got a PSNR of 11.70
Neural model, 14 samples got a PSNR of 11.77
Neural model, 16 samples got a PSNR of 11.83
Neural model, 17 samples got a PSNR of 11.87
Neural model, 19 samples got a PSNR of 11.95
Neural model, 21 samples got a PSNR of 11.96
Neural model, 23 samples got a PSNR of 11.97
Neural model, 25 samples got a PSNR of 12.01
Neural model, 28 samples got a PSNR of 12.07
Neural model, 31 samples got a PSNR of 12.13
Neural model, 34 sa

# MLP

In [None]:
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets:
    
    aligner, epoch = train_mlp_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_mlp_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    print(f"Done with {n_samples}. Trained for {epoch} epochs.")

Done with 1. Trained for 43 epochs.
Done with 2. Trained for 33 epochs.
Done with 3. Trained for 65 epochs.
Done with 4. Trained for 69 epochs.
Done with 5. Trained for 42 epochs.
Done with 6. Trained for 121 epochs.
Done with 7. Trained for 125 epochs.
Done with 8. Trained for 64 epochs.
Done with 9. Trained for 43 epochs.
Done with 10. Trained for 74 epochs.
Done with 11. Trained for 86 epochs.
Done with 12. Trained for 96 epochs.
Done with 13. Trained for 39 epochs.
Done with 14. Trained for 103 epochs.
Done with 16. Trained for 51 epochs.
Done with 17. Trained for 155 epochs.
Done with 19. Trained for 50 epochs.
Done with 21. Trained for 68 epochs.
Done with 23. Trained for 120 epochs.
Done with 25. Trained for 149 epochs.
Done with 28. Trained for 109 epochs.
Done with 31. Trained for 100 epochs.
Done with 34. Trained for 149 epochs.
Done with 37. Trained for 173 epochs.
Done with 41. Trained for 89 epochs.
Done with 45. Trained for 121 epochs.
Done with 49. Trained for 144 epochs

In [None]:
set_seed(seed)
aligner = _MLPAlignment(input_dim=1, hidden_dims=[1]

for n_samples in samples_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_mlp_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"MLP model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times, device):.2f}")

MLP model, 1 samples got a PSNR of 10.67
MLP model, 2 samples got a PSNR of 10.97
MLP model, 3 samples got a PSNR of 10.43
MLP model, 4 samples got a PSNR of 10.24


KeyboardInterrupt: 

# Convolutional

In [None]:
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets:

    aligner, epoch = train_conv_aligner(data, permutation, n_samples, c, batch_size, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_conv_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    print(f"Done with {n_samples}. Trained for {epoch} epochs.")

Done with 1. Trained for 148 epochs.
Done with 2. Trained for 190 epochs.
Done with 3. Trained for 155 epochs.
Done with 4. Trained for 172 epochs.
Done with 5. Trained for 196 epochs.
Done with 6. Trained for 140 epochs.
Done with 7. Trained for 169 epochs.
Done with 8. Trained for 190 epochs.
Done with 9. Trained for 199 epochs.
Done with 10. Trained for 184 epochs.
Done with 11. Trained for 200 epochs.
Done with 12. Trained for 187 epochs.
Done with 13. Trained for 234 epochs.
Done with 14. Trained for 172 epochs.
Done with 16. Trained for 272 epochs.
Done with 17. Trained for 195 epochs.
Done with 19. Trained for 243 epochs.
Done with 21. Trained for 222 epochs.
Done with 23. Trained for 227 epochs.
Done with 25. Trained for 233 epochs.
Done with 28. Trained for 295 epochs.
Done with 31. Trained for 200 epochs.
Done with 34. Trained for 255 epochs.
Done with 37. Trained for 203 epochs.
Done with 41. Trained for 234 epochs.
Done with 45. Trained for 234 epochs.
Done with 49. Trained

KeyboardInterrupt: 

In [None]:
set_seed(seed)
aligner = _ConvolutionalAlignment(in_channels=1, out_channels=1, kernel_size=3)

for n_samples in samples_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_conv_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"Conv model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times, device):.2f}")

Conv model, 1 samples got a PSNR of 16.88
Conv model, 2 samples got a PSNR of 25.24
Conv model, 3 samples got a PSNR of 26.49
Conv model, 4 samples got a PSNR of 27.47


KeyboardInterrupt: 

# Zero-shot

In [None]:
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets[1:]:

    aligner = train_zeroshot_aligner(data, permutation, n_samples, train_snr, n_samples, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_zeroshot_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    print(f"Done with {n_samples}.")

Computing model outputs: 100%|██████████| 782/782 [00:08<00:00, 91.04it/s] 


Done with 2.
Done with 3.
Done with 4.
Done with 5.
Done with 6.
Done with 7.
Done with 8.
Done with 9.
Done with 10.
Done with 11.
Done with 12.
Done with 13.
Done with 14.
Done with 16.
Done with 17.
Done with 19.
Done with 21.
Done with 23.
Done with 25.
Done with 28.
Done with 31.
Done with 34.
Done with 37.
Done with 41.
Done with 45.
Done with 49.
Done with 54.
Done with 59.
Done with 65.
Done with 72.
Done with 79.
Done with 86.
Done with 95.
Done with 104.
Done with 114.
Done with 126.
Done with 138.
Done with 151.
Done with 166.
Done with 183.
Done with 200.
Done with 220.
Done with 242.
Done with 265.
Done with 291.
Done with 319.
Done with 351.
Done with 385.
Done with 422.
Done with 464.
Done with 509.
Done with 559.
Done with 613.
Done with 673.
Done with 739.
Done with 811.
Done with 890.
Done with 977.
Done with 1072.
Done with 1176.
Done with 1291.
Done with 1417.
Done with 1555.
Done with 1707.
Done with 1873.
Done with 2056.
Done with 2257.
Done with 2477.
Done with 2

In [None]:
set_seed(seed)
aligner = _ZeroShotAlignment(
    F_tilde=torch.zeros(1, 1),
    G_tilde=torch.zeros(1, 1), 
    G=torch.zeros(1, 1),
    L=torch.zeros(1, 1),
    mean=torch.zeros(1, 1)
)

for n_samples in samples_sets[1:]:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_zeroshot_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"Zeroshot model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times, device):.2f}")

Zeroshot model, 2 samples got a PSNR of 9.91
Zeroshot model, 3 samples got a PSNR of 9.86
Zeroshot model, 4 samples got a PSNR of 9.90
Zeroshot model, 5 samples got a PSNR of 10.84
Zeroshot model, 6 samples got a PSNR of 10.95
Zeroshot model, 7 samples got a PSNR of 11.08
Zeroshot model, 8 samples got a PSNR of 11.31
Zeroshot model, 9 samples got a PSNR of 11.21
Zeroshot model, 10 samples got a PSNR of 11.23
Zeroshot model, 11 samples got a PSNR of 11.27
Zeroshot model, 12 samples got a PSNR of 11.27
Zeroshot model, 13 samples got a PSNR of 11.27
Zeroshot model, 14 samples got a PSNR of 11.48
Zeroshot model, 16 samples got a PSNR of 11.70
Zeroshot model, 17 samples got a PSNR of 11.85
Zeroshot model, 19 samples got a PSNR of 11.89
Zeroshot model, 21 samples got a PSNR of 12.10
Zeroshot model, 23 samples got a PSNR of 12.26
Zeroshot model, 25 samples got a PSNR of 12.31
Zeroshot model, 28 samples got a PSNR of 12.36
Zeroshot model, 31 samples got a PSNR of 12.49
Zeroshot model, 34 sampl