In [1]:
%cd ..

import os
import torch
import copy
import numpy as np

from alignment.alignment_utils import load_deep_jscc
from alignment.alignment_model import AlignedDeepJSCC
from alignment.alignment_training import *
from alignment.alignment_validation import *

import pickle

/home/lorenzo/repos/Deep-JSCC-PyTorch


In [2]:
model1_fp = r'alignment/models/autoencoders/snr_0_seed_42.pkl'
model2_fp = r'alignment/models/autoencoders/snr_0_seed_43.pkl'
folder = r'psnr_vs_pilots_5'
os.makedirs(f'alignment/models/plots/{folder}', exist_ok=True)

dataset = "cifar10"
resolution = 96
channel = 'AWGN'
batch_size = 64
num_workers = 4

train_snr = 0
val_snr = 0
times = 10
c = 8
seed = 42

samples_sets = np.unique(np.logspace(0, np.log10(10000), num=100, base=10).astype(int))
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

encoder = copy.deepcopy(load_deep_jscc(model1_fp, val_snr, c, "AWGN").encoder)
decoder = copy.deepcopy(load_deep_jscc(model2_fp, val_snr, c, "AWGN").decoder)

train_loader, test_loader = get_data_loaders(dataset, resolution, batch_size, num_workers)
data = load_alignment_dataset(model1_fp, model2_fp, train_snr, train_loader, c, device)

Caching inputs: 100%|██████████| 782/782 [00:02<00:00, 348.51it/s]


# No mismatch - Unaligned

In [None]:
model = AlignedDeepJSCC(encoder, decoder, None, val_snr, "AWGN")
print(f"Unaligned {validation_vectorized(model, test_loader, times):.2f}")

In [None]:
model = AlignedDeepJSCC(encoder, copy.deepcopy(load_deep_jscc(model1_fp, val_snr, c, "AWGN").decoder), None, val_snr, "AWGN")
print(f"Aligned {validation_vectorized(model, test_loader, times):.2f}")

# Least Squares

In [None]:
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets:

    aligner = train_linear_aligner(data, permutation, n_samples, train_snr)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_linear_{n_samples}.pkl'
    with open(aligner_fp, 'wb') as f:
        pickle.dump(aligner, f)

    print(f"Done with {n_samples}.")

Caching inputs: 100%|██████████| 782/782 [00:02<00:00, 326.27it/s]


Done with 1.
Done with 2.
Done with 3.
Done with 4.
Done with 5.
Done with 6.
Done with 7.
Done with 8.
Done with 9.
Done with 10.
Done with 11.
Done with 12.
Done with 13.
Done with 14.
Done with 16.
Done with 17.
Done with 19.
Done with 21.
Done with 23.
Done with 25.
Done with 28.
Done with 31.
Done with 34.
Done with 37.
Done with 41.
Done with 45.
Done with 49.
Done with 54.
Done with 59.
Done with 65.
Done with 72.
Done with 79.
Done with 86.
Done with 95.
Done with 104.
Done with 114.
Done with 126.
Done with 138.
Done with 151.
Done with 166.
Done with 183.
Done with 200.
Done with 220.
Done with 242.
Done with 265.
Done with 291.
Done with 319.
Done with 351.
Done with 385.
Done with 422.
Done with 464.


KeyboardInterrupt: 

In [None]:
set_seed(seed)

for n_samples in samples_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_linear_{n_samples}.pkl'
    with open(aligner_fp, 'rb') as f:
        aligner = pickle.load(f)

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, channel)

    print(f"Linear model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times, device):.2f}")

Linear model, 1000 samples got a PSNR of 16.83


# Linear Neural

In [None]:
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets:
    
    aligner, epoch = train_neural_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_neural_{n_samples}.pkl'
    with open(aligner_fp, 'wb') as f:
        pickle.dump(aligner.to("cpu"), f)

    print(f"Done with {n_samples}. Trained for {epoch} epochs.")

Computing model outputs: 100%|██████████| 782/782 [00:10<00:00, 76.13it/s]


Done with 1000. Trained for 45 epochs.


In [None]:
set_seed(seed)

for n_samples in samples_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_neural_{n_samples}.pkl'
    with open(aligner_fp, 'rb') as f:
        aligner = pickle.load(f)

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"Neural model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times, device):.2f}")

Neural model, 1000 samples got a PSNR of 17.13


# MLP

In [None]:
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets:
    
    aligner, epoch = train_mlp_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_mlp_{n_samples}.pkl'
    with open(aligner_fp, 'wb') as f:
        pickle.dump(aligner.to("cpu"), f)

    print(f"Done with {n_samples}. Trained for {epoch} epochs.")

Computing model outputs: 100%|██████████| 782/782 [00:08<00:00, 87.03it/s] 


Done with 10000. Trained for 12 epochs.


In [None]:
set_seed(seed)

for n_samples in samples_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_mlp_{n_samples}.pkl'
    with open(aligner_fp, 'rb') as f:
        aligner = pickle.load(f)

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"MLP model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times, device):.2f}")

MLP model, 10000 samples got a PSNR of 18.16


# Convolutional

In [None]:
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets:

    aligner, epoch = train_conv_aligner(data, permutation, n_samples, c, batch_size, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_conv_{n_samples}.pkl'
    with open(aligner_fp, 'wb') as f:
        pickle.dump(aligner.to("cpu"), f)

    print(f"Done with {n_samples}. Trained for {epoch} epochs.")

Done with 1. Trained for 168 epochs.
Done with 2. Trained for 204 epochs.
Done with 3. Trained for 289 epochs.
Done with 4. Trained for 250 epochs.
Done with 5. Trained for 237 epochs.
Done with 6. Trained for 217 epochs.
Done with 7. Trained for 278 epochs.
Done with 8. Trained for 305 epochs.
Done with 9. Trained for 280 epochs.
Done with 10. Trained for 293 epochs.
Done with 11. Trained for 299 epochs.
Done with 12. Trained for 334 epochs.
Done with 13. Trained for 315 epochs.
Done with 14. Trained for 324 epochs.
Done with 16. Trained for 341 epochs.
Done with 17. Trained for 350 epochs.
Done with 19. Trained for 342 epochs.
Done with 21. Trained for 397 epochs.
Done with 23. Trained for 341 epochs.
Done with 25. Trained for 368 epochs.
Done with 28. Trained for 309 epochs.
Done with 31. Trained for 368 epochs.
Done with 34. Trained for 357 epochs.
Done with 37. Trained for 457 epochs.
Done with 41. Trained for 414 epochs.
Done with 45. Trained for 421 epochs.
Done with 49. Trained

In [None]:
set_seed(seed)

for n_samples in samples_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_conv_{n_samples}.pkl'
    with open(aligner_fp, 'rb') as f:
        aligner = pickle.load(f)

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"Conv model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times):.2f}")

Conv model, 1 samples got a PSNR of 19.85
Conv model, 2 samples got a PSNR of 23.71
Conv model, 3 samples got a PSNR of 23.80
Conv model, 4 samples got a PSNR of 24.41
Conv model, 5 samples got a PSNR of 25.30
Conv model, 6 samples got a PSNR of 25.68
Conv model, 7 samples got a PSNR of 26.41
Conv model, 8 samples got a PSNR of 26.97
Conv model, 9 samples got a PSNR of 27.01
Conv model, 10 samples got a PSNR of 27.43
Conv model, 11 samples got a PSNR of 27.43
Conv model, 12 samples got a PSNR of 27.43
Conv model, 13 samples got a PSNR of 27.86
Conv model, 14 samples got a PSNR of 28.20
Conv model, 16 samples got a PSNR of 27.88
Conv model, 17 samples got a PSNR of 28.30
Conv model, 19 samples got a PSNR of 28.42
Conv model, 21 samples got a PSNR of 28.26
Conv model, 23 samples got a PSNR of 28.24
Conv model, 25 samples got a PSNR of 28.13
Conv model, 28 samples got a PSNR of 28.39
Conv model, 31 samples got a PSNR of 28.20
Conv model, 34 samples got a PSNR of 28.22
Conv model, 37 sampl

# Zero-shot

In [None]:
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in samples_sets[1:]:

    aligner = train_zeroshot_aligner(data, permutation, n_samples, train_snr, n_samples, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_zeroshot_{n_samples}.pkl'
    with open(aligner_fp, 'wb') as f:
        pickle.dump(aligner, f)

    print(f"Done with {n_samples}.")

Computing model outputs: 100%|██████████| 782/782 [00:08<00:00, 91.04it/s] 


Done with 2.
Done with 3.
Done with 4.
Done with 5.
Done with 6.
Done with 7.
Done with 8.
Done with 9.
Done with 10.
Done with 11.
Done with 12.
Done with 13.
Done with 14.
Done with 16.
Done with 17.
Done with 19.
Done with 21.
Done with 23.
Done with 25.
Done with 28.
Done with 31.
Done with 34.
Done with 37.
Done with 41.
Done with 45.
Done with 49.
Done with 54.
Done with 59.
Done with 65.
Done with 72.
Done with 79.
Done with 86.
Done with 95.
Done with 104.
Done with 114.
Done with 126.
Done with 138.
Done with 151.
Done with 166.
Done with 183.
Done with 200.
Done with 220.
Done with 242.
Done with 265.
Done with 291.
Done with 319.
Done with 351.
Done with 385.
Done with 422.
Done with 464.
Done with 509.
Done with 559.
Done with 613.
Done with 673.
Done with 739.
Done with 811.
Done with 890.
Done with 977.
Done with 1072.
Done with 1176.
Done with 1291.
Done with 1417.
Done with 1555.
Done with 1707.
Done with 1873.
Done with 2056.
Done with 2257.
Done with 2477.
Done with 2

In [None]:
set_seed(seed)

for n_samples in samples_sets[1:]:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_zeroshot_{n_samples}.pkl'
    with open(aligner_fp, 'rb') as f:
        aligner = pickle.load(f)

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"Zeroshot model, {n_samples} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times, device):.2f}")

Zeroshot model, 2 samples got a PSNR of 9.91
Zeroshot model, 3 samples got a PSNR of 9.86
Zeroshot model, 4 samples got a PSNR of 9.90
Zeroshot model, 5 samples got a PSNR of 10.84
Zeroshot model, 6 samples got a PSNR of 10.95
Zeroshot model, 7 samples got a PSNR of 11.08
Zeroshot model, 8 samples got a PSNR of 11.31
Zeroshot model, 9 samples got a PSNR of 11.21
Zeroshot model, 10 samples got a PSNR of 11.23
Zeroshot model, 11 samples got a PSNR of 11.27
Zeroshot model, 12 samples got a PSNR of 11.27
Zeroshot model, 13 samples got a PSNR of 11.27
Zeroshot model, 14 samples got a PSNR of 11.48
Zeroshot model, 16 samples got a PSNR of 11.70
Zeroshot model, 17 samples got a PSNR of 11.85
Zeroshot model, 19 samples got a PSNR of 11.89
Zeroshot model, 21 samples got a PSNR of 12.10
Zeroshot model, 23 samples got a PSNR of 12.26
Zeroshot model, 25 samples got a PSNR of 12.31
Zeroshot model, 28 samples got a PSNR of 12.36
Zeroshot model, 31 samples got a PSNR of 12.49
Zeroshot model, 34 sampl