In [1]:
%cd ..

import os
import torch
import copy
import numpy as np
from tqdm.notebook import tqdm


from alignment.alignment_utils import load_deep_jscc
from alignment.alignment_model import *
from alignment.alignment_model import _LinearAlignment, _MLPAlignment, _ConvolutionalAlignment, _ZeroShotAlignment, _TwoConvAlignment
from alignment.alignment_training import *
from alignment.alignment_validation import *

/home/lorenzo/repos/Deep-JSCC-PyTorch


In [2]:
model1_fp = r'alignment/models/autoencoders/snr_30_seed_42.pkl'
model2_fp = r'alignment/models/autoencoders/snr_30_seed_43.pkl'
folder = r'psnr_vs_pilots'
os.makedirs(f'alignment/models/plots/{folder}', exist_ok=True)
os.makedirs(f'alignment/logs', exist_ok=True)

dataset = "cifar10"
resolution = 96
channel = 'AWGN'
batch_size = 64
num_workers = 4

snr = 30
seed = 44

train_snr = snr
val_snr = snr
times = 10
c = 8

n_points = 20
pilots_sets = np.unique(np.logspace(0, np.log10(10000), num=n_points, base=10).astype(int))
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

encoder = copy.deepcopy(load_deep_jscc(model1_fp, val_snr, c, "AWGN").encoder)
decoder = copy.deepcopy(load_deep_jscc(model2_fp, val_snr, c, "AWGN").decoder)

train_loader, test_loader = get_data_loaders(dataset, resolution, batch_size, num_workers)
data = load_alignment_dataset(model1_fp, model2_fp, train_snr, train_loader, c, device)

Caching inputs: 100%|██████████| 782/782 [00:03<00:00, 212.08it/s]


# No mismatch - Unaligned - Zeroshot max

In [3]:
log_file = f"alignment/logs/lines_snr_{snr}_seed_{seed}.txt"

# unaligned
model = AlignedDeepJSCC(encoder, decoder, None, val_snr, "AWGN")

result_msg = f"unaligned {validation_vectorized(model, test_loader, times, device):.2f}"
print(result_msg)
with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

# aligned
model = AlignedDeepJSCC(encoder, copy.deepcopy(load_deep_jscc(model1_fp, val_snr, c, "AWGN").decoder), None, val_snr, "AWGN")

result_msg = f"aligned {validation_vectorized(model, test_loader, times, device):.2f}"
print(result_msg)
with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

# zeroshot
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

aligner = train_zeroshot_aligner(data, permutation, resolution**2, train_snr, resolution**2, device)
aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

result_msg = f"zeroshot {validation_vectorized(aligned_model, test_loader, times, device):.2f}"
print(result_msg)
with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

unaligned 5.39
aligned 48.50
zeroshot 38.37


# Least Squares

In [None]:
aligner_type = "linear"
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):

    aligner = train_linear_aligner(data, permutation, n_samples, train_snr)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}")

Training:  74%|███████▎  | 14/19 [00:07<00:02,  1.69it/s]

In [None]:
aligner_type = "linear"
aligner = _LinearAlignment(resolution**2)
log_file = f"alignment/logs/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_linear_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, channel)

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Linear model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

Linear model, 1 samples got a PSNR of 12.21
Linear model, 2 samples got a PSNR of 11.83
Linear model, 4 samples got a PSNR of 12.51
Linear model, 6 samples got a PSNR of 13.03
Linear model, 11 samples got a PSNR of 13.55
Linear model, 18 samples got a PSNR of 13.95
Linear model, 29 samples got a PSNR of 14.89
Linear model, 48 samples got a PSNR of 15.48
Linear model, 78 samples got a PSNR of 16.20
Linear model, 127 samples got a PSNR of 17.12
Linear model, 206 samples got a PSNR of 18.03
Linear model, 335 samples got a PSNR of 18.95
Linear model, 545 samples got a PSNR of 20.07
Linear model, 885 samples got a PSNR of 21.32
Linear model, 1438 samples got a PSNR of 22.80
Linear model, 2335 samples got a PSNR of 24.44
Linear model, 3792 samples got a PSNR of 26.33
Linear model, 6158 samples got a PSNR of 28.43
Linear model, 10000 samples got a PSNR of 30.71


# Linear Neural

In [None]:
aligner_type = "neural"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):
    
    aligner, epoch = train_neural_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}. Trained for {epoch} epochs.")

Training:   5%|▌         | 1/19 [00:03<01:06,  3.71s/it]

Done with 1. Trained for 260 epochs.


Training:  11%|█         | 2/19 [00:08<01:16,  4.49s/it]

Done with 2. Trained for 373 epochs.


Training:  16%|█▌        | 3/19 [00:14<01:24,  5.30s/it]

Done with 4. Trained for 458 epochs.


Training:  21%|██        | 4/19 [00:21<01:24,  5.67s/it]

Done with 6. Trained for 474 epochs.


Training:  26%|██▋       | 5/19 [00:33<01:51,  7.98s/it]

Done with 11. Trained for 867 epochs.


Training:  32%|███▏      | 6/19 [00:50<02:22, 10.96s/it]

Done with 18. Trained for 1074 epochs.


Training:  37%|███▋      | 7/19 [01:10<02:48, 14.04s/it]

Done with 29. Trained for 1293 epochs.


Training:  42%|████▏     | 8/19 [01:32<03:02, 16.57s/it]

Done with 48. Trained for 1212 epochs.


Training:  47%|████▋     | 9/19 [01:40<02:19, 13.97s/it]

Done with 78. Trained for 245 epochs.


Training:  53%|█████▎    | 10/19 [01:50<01:55, 12.80s/it]

Done with 127. Trained for 257 epochs.


Training:  58%|█████▊    | 11/19 [02:01<01:37, 12.22s/it]

Done with 206. Trained for 197 epochs.


Training:  63%|██████▎   | 12/19 [02:14<01:25, 12.25s/it]

Done with 335. Trained for 128 epochs.


Training:  68%|██████▊   | 13/19 [02:27<01:14, 12.50s/it]

Done with 545. Trained for 83 epochs.


Training:  74%|███████▎  | 14/19 [02:40<01:03, 12.78s/it]

Done with 885. Trained for 58 epochs.


Training:  79%|███████▉  | 15/19 [02:55<00:53, 13.38s/it]

Done with 1438. Trained for 39 epochs.


Training:  84%|████████▍ | 16/19 [03:17<00:47, 15.89s/it]

Done with 2335. Trained for 35 epochs.


Training:  89%|████████▉ | 17/19 [03:43<00:38, 19.18s/it]

Done with 3792. Trained for 27 epochs.


Training:  95%|█████████▍| 18/19 [04:24<00:25, 25.62s/it]

Done with 6158. Trained for 25 epochs.


Training: 100%|██████████| 19/19 [05:24<00:00, 17.06s/it]

Done with 10000. Trained for 23 epochs.





In [None]:
aligner_type = "neural"
aligner = _LinearAlignment(resolution**2)
log_file = f"alignment/logs/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_neural_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Neural model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

Neural model, 1 samples got a PSNR of 11.66
Neural model, 2 samples got a PSNR of 12.32
Neural model, 4 samples got a PSNR of 12.60
Neural model, 6 samples got a PSNR of 13.10
Neural model, 11 samples got a PSNR of 14.53
Neural model, 18 samples got a PSNR of 15.41
Neural model, 29 samples got a PSNR of 16.28
Neural model, 48 samples got a PSNR of 17.20
Neural model, 78 samples got a PSNR of 17.14
Neural model, 127 samples got a PSNR of 18.17
Neural model, 206 samples got a PSNR of 19.47
Neural model, 335 samples got a PSNR of 20.94
Neural model, 545 samples got a PSNR of 22.76
Neural model, 885 samples got a PSNR of 25.18
Neural model, 1438 samples got a PSNR of 27.43
Neural model, 2335 samples got a PSNR of 31.19
Neural model, 3792 samples got a PSNR of 31.81
Neural model, 6158 samples got a PSNR of 33.30
Neural model, 10000 samples got a PSNR of 34.02


# MLP

In [None]:
aligner_type = "mlp"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):
    
    aligner, epoch = train_mlp_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}. Trained for {epoch} epochs.")

Training:   5%|▌         | 1/19 [00:05<01:31,  5.07s/it]

Done with 1. Trained for 155 epochs.


Training:  11%|█         | 2/19 [00:06<00:54,  3.18s/it]

Done with 2. Trained for 21 epochs.


Training:  16%|█▌        | 3/19 [00:12<01:10,  4.43s/it]

Done with 4. Trained for 220 epochs.


Training:  21%|██        | 4/19 [00:14<00:51,  3.43s/it]

Done with 6. Trained for 21 epochs.


Training:  26%|██▋       | 5/19 [00:18<00:47,  3.41s/it]

Done with 11. Trained for 76 epochs.


Training:  32%|███▏      | 6/19 [00:21<00:44,  3.46s/it]

Done with 18. Trained for 75 epochs.


Training:  37%|███▋      | 7/19 [01:07<03:27, 17.29s/it]

Done with 29. Trained for 1573 epochs.


Training:  42%|████▏     | 8/19 [01:26<03:15, 17.80s/it]

Done with 48. Trained for 605 epochs.


Training:  47%|████▋     | 9/19 [01:32<02:22, 14.24s/it]

Done with 78. Trained for 89 epochs.


Training:  53%|█████▎    | 10/19 [01:38<01:45, 11.76s/it]

Done with 127. Trained for 75 epochs.


Training:  58%|█████▊    | 11/19 [02:06<02:11, 16.46s/it]

Done with 206. Trained for 285 epochs.


Training:  63%|██████▎   | 12/19 [02:51<02:56, 25.22s/it]

Done with 335. Trained for 308 epochs.


Training:  68%|██████▊   | 13/19 [03:57<03:45, 37.61s/it]

Done with 545. Trained for 273 epochs.


Training:  74%|███████▎  | 14/19 [05:10<04:01, 48.32s/it]

Done with 885. Trained for 186 epochs.


Training:  79%|███████▉  | 15/19 [06:22<03:41, 55.35s/it]

Done with 1438. Trained for 113 epochs.


Training:  84%|████████▍ | 16/19 [07:48<03:14, 64.79s/it]

Done with 2335. Trained for 88 epochs.


Training:  89%|████████▉ | 17/19 [09:19<02:25, 72.62s/it]

Done with 3792. Trained for 56 epochs.


Training:  95%|█████████▍| 18/19 [11:09<01:23, 83.66s/it]

Done with 6158. Trained for 41 epochs.


Training: 100%|██████████| 19/19 [13:32<00:00, 42.74s/it] 

Done with 10000. Trained for 33 epochs.





In [None]:
aligner_type = "mlp"
aligner = _MLPAlignment(input_dim=resolution**2, hidden_dims=[resolution**2])
log_file = f"alignment/logs/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_mlp_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"MLP model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

MLP model, 1 samples got a PSNR of 11.62
MLP model, 2 samples got a PSNR of 10.06
MLP model, 4 samples got a PSNR of 12.39
MLP model, 6 samples got a PSNR of 10.90
MLP model, 11 samples got a PSNR of 13.58
MLP model, 18 samples got a PSNR of 14.22
MLP model, 29 samples got a PSNR of 16.05
MLP model, 48 samples got a PSNR of 16.69
MLP model, 78 samples got a PSNR of 16.01
MLP model, 127 samples got a PSNR of 16.53
MLP model, 206 samples got a PSNR of 20.03
MLP model, 335 samples got a PSNR of 21.68
MLP model, 545 samples got a PSNR of 22.93
MLP model, 885 samples got a PSNR of 24.17
MLP model, 1438 samples got a PSNR of 25.10
MLP model, 2335 samples got a PSNR of 26.11
MLP model, 3792 samples got a PSNR of 26.35
MLP model, 6158 samples got a PSNR of 26.87
MLP model, 10000 samples got a PSNR of 27.04


# Convolutional

In [None]:
aligner_type = "conv"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):

    aligner, epoch = train_conv_aligner(data, permutation, n_samples, c, batch_size, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}. Trained for {epoch} epochs.")

Training: 100%|██████████| 19/19 [10:05<00:00, 31.87s/it]


In [None]:
aligner_type = "conv"
aligner = _ConvolutionalAlignment(in_channels=2*c, out_channels=2*c, kernel_size=5)
log_file = f"alignment/logs/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_conv_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Conv model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

Conv model, 1 samples got a PSNR of 30.90
Conv model, 2 samples got a PSNR of 36.63
Conv model, 4 samples got a PSNR of 38.32
Conv model, 6 samples got a PSNR of 39.79
Conv model, 11 samples got a PSNR of 37.09
Conv model, 18 samples got a PSNR of 39.08
Conv model, 29 samples got a PSNR of 38.18
Conv model, 48 samples got a PSNR of 40.45
Conv model, 78 samples got a PSNR of 40.97
Conv model, 127 samples got a PSNR of 41.11
Conv model, 206 samples got a PSNR of 41.29
Conv model, 335 samples got a PSNR of 41.36
Conv model, 545 samples got a PSNR of 41.37
Conv model, 885 samples got a PSNR of 41.43
Conv model, 1438 samples got a PSNR of 41.38
Conv model, 2335 samples got a PSNR of 41.25
Conv model, 3792 samples got a PSNR of 41.57
Conv model, 6158 samples got a PSNR of 41.54
Conv model, 10000 samples got a PSNR of 41.56


# Two Conv

In [None]:
aligner_type = "twoconv"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):

    aligner, epoch = train_twoconv_aligner(data, permutation, n_samples, c, batch_size, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}. Trained for {epoch} epochs.")

Training: 100%|██████████| 19/19 [13:25<00:00, 42.38s/it] 


In [None]:
aligner_type = "twoconv"
aligner = _TwoConvAlignment(in_channels=2*c, hidden_channels=2*c, out_channels=2*c, kernel_size=5)
log_file = f"alignment/logs/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_twoconv_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Twoconv model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

Twoconv model, 1 samples got a PSNR of 24.32
Twoconv model, 2 samples got a PSNR of 29.27
Twoconv model, 4 samples got a PSNR of 34.76
Twoconv model, 6 samples got a PSNR of 34.89
Twoconv model, 11 samples got a PSNR of 35.82
Twoconv model, 18 samples got a PSNR of 37.42
Twoconv model, 29 samples got a PSNR of 38.00
Twoconv model, 48 samples got a PSNR of 37.95
Twoconv model, 78 samples got a PSNR of 37.14
Twoconv model, 127 samples got a PSNR of 38.31
Twoconv model, 206 samples got a PSNR of 39.95
Twoconv model, 335 samples got a PSNR of 41.87
Twoconv model, 545 samples got a PSNR of 41.95
Twoconv model, 885 samples got a PSNR of 41.04
Twoconv model, 1438 samples got a PSNR of 43.41
Twoconv model, 2335 samples got a PSNR of 44.01
Twoconv model, 3792 samples got a PSNR of 43.48
Twoconv model, 6158 samples got a PSNR of 43.51
Twoconv model, 10000 samples got a PSNR of 43.77


# Zero-shot

In [None]:
aligner_type = "zeroshot"
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets[1:], desc="Training"):

    aligner = train_zeroshot_aligner(data, permutation, n_samples, train_snr, n_samples, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}.")

Training: 100%|██████████| 18/18 [01:33<00:00,  5.20s/it]


In [None]:
aligner_type = "zeroshot"
log_file = f"alignment/logs/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets[1:]:

    aligner = _ZeroShotAlignment(
        F_tilde=torch.zeros(n_samples, resolution**2),
        G_tilde=torch.zeros(resolution**2, n_samples), 
        G=torch.zeros(1, 1),
        L=torch.zeros(n_samples, n_samples),
        mean=torch.zeros(n_samples, 1)
    )

    aligner_fp = f'alignment/models/plots/{folder}/aligner_zeroshot_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Zeroshot model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

Zeroshot model, 2 samples got a PSNR of 12.31
Zeroshot model, 4 samples got a PSNR of 12.63
Zeroshot model, 6 samples got a PSNR of 13.17
Zeroshot model, 11 samples got a PSNR of 14.67
Zeroshot model, 18 samples got a PSNR of 15.53
Zeroshot model, 29 samples got a PSNR of 16.52
Zeroshot model, 48 samples got a PSNR of 17.55
Zeroshot model, 78 samples got a PSNR of 18.67
Zeroshot model, 127 samples got a PSNR of 19.90
Zeroshot model, 206 samples got a PSNR of 21.71
Zeroshot model, 335 samples got a PSNR of 23.73
Zeroshot model, 545 samples got a PSNR of 26.44
Zeroshot model, 885 samples got a PSNR of 29.70
Zeroshot model, 1438 samples got a PSNR of 33.38
Zeroshot model, 2335 samples got a PSNR of 36.36
Zeroshot model, 3792 samples got a PSNR of 37.54
Zeroshot model, 6158 samples got a PSNR of 38.86
Zeroshot model, 10000 samples got a PSNR of 38.43
