In [5]:
%cd ..

import os
import torch
import copy
import numpy as np

from alignment.alignment_utils import load_deep_jscc
from alignment.alignment_model import *
from alignment.alignment_model import _LinearAlignment, _MLPAlignment, _ConvolutionalAlignment, _ZeroShotAlignment
from alignment.alignment_training import *
from alignment.alignment_validation import *

/home/lorenzo/repos


In [None]:
model1_fp = r'alignment/models/autoencoders/snr_30_seed_42.pkl'
model2_fp = r'alignment/models/autoencoders/snr_30_seed_43.pkl'
folder = r'psnr_vs_pilots_5'
os.makedirs(f'alignment/models/plots/{folder}', exist_ok=True)

dataset = "cifar10"
resolution = 96
channel = 'AWGN'
batch_size = 64
num_workers = 4

train_snr = 30
val_snr = 30
times = 10
c = 8
seed = 42

samples_sets = np.unique(np.logspace(0, np.log10(10000), num=100, base=10).astype(int))
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

encoder = copy.deepcopy(load_deep_jscc(model1_fp, val_snr, c, "AWGN").encoder)
decoder = copy.deepcopy(load_deep_jscc(model2_fp, val_snr, c, "AWGN").decoder)

train_loader, test_loader = get_data_loaders(dataset, resolution, batch_size, num_workers)
data = load_alignment_dataset(model1_fp, model2_fp, train_snr, train_loader, c, device)

FileNotFoundError: [Errno 2] No such file or directory: 'alignment/models/autoencoders/snr_0_seed_42.pkl'

# No mismatch - Unaligned

In [None]:
model = AlignedDeepJSCC(encoder, decoder, None, val_snr, "AWGN")
print(f"Unaligned {validation_vectorized(model, test_loader, times, device):.2f}")

Unaligned 10.80


In [None]:
model = AlignedDeepJSCC(encoder, copy.deepcopy(load_deep_jscc(model1_fp, val_snr, c, "AWGN").decoder), None, val_snr, "AWGN")
print(f"Aligned {validation_vectorized(model, test_loader, times, device):.2f}")

Aligned 37.70


# Least Squares

In [None]:
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets:

    aligner = train_linear_aligner(data, permutation, n_samples, train_snr)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_linear_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    print(f"Done with {n_samples}.")

Done with 1.
Done with 2.
Done with 3.
Done with 4.
Done with 5.
Done with 6.
Done with 7.
Done with 8.
Done with 9.
Done with 10.
Done with 11.
Done with 12.
Done with 13.
Done with 14.
Done with 16.
Done with 17.
Done with 19.
Done with 21.
Done with 23.
Done with 25.
Done with 28.
Done with 31.
Done with 34.
Done with 37.
Done with 41.
Done with 45.
Done with 49.
Done with 54.
Done with 59.
Done with 65.
Done with 72.
Done with 79.
Done with 86.
Done with 95.
Done with 104.
Done with 114.
Done with 126.
Done with 138.
Done with 151.
Done with 166.
Done with 183.
Done with 200.
Done with 220.
Done with 242.
Done with 265.
Done with 291.
Done with 319.
Done with 351.
Done with 385.
Done with 422.
Done with 464.
Done with 509.
Done with 559.
Done with 613.
Done with 673.
Done with 739.
Done with 811.
Done with 890.
Done with 977.
Done with 1072.
Done with 1176.
Done with 1291.
Done with 1417.
Done with 1555.
Done with 1707.
Done with 1873.
Done with 2056.
Done with 2257.
Done with 2477

In [None]:
set_seed(seed)
aligner = _LinearAlignment(resolution**2)

for n_samples in samples_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_linear_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, channel)

    print(f"Linear model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times, device):.2f}")

Linear model, 1 samples got a PSNR of 10.05
Linear model, 2 samples got a PSNR of 10.40
Linear model, 3 samples got a PSNR of 10.76
Linear model, 4 samples got a PSNR of 11.07
Linear model, 5 samples got a PSNR of 11.00
Linear model, 6 samples got a PSNR of 11.16
Linear model, 7 samples got a PSNR of 11.16
Linear model, 8 samples got a PSNR of 11.19
Linear model, 9 samples got a PSNR of 11.15
Linear model, 10 samples got a PSNR of 11.32
Linear model, 11 samples got a PSNR of 11.28
Linear model, 12 samples got a PSNR of 11.45
Linear model, 13 samples got a PSNR of 11.54
Linear model, 14 samples got a PSNR of 11.54
Linear model, 16 samples got a PSNR of 11.61
Linear model, 17 samples got a PSNR of 11.63
Linear model, 19 samples got a PSNR of 11.62
Linear model, 21 samples got a PSNR of 11.71
Linear model, 23 samples got a PSNR of 11.90
Linear model, 25 samples got a PSNR of 11.91
Linear model, 28 samples got a PSNR of 12.06
Linear model, 31 samples got a PSNR of 12.23
Linear model, 34 sa

# Linear Neural

In [None]:
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets:
    
    aligner, epoch = train_neural_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_neural_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    print(f"Done with {n_samples}. Trained for {epoch} epochs.")

Done with 1. Trained for 18 epochs.
Done with 2. Trained for 24 epochs.
Done with 3. Trained for 31 epochs.
Done with 4. Trained for 45 epochs.
Done with 5. Trained for 38 epochs.
Done with 6. Trained for 67 epochs.
Done with 7. Trained for 84 epochs.
Done with 8. Trained for 44 epochs.
Done with 9. Trained for 157 epochs.
Done with 10. Trained for 105 epochs.
Done with 11. Trained for 146 epochs.
Done with 12. Trained for 103 epochs.
Done with 13. Trained for 102 epochs.
Done with 14. Trained for 12 epochs.
Done with 16. Trained for 12 epochs.
Done with 17. Trained for 12 epochs.
Done with 19. Trained for 12 epochs.
Done with 21. Trained for 12 epochs.
Done with 23. Trained for 12 epochs.
Done with 25. Trained for 12 epochs.
Done with 28. Trained for 424 epochs.
Done with 31. Trained for 599 epochs.
Done with 34. Trained for 505 epochs.
Done with 37. Trained for 614 epochs.
Done with 41. Trained for 667 epochs.
Done with 45. Trained for 723 epochs.
Done with 49. Trained for 687 epochs

: 

In [None]:
set_seed(seed)
aligner = _LinearAlignment(resolution**2)

for n_samples in [10000]:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_neural_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"Neural model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times, device):.2f}")

Neural model, 10000 samples got a PSNR of 23.24


# MLP

In [None]:
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets[76:]:
    
    aligner, epoch = train_mlp_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_mlp_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    print(f"Done with {n_samples}. Trained for {epoch} epochs.")

Done with 5214. Trained for 67 epochs.
Done with 5722. Trained for 59 epochs.
Done with 6280. Trained for 42 epochs.
Done with 6892. Trained for 45 epochs.
Done with 7564. Trained for 36 epochs.
Done with 8302. Trained for 34 epochs.
Done with 9111. Trained for 32 epochs.
Done with 10000. Trained for 31 epochs.


In [None]:
set_seed(seed)
aligner = _MLPAlignment(input_dim=resolution**2, hidden_dims=[resolution**2])

for n_samples in samples_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_mlp_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"MLP model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times, device):.2f}")

MLP model, 1 samples got a PSNR of 10.16
MLP model, 2 samples got a PSNR of 10.04
MLP model, 3 samples got a PSNR of 10.80
MLP model, 4 samples got a PSNR of 10.92
MLP model, 5 samples got a PSNR of 11.07
MLP model, 6 samples got a PSNR of 10.94
MLP model, 7 samples got a PSNR of 11.02
MLP model, 8 samples got a PSNR of 10.94
MLP model, 9 samples got a PSNR of 10.76
MLP model, 10 samples got a PSNR of 10.80
MLP model, 11 samples got a PSNR of 10.85
MLP model, 12 samples got a PSNR of 11.15
MLP model, 13 samples got a PSNR of 11.55
MLP model, 14 samples got a PSNR of 11.52
MLP model, 16 samples got a PSNR of 11.50
MLP model, 17 samples got a PSNR of 11.38
MLP model, 19 samples got a PSNR of 11.09
MLP model, 21 samples got a PSNR of 11.34
MLP model, 23 samples got a PSNR of 11.79
MLP model, 25 samples got a PSNR of 11.41
MLP model, 28 samples got a PSNR of 11.59
MLP model, 31 samples got a PSNR of 11.70
MLP model, 34 samples got a PSNR of 11.25
MLP model, 37 samples got a PSNR of 11.64
M

# Convolutional

In [None]:
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets:

    aligner, epoch = train_conv_aligner(data, permutation, n_samples, c, batch_size, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_conv_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    print(f"Done with {n_samples}. Trained for {epoch} epochs.")

Done with 1. Trained for 112 epochs.
Done with 2. Trained for 177 epochs.
Done with 3. Trained for 181 epochs.
Done with 4. Trained for 135 epochs.
Done with 5. Trained for 198 epochs.
Done with 6. Trained for 169 epochs.
Done with 7. Trained for 152 epochs.
Done with 8. Trained for 244 epochs.
Done with 9. Trained for 184 epochs.
Done with 10. Trained for 216 epochs.
Done with 11. Trained for 230 epochs.
Done with 12. Trained for 257 epochs.
Done with 13. Trained for 208 epochs.
Done with 14. Trained for 255 epochs.
Done with 16. Trained for 200 epochs.
Done with 17. Trained for 202 epochs.
Done with 19. Trained for 258 epochs.
Done with 21. Trained for 178 epochs.
Done with 23. Trained for 228 epochs.
Done with 25. Trained for 247 epochs.
Done with 28. Trained for 225 epochs.
Done with 31. Trained for 247 epochs.
Done with 34. Trained for 244 epochs.
Done with 37. Trained for 276 epochs.
Done with 41. Trained for 273 epochs.
Done with 45. Trained for 289 epochs.
Done with 49. Trained

In [None]:
set_seed(seed)
aligner = _ConvolutionalAlignment(in_channels=2*c, out_channels=2*c, kernel_size=5)

for n_samples in samples_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_conv_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"Conv model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times, device):.2f}")

Conv model, 1 samples got a PSNR of 21.97
Conv model, 2 samples got a PSNR of 24.12
Conv model, 3 samples got a PSNR of 25.64
Conv model, 4 samples got a PSNR of 24.60
Conv model, 5 samples got a PSNR of 26.45
Conv model, 6 samples got a PSNR of 27.25
Conv model, 7 samples got a PSNR of 27.38
Conv model, 8 samples got a PSNR of 28.57
Conv model, 9 samples got a PSNR of 28.17
Conv model, 10 samples got a PSNR of 28.47
Conv model, 11 samples got a PSNR of 28.65
Conv model, 12 samples got a PSNR of 28.84
Conv model, 13 samples got a PSNR of 28.75
Conv model, 14 samples got a PSNR of 29.07
Conv model, 16 samples got a PSNR of 28.96
Conv model, 17 samples got a PSNR of 29.29
Conv model, 19 samples got a PSNR of 29.78
Conv model, 21 samples got a PSNR of 29.21
Conv model, 23 samples got a PSNR of 29.90
Conv model, 25 samples got a PSNR of 30.09
Conv model, 28 samples got a PSNR of 30.15
Conv model, 31 samples got a PSNR of 30.24
Conv model, 34 samples got a PSNR of 30.11
Conv model, 37 sampl

# Zero-shot

In [None]:
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets[1:]:

    aligner = train_zeroshot_aligner(data, permutation, n_samples, train_snr, n_samples, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_zeroshot_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    print(f"Done with {n_samples}.")

Done with 2.
Done with 3.
Done with 4.
Done with 5.
Done with 6.
Done with 7.
Done with 8.
Done with 9.
Done with 10.
Done with 11.
Done with 12.
Done with 13.
Done with 14.
Done with 16.
Done with 17.
Done with 19.
Done with 21.
Done with 23.
Done with 25.
Done with 28.
Done with 31.
Done with 34.
Done with 37.
Done with 41.
Done with 45.
Done with 49.
Done with 54.
Done with 59.
Done with 65.
Done with 72.
Done with 79.
Done with 86.
Done with 95.
Done with 104.
Done with 114.
Done with 126.
Done with 138.
Done with 151.
Done with 166.
Done with 183.
Done with 200.
Done with 220.
Done with 242.
Done with 265.
Done with 291.
Done with 319.
Done with 351.
Done with 385.
Done with 422.
Done with 464.
Done with 509.
Done with 559.
Done with 613.
Done with 673.
Done with 739.
Done with 811.
Done with 890.
Done with 977.
Done with 1072.
Done with 1176.
Done with 1291.
Done with 1417.
Done with 1555.
Done with 1707.
Done with 1873.
Done with 2056.
Done with 2257.
Done with 2477.
Done with 2

In [None]:
set_seed(seed)

for n_samples in samples_sets[1:]:

    aligner = _ZeroShotAlignment(
        F_tilde=torch.zeros(n_samples, resolution**2),
        G_tilde=torch.zeros(resolution**2, n_samples), 
        G=torch.zeros(1, 1),
        L=torch.zeros(n_samples, n_samples),
        mean=torch.zeros(n_samples, 1)
    )


    aligner_fp = f'alignment/models/plots/{folder}/aligner_zeroshot_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"Zeroshot model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times, device):.2f}")

Zeroshot model, 2 samples got a PSNR of 10.70
Zeroshot model, 3 samples got a PSNR of 10.24
Zeroshot model, 4 samples got a PSNR of 10.34
Zeroshot model, 5 samples got a PSNR of 10.31
Zeroshot model, 6 samples got a PSNR of 10.00
Zeroshot model, 7 samples got a PSNR of 10.47
Zeroshot model, 8 samples got a PSNR of 9.88
Zeroshot model, 9 samples got a PSNR of 10.10
Zeroshot model, 10 samples got a PSNR of 10.75
Zeroshot model, 11 samples got a PSNR of 10.29
Zeroshot model, 12 samples got a PSNR of 10.44
Zeroshot model, 13 samples got a PSNR of 10.42
Zeroshot model, 14 samples got a PSNR of 10.32
Zeroshot model, 16 samples got a PSNR of 10.53
Zeroshot model, 17 samples got a PSNR of 10.74
Zeroshot model, 19 samples got a PSNR of 10.48
Zeroshot model, 21 samples got a PSNR of 10.79
Zeroshot model, 23 samples got a PSNR of 10.57
Zeroshot model, 25 samples got a PSNR of 11.02
Zeroshot model, 28 samples got a PSNR of 10.22
Zeroshot model, 31 samples got a PSNR of 10.90
Zeroshot model, 34 sam