In [1]:
%cd ..

import os
import torch
import copy
import numpy as np
from tqdm.notebook import tqdm


from alignment.alignment_utils import load_deep_jscc
from alignment.alignment_model import *
from alignment.alignment_model import _LinearAlignment, _MLPAlignment, _ConvolutionalAlignment, _ZeroShotAlignment, _TwoConvAlignment
from alignment.alignment_training import *
from alignment.alignment_validation import *

/home/lorenzo/repos/Deep-JSCC-PyTorch


In [None]:
snr = 30
seed = 42
resolution = 32

model1_fp = f'alignment/models/autoencoders/snr_{snr}_seed_42.pkl'
model2_fp = f'alignment/models/autoencoders/snr_{snr}_seed_43.pkl'
folder = f'psnr_vs_pilots'
os.makedirs(f'alignment/models/plots/{folder}', exist_ok=True)

dataset = "cifar10"
channel = 'AWGN'
batch_size = 64
num_workers = 4

logs_folder = f'alignment/logs_{resolution}'
os.makedirs(logs_folder, exist_ok=True)

train_snr = snr
val_snr = snr
times = 10
c = 8

n_points = 20
pilots_sets = np.unique(np.logspace(0, np.log10(10000), num=n_points, base=10).astype(int))
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

encoder = copy.deepcopy(load_deep_jscc(model1_fp, val_snr, c, "AWGN").encoder)
decoder = copy.deepcopy(load_deep_jscc(model2_fp, val_snr, c, "AWGN").decoder)

train_loader, test_loader = get_data_loaders(dataset, resolution, batch_size, num_workers)
data = load_alignment_dataset(model1_fp, model2_fp, train_snr, train_loader, c, device)

Caching inputs: 100%|██████████| 782/782 [00:02<00:00, 286.01it/s]


# No mismatch - Unaligned - Zeroshot max

In [None]:
log_file = f"{logs_folder}/lines_snr_{snr}_seed_{seed}.txt"

# unaligned
model = AlignedDeepJSCC(encoder, decoder, None, val_snr, "AWGN")

result_msg = f"unaligned {validation_vectorized(model, test_loader, times, device):.2f}"
print(result_msg)
with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

# aligned
model = AlignedDeepJSCC(encoder, copy.deepcopy(load_deep_jscc(model1_fp, val_snr, c, "AWGN").decoder), None, val_snr, "AWGN")

result_msg = f"aligned {validation_vectorized(model, test_loader, times, device):.2f}"
print(result_msg)
with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

# zeroshot
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

aligner = train_zeroshot_aligner(data, permutation, resolution**2, train_snr, resolution**2, device)
aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

result_msg = f"zeroshot {validation_vectorized(aligned_model, test_loader, times, device):.2f}"
print(result_msg)
with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

unaligned 7.81


FileNotFoundError: [Errno 2] No such file or directory: 'alignment/logs/lines_snr_30_seed_42.txt'

# Least Squares

In [None]:
aligner_type = "linear"
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):

    aligner = train_linear_aligner(data, permutation, n_samples, train_snr)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}")

Training: 100%|██████████| 19/19 [00:13<00:00,  1.44it/s]


In [None]:
aligner_type = "linear"
aligner = _LinearAlignment(resolution**2)
log_file = f"alignment/logs/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, channel)

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Linear model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

Linear model, 1 samples got a PSNR of 12.25
Linear model, 2 samples got a PSNR of 12.00
Linear model, 4 samples got a PSNR of 12.98
Linear model, 6 samples got a PSNR of 13.45
Linear model, 11 samples got a PSNR of 13.67
Linear model, 18 samples got a PSNR of 14.34
Linear model, 29 samples got a PSNR of 14.89
Linear model, 48 samples got a PSNR of 15.53
Linear model, 78 samples got a PSNR of 16.24
Linear model, 127 samples got a PSNR of 17.08
Linear model, 206 samples got a PSNR of 17.90
Linear model, 335 samples got a PSNR of 18.87
Linear model, 545 samples got a PSNR of 19.98
Linear model, 885 samples got a PSNR of 21.29
Linear model, 1438 samples got a PSNR of 22.78
Linear model, 2335 samples got a PSNR of 24.46
Linear model, 3792 samples got a PSNR of 26.35
Linear model, 6158 samples got a PSNR of 28.45
Linear model, 10000 samples got a PSNR of 30.73


# Linear Neural

In [None]:
aligner_type = "neural"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):
    
    aligner, epoch = train_neural_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}. Trained for {epoch} epochs.")

Training: 100%|██████████| 19/19 [05:36<00:00, 17.69s/it]


In [None]:
aligner_type = "neural"
aligner = _LinearAlignment(resolution**2)
log_file = f"alignment/logs/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Neural model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

Neural model, 1 samples got a PSNR of 11.19
Neural model, 2 samples got a PSNR of 11.17
Neural model, 4 samples got a PSNR of 13.66
Neural model, 6 samples got a PSNR of 14.37
Neural model, 11 samples got a PSNR of 15.00
Neural model, 18 samples got a PSNR of 15.63
Neural model, 29 samples got a PSNR of 16.43
Neural model, 48 samples got a PSNR of 17.35
Neural model, 78 samples got a PSNR of 17.38
Neural model, 127 samples got a PSNR of 18.32
Neural model, 206 samples got a PSNR of 19.88
Neural model, 335 samples got a PSNR of 21.51
Neural model, 545 samples got a PSNR of 22.97
Neural model, 885 samples got a PSNR of 25.09
Neural model, 1438 samples got a PSNR of 27.95
Neural model, 2335 samples got a PSNR of 30.71


Neural model, 3792 samples got a PSNR of 32.64
Neural model, 6158 samples got a PSNR of 34.04
Neural model, 10000 samples got a PSNR of 34.08


# MLP

In [None]:
aligner_type = "mlp"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):
    
    aligner, epoch = train_mlp_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}. Trained for {epoch} epochs.")

Training: 100%|██████████| 19/19 [14:37<00:00, 46.18s/it] 


In [None]:
aligner_type = "mlp"
aligner = _MLPAlignment(input_dim=resolution**2, hidden_dims=[resolution**2])
log_file = f"alignment/logs/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"MLP model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

MLP model, 1 samples got a PSNR of 11.06
MLP model, 2 samples got a PSNR of 10.94
MLP model, 4 samples got a PSNR of 10.85
MLP model, 6 samples got a PSNR of 11.96
MLP model, 11 samples got a PSNR of 14.89
MLP model, 18 samples got a PSNR of 15.47
MLP model, 29 samples got a PSNR of 16.13
MLP model, 48 samples got a PSNR of 15.91
MLP model, 78 samples got a PSNR of 17.03
MLP model, 127 samples got a PSNR of 17.57
MLP model, 206 samples got a PSNR of 20.34
MLP model, 335 samples got a PSNR of 21.69
MLP model, 545 samples got a PSNR of 22.99
MLP model, 885 samples got a PSNR of 24.11
MLP model, 1438 samples got a PSNR of 25.10
MLP model, 2335 samples got a PSNR of 26.12
MLP model, 3792 samples got a PSNR of 26.48
MLP model, 6158 samples got a PSNR of 26.88
MLP model, 10000 samples got a PSNR of 27.01


# Convolutional

In [None]:
aligner_type = "conv"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):

    aligner, epoch = train_conv_aligner(data, permutation, n_samples, c, batch_size, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}. Trained for {epoch} epochs.")

Training: 100%|██████████| 19/19 [08:23<00:00, 26.48s/it]


In [None]:
aligner_type = "conv"
aligner = _ConvolutionalAlignment(in_channels=2*c, out_channels=2*c, kernel_size=5)
log_file = f"alignment/logs/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Conv model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

Conv model, 1 samples got a PSNR of 32.57
Conv model, 2 samples got a PSNR of 36.56
Conv model, 4 samples got a PSNR of 37.97
Conv model, 6 samples got a PSNR of 37.88
Conv model, 11 samples got a PSNR of 37.34
Conv model, 18 samples got a PSNR of 38.79
Conv model, 29 samples got a PSNR of 39.44
Conv model, 48 samples got a PSNR of 39.82
Conv model, 78 samples got a PSNR of 40.62
Conv model, 127 samples got a PSNR of 41.10
Conv model, 206 samples got a PSNR of 41.32
Conv model, 335 samples got a PSNR of 41.21
Conv model, 545 samples got a PSNR of 41.43
Conv model, 885 samples got a PSNR of 41.40
Conv model, 1438 samples got a PSNR of 41.41
Conv model, 2335 samples got a PSNR of 41.55
Conv model, 3792 samples got a PSNR of 41.43
Conv model, 6158 samples got a PSNR of 41.42
Conv model, 10000 samples got a PSNR of 41.49


# Two Conv

In [None]:
aligner_type = "twoconv"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):

    aligner, epoch = train_twoconv_aligner(data, permutation, n_samples, c, batch_size, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}. Trained for {epoch} epochs.")

Training: 100%|██████████| 19/19 [13:56<00:00, 44.01s/it] 


In [None]:
aligner_type = "twoconv"
aligner = _TwoConvAlignment(in_channels=2*c, hidden_channels=2*c, out_channels=2*c, kernel_size=5)
log_file = f"alignment/logs/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Twoconv model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

Twoconv model, 1 samples got a PSNR of 23.81
Twoconv model, 2 samples got a PSNR of 29.21
Twoconv model, 4 samples got a PSNR of 34.44
Twoconv model, 6 samples got a PSNR of 34.40
Twoconv model, 11 samples got a PSNR of 36.36
Twoconv model, 18 samples got a PSNR of 36.96
Twoconv model, 29 samples got a PSNR of 37.17
Twoconv model, 48 samples got a PSNR of 37.41
Twoconv model, 78 samples got a PSNR of 36.48
Twoconv model, 127 samples got a PSNR of 36.96
Twoconv model, 206 samples got a PSNR of 41.38
Twoconv model, 335 samples got a PSNR of 40.59
Twoconv model, 545 samples got a PSNR of 42.24
Twoconv model, 885 samples got a PSNR of 42.50
Twoconv model, 1438 samples got a PSNR of 41.72
Twoconv model, 2335 samples got a PSNR of 43.79
Twoconv model, 3792 samples got a PSNR of 43.35
Twoconv model, 6158 samples got a PSNR of 43.79
Twoconv model, 10000 samples got a PSNR of 42.99


# Zero-shot

In [None]:
aligner_type = "zeroshot"
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets[1:], desc="Training"):

    aligner = train_zeroshot_aligner(data, permutation, n_samples, train_snr, n_samples, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}.")

Training: 100%|██████████| 18/18 [01:31<00:00,  5.07s/it]


In [None]:
aligner_type = "zeroshot"
log_file = f"alignment/logs/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets[1:]:

    aligner = _ZeroShotAlignment(
        F_tilde=torch.zeros(n_samples, resolution**2),
        G_tilde=torch.zeros(resolution**2, n_samples), 
        G=torch.zeros(1, 1),
        L=torch.zeros(n_samples, n_samples),
        mean=torch.zeros(n_samples, 1)
    )

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Zeroshot model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

Zeroshot model, 2 samples got a PSNR of 11.12
Zeroshot model, 4 samples got a PSNR of 13.74
Zeroshot model, 6 samples got a PSNR of 14.50
Zeroshot model, 11 samples got a PSNR of 15.11
Zeroshot model, 18 samples got a PSNR of 15.75
Zeroshot model, 29 samples got a PSNR of 16.63
Zeroshot model, 48 samples got a PSNR of 17.68
Zeroshot model, 78 samples got a PSNR of 18.83
Zeroshot model, 127 samples got a PSNR of 20.14


Zeroshot model, 206 samples got a PSNR of 20.15
Zeroshot model, 335 samples got a PSNR of 23.83
Zeroshot model, 545 samples got a PSNR of 26.29
Zeroshot model, 885 samples got a PSNR of 29.66
Zeroshot model, 1438 samples got a PSNR of 33.37
Zeroshot model, 2335 samples got a PSNR of 36.31
Zeroshot model, 3792 samples got a PSNR of 37.55
Zeroshot model, 6158 samples got a PSNR of 38.45
Zeroshot model, 10000 samples got a PSNR of 38.43
