In [1]:
%cd ..

import os
import torch
import copy
import numpy as np
from tqdm.notebook import tqdm


from alignment.alignment_utils import load_deep_jscc
from alignment.alignment_model import *
from alignment.alignment_model import _LinearAlignment, _MLPAlignment, _ConvolutionalAlignment, _ZeroShotAlignment, _TwoConvAlignment
from alignment.alignment_training import *
from alignment.alignment_validation import *

/home/lorenzo/repos/Deep-JSCC-PyTorch


In [2]:
model1_fp = r'alignment/models/autoencoders/snr_30_seed_42.pkl'
model2_fp = r'alignment/models/autoencoders/snr_30_seed_43.pkl'
folder = r'psnr_vs_pilots'
os.makedirs(f'alignment/models/plots/{folder}', exist_ok=True)

dataset = "cifar10"
resolution = 96
channel = 'AWGN'
batch_size = 64
num_workers = 4

snr = 30
seed = 43

train_snr = snr
val_snr = snr
times = 10
c = 8

n_points = 20
pilots_sets = np.unique(np.logspace(0, np.log10(10000), num=n_points, base=10).astype(int))
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

encoder = copy.deepcopy(load_deep_jscc(model1_fp, val_snr, c, "AWGN").encoder)
decoder = copy.deepcopy(load_deep_jscc(model2_fp, val_snr, c, "AWGN").decoder)

train_loader, test_loader = get_data_loaders(dataset, resolution, batch_size, num_workers)
data = load_alignment_dataset(model1_fp, model2_fp, train_snr, train_loader, c, device)

Caching inputs: 100%|██████████| 782/782 [00:04<00:00, 169.25it/s]


# No mismatch - Unaligned - Zeroshot max

In [3]:
log_file = f"alignment/logs/lines_snr_{snr}_seed_{seed}.txt"

# unaligned
model = AlignedDeepJSCC(encoder, decoder, None, val_snr, "AWGN")

result_msg = f"unaligned {validation_vectorized(model, test_loader, times, device):.2f}"
print(result_msg)
with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

# aligned
model = AlignedDeepJSCC(encoder, copy.deepcopy(load_deep_jscc(model1_fp, val_snr, c, "AWGN").decoder), None, val_snr, "AWGN")

result_msg = f"aligned {validation_vectorized(model, test_loader, times, device):.2f}"
print(result_msg)
with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

# zeroshot
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

aligner = train_zeroshot_aligner(data, permutation, resolution**2, train_snr, resolution**2, device)
aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

result_msg = f"zeroshot {validation_vectorized(aligned_model, test_loader, times, device):.2f}"
print(result_msg)
with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

unaligned 5.39
aligned 48.50
zeroshot 38.35


# Least Squares

In [None]:
aligner_type = "linear"
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):

    aligner = train_linear_aligner(data, permutation, n_samples, train_snr)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}")

Training:   5%|▌         | 1/19 [00:00<00:08,  2.04it/s]

Done with 1


Training:  11%|█         | 2/19 [00:00<00:08,  2.06it/s]

Done with 2


Training:  16%|█▌        | 3/19 [00:01<00:07,  2.08it/s]

Done with 4


Training:  21%|██        | 4/19 [00:01<00:07,  2.02it/s]

Done with 6


Training:  26%|██▋       | 5/19 [00:02<00:06,  2.02it/s]

Done with 11


Training:  32%|███▏      | 6/19 [00:02<00:06,  2.04it/s]

Done with 18


Training:  37%|███▋      | 7/19 [00:03<00:05,  2.06it/s]

Done with 29


Training:  42%|████▏     | 8/19 [00:03<00:05,  2.05it/s]

Done with 48


Training:  47%|████▋     | 9/19 [00:04<00:04,  2.04it/s]

Done with 78


Training:  53%|█████▎    | 10/19 [00:04<00:04,  2.03it/s]

Done with 127


Training:  58%|█████▊    | 11/19 [00:05<00:03,  2.02it/s]

Done with 206


Training:  63%|██████▎   | 12/19 [00:05<00:03,  2.00it/s]

Done with 335


Training:  68%|██████▊   | 13/19 [00:06<00:03,  1.98it/s]

Done with 545


Training:  74%|███████▎  | 14/19 [00:06<00:02,  1.91it/s]

Done with 885


Training:  79%|███████▉  | 15/19 [00:07<00:02,  1.72it/s]

Done with 1438


Training:  84%|████████▍ | 16/19 [00:08<00:01,  1.58it/s]

Done with 2335


Training:  89%|████████▉ | 17/19 [00:09<00:01,  1.42it/s]

Done with 3792


Training:  95%|█████████▍| 18/19 [00:10<00:00,  1.09it/s]

Done with 6158


Training: 100%|██████████| 19/19 [00:12<00:00,  1.51it/s]

Done with 10000





In [7]:
aligner_type = "linear"
aligner = _LinearAlignment(resolution**2)
log_file = f"alignment/logs/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_linear_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, channel)

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Linear model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

Linear model, 1 samples got a PSNR of 12.21
Linear model, 2 samples got a PSNR of 11.83
Linear model, 4 samples got a PSNR of 12.51
Linear model, 6 samples got a PSNR of 13.03
Linear model, 11 samples got a PSNR of 13.55
Linear model, 18 samples got a PSNR of 13.95
Linear model, 29 samples got a PSNR of 14.89
Linear model, 48 samples got a PSNR of 15.48
Linear model, 78 samples got a PSNR of 16.20
Linear model, 127 samples got a PSNR of 17.12
Linear model, 206 samples got a PSNR of 18.03
Linear model, 335 samples got a PSNR of 18.95
Linear model, 545 samples got a PSNR of 20.07
Linear model, 885 samples got a PSNR of 21.32
Linear model, 1438 samples got a PSNR of 22.80
Linear model, 2335 samples got a PSNR of 24.44
Linear model, 3792 samples got a PSNR of 26.33
Linear model, 6158 samples got a PSNR of 28.43
Linear model, 10000 samples got a PSNR of 30.71


# Linear Neural

In [None]:
aligner_type = "neural"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):
    
    aligner, epoch = train_neural_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}. Trained for {epoch} epochs.")

Training:   5%|▌         | 1/19 [00:03<01:06,  3.71s/it]

Done with 1. Trained for 260 epochs.


Training:  11%|█         | 2/19 [00:08<01:16,  4.49s/it]

Done with 2. Trained for 373 epochs.


Training:  16%|█▌        | 3/19 [00:14<01:24,  5.30s/it]

Done with 4. Trained for 458 epochs.


Training:  21%|██        | 4/19 [00:21<01:24,  5.67s/it]

Done with 6. Trained for 474 epochs.


Training:  26%|██▋       | 5/19 [00:33<01:51,  7.98s/it]

Done with 11. Trained for 867 epochs.


Training:  32%|███▏      | 6/19 [00:50<02:22, 10.96s/it]

Done with 18. Trained for 1074 epochs.


Training:  37%|███▋      | 7/19 [01:10<02:48, 14.04s/it]

Done with 29. Trained for 1293 epochs.


Training:  42%|████▏     | 8/19 [01:32<03:02, 16.57s/it]

Done with 48. Trained for 1212 epochs.


Training:  47%|████▋     | 9/19 [01:40<02:19, 13.97s/it]

Done with 78. Trained for 245 epochs.


Training:  53%|█████▎    | 10/19 [01:50<01:55, 12.80s/it]

Done with 127. Trained for 257 epochs.


Training:  58%|█████▊    | 11/19 [02:01<01:37, 12.22s/it]

Done with 206. Trained for 197 epochs.


Training:  63%|██████▎   | 12/19 [02:14<01:25, 12.25s/it]

Done with 335. Trained for 128 epochs.


Training:  68%|██████▊   | 13/19 [02:27<01:14, 12.50s/it]

Done with 545. Trained for 83 epochs.


Training:  74%|███████▎  | 14/19 [02:40<01:03, 12.78s/it]

Done with 885. Trained for 58 epochs.


Training:  79%|███████▉  | 15/19 [02:55<00:53, 13.38s/it]

Done with 1438. Trained for 39 epochs.


Training:  84%|████████▍ | 16/19 [03:17<00:47, 15.89s/it]

Done with 2335. Trained for 35 epochs.


Training:  89%|████████▉ | 17/19 [03:43<00:38, 19.18s/it]

Done with 3792. Trained for 27 epochs.


Training:  95%|█████████▍| 18/19 [04:24<00:25, 25.62s/it]

Done with 6158. Trained for 25 epochs.


Training: 100%|██████████| 19/19 [05:24<00:00, 17.06s/it]

Done with 10000. Trained for 23 epochs.





In [5]:
aligner_type = "neural"
aligner = _LinearAlignment(resolution**2)
log_file = f"alignment/logs/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_neural_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Neural model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

Neural model, 1 samples got a PSNR of 11.66
Neural model, 2 samples got a PSNR of 12.32
Neural model, 4 samples got a PSNR of 12.60
Neural model, 6 samples got a PSNR of 13.10
Neural model, 11 samples got a PSNR of 14.53
Neural model, 18 samples got a PSNR of 15.41
Neural model, 29 samples got a PSNR of 16.28
Neural model, 48 samples got a PSNR of 17.20
Neural model, 78 samples got a PSNR of 17.14
Neural model, 127 samples got a PSNR of 18.17
Neural model, 206 samples got a PSNR of 19.47
Neural model, 335 samples got a PSNR of 20.94
Neural model, 545 samples got a PSNR of 22.76
Neural model, 885 samples got a PSNR of 25.18
Neural model, 1438 samples got a PSNR of 27.43
Neural model, 2335 samples got a PSNR of 31.19
Neural model, 3792 samples got a PSNR of 31.81
Neural model, 6158 samples got a PSNR of 33.30
Neural model, 10000 samples got a PSNR of 34.02


# MLP

In [None]:
aligner_type = "mlp"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):
    
    aligner, epoch = train_mlp_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}. Trained for {epoch} epochs.")

Training:   5%|▌         | 1/19 [00:05<01:31,  5.07s/it]

Done with 1. Trained for 155 epochs.


Training:  11%|█         | 2/19 [00:06<00:54,  3.18s/it]

Done with 2. Trained for 21 epochs.


Training:  16%|█▌        | 3/19 [00:12<01:10,  4.43s/it]

Done with 4. Trained for 220 epochs.


Training:  21%|██        | 4/19 [00:14<00:51,  3.43s/it]

Done with 6. Trained for 21 epochs.


Training:  26%|██▋       | 5/19 [00:18<00:47,  3.41s/it]

Done with 11. Trained for 76 epochs.


Training:  32%|███▏      | 6/19 [00:21<00:44,  3.46s/it]

Done with 18. Trained for 75 epochs.


Training:  37%|███▋      | 7/19 [01:07<03:27, 17.29s/it]

Done with 29. Trained for 1573 epochs.


Training:  42%|████▏     | 8/19 [01:26<03:15, 17.80s/it]

Done with 48. Trained for 605 epochs.


Training:  47%|████▋     | 9/19 [01:32<02:22, 14.24s/it]

Done with 78. Trained for 89 epochs.


Training:  53%|█████▎    | 10/19 [01:38<01:45, 11.76s/it]

Done with 127. Trained for 75 epochs.


Training:  58%|█████▊    | 11/19 [02:06<02:11, 16.46s/it]

Done with 206. Trained for 285 epochs.


Training:  63%|██████▎   | 12/19 [02:51<02:56, 25.22s/it]

Done with 335. Trained for 308 epochs.


Training:  68%|██████▊   | 13/19 [03:57<03:45, 37.61s/it]

Done with 545. Trained for 273 epochs.


Training:  74%|███████▎  | 14/19 [05:10<04:01, 48.32s/it]

Done with 885. Trained for 186 epochs.


Training:  79%|███████▉  | 15/19 [06:22<03:41, 55.35s/it]

Done with 1438. Trained for 113 epochs.


Training:  84%|████████▍ | 16/19 [07:48<03:14, 64.79s/it]

Done with 2335. Trained for 88 epochs.


Training:  89%|████████▉ | 17/19 [09:19<02:25, 72.62s/it]

Done with 3792. Trained for 56 epochs.


In [None]:
aligner_type = "mlp"
aligner = _MLPAlignment(input_dim=resolution**2, hidden_dims=[resolution**2])
log_file = f"alignment/logs/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_mlp_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"MLP model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

MLP model, 1 samples got a PSNR of 9.16
MLP model, 2 samples got a PSNR of 10.12
MLP model, 4 samples got a PSNR of 10.90
MLP model, 6 samples got a PSNR of 10.84
MLP model, 11 samples got a PSNR of 12.84
MLP model, 18 samples got a PSNR of 11.43
MLP model, 29 samples got a PSNR of 11.86
MLP model, 48 samples got a PSNR of 10.14
MLP model, 78 samples got a PSNR of 16.26
MLP model, 127 samples got a PSNR of 16.63
MLP model, 206 samples got a PSNR of 17.96
MLP model, 335 samples got a PSNR of 21.63
MLP model, 545 samples got a PSNR of 18.68
MLP model, 885 samples got a PSNR of 24.24
MLP model, 1438 samples got a PSNR of 25.15
MLP model, 2335 samples got a PSNR of 26.25
MLP model, 3792 samples got a PSNR of 26.39
MLP model, 6158 samples got a PSNR of 26.81
MLP model, 10000 samples got a PSNR of 27.13


# Convolutional

In [None]:
aligner_type = "conv"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):

    aligner, epoch = train_conv_aligner(data, permutation, n_samples, c, batch_size, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}. Trained for {epoch} epochs.")

Done with 1. Trained for 589 epochs.
Done with 2. Trained for 514 epochs.
Done with 4. Trained for 499 epochs.
Done with 6. Trained for 764 epochs.
Done with 11. Trained for 513 epochs.
Done with 18. Trained for 511 epochs.
Done with 29. Trained for 508 epochs.
Done with 48. Trained for 537 epochs.
Done with 78. Trained for 474 epochs.
Done with 127. Trained for 661 epochs.
Done with 206. Trained for 631 epochs.
Done with 335. Trained for 333 epochs.
Done with 545. Trained for 334 epochs.
Done with 885. Trained for 225 epochs.
Done with 1438. Trained for 162 epochs.
Done with 2335. Trained for 179 epochs.
Done with 3792. Trained for 109 epochs.
Done with 6158. Trained for 101 epochs.
Done with 10000. Trained for 60 epochs.


In [None]:
aligner_type = "conv"
aligner = _ConvolutionalAlignment(in_channels=2*c, out_channels=2*c, kernel_size=5)
log_file = f"alignment/logs/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_conv_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Conv model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

Conv model, 1 samples got a PSNR of 22.47
Conv model, 2 samples got a PSNR of 32.35
Conv model, 4 samples got a PSNR of 36.16
Conv model, 6 samples got a PSNR of 38.26
Conv model, 11 samples got a PSNR of 37.65
Conv model, 18 samples got a PSNR of 38.36
Conv model, 29 samples got a PSNR of 37.95
Conv model, 48 samples got a PSNR of 39.04
Conv model, 78 samples got a PSNR of 39.81
Conv model, 127 samples got a PSNR of 40.63
Conv model, 206 samples got a PSNR of 41.07
Conv model, 335 samples got a PSNR of 41.00
Conv model, 545 samples got a PSNR of 41.32
Conv model, 885 samples got a PSNR of 41.47
Conv model, 1438 samples got a PSNR of 41.29
Conv model, 2335 samples got a PSNR of 41.57
Conv model, 3792 samples got a PSNR of 41.27
Conv model, 6158 samples got a PSNR of 41.51
Conv model, 10000 samples got a PSNR of 41.49


# Two Conv

In [None]:
aligner_type = "twoconv"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):

    aligner, epoch = train_twoconv_aligner(data, permutation, n_samples, c, batch_size, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}. Trained for {epoch} epochs.")

Done with 1. Trained for 667 epochs.
Done with 2. Trained for 625 epochs.
Done with 4. Trained for 713 epochs.
Done with 6. Trained for 500 epochs.
Done with 11. Trained for 430 epochs.
Done with 18. Trained for 419 epochs.
Done with 29. Trained for 427 epochs.
Done with 48. Trained for 534 epochs.
Done with 78. Trained for 168 epochs.
Done with 127. Trained for 153 epochs.
Done with 206. Trained for 416 epochs.
Done with 335. Trained for 162 epochs.
Done with 545. Trained for 187 epochs.
Done with 885. Trained for 101 epochs.
Done with 1438. Trained for 160 epochs.
Done with 2335. Trained for 235 epochs.
Done with 3792. Trained for 138 epochs.
Done with 6158. Trained for 129 epochs.
Done with 10000. Trained for 164 epochs.


In [None]:
aligner_type = "twoconv"
aligner = _TwoConvAlignment(in_channels=2*c, hidden_channels=2*c, out_channels=2*c, kernel_size=5)
log_file = f"alignment/logs/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_twoconv_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Twoconv model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

Twoconv model, 1 samples got a PSNR of 25.85
Twoconv model, 2 samples got a PSNR of 27.45
Twoconv model, 4 samples got a PSNR of 32.31
Twoconv model, 6 samples got a PSNR of 33.19
Twoconv model, 11 samples got a PSNR of 35.35
Twoconv model, 18 samples got a PSNR of 35.64
Twoconv model, 29 samples got a PSNR of 36.32
Twoconv model, 48 samples got a PSNR of 36.36
Twoconv model, 78 samples got a PSNR of 34.07
Twoconv model, 127 samples got a PSNR of 33.25
Twoconv model, 206 samples got a PSNR of 40.30
Twoconv model, 335 samples got a PSNR of 39.25
Twoconv model, 545 samples got a PSNR of 40.32
Twoconv model, 885 samples got a PSNR of 40.26
Twoconv model, 1438 samples got a PSNR of 42.06
Twoconv model, 2335 samples got a PSNR of 43.75
Twoconv model, 3792 samples got a PSNR of 42.08
Twoconv model, 6158 samples got a PSNR of 42.85
Twoconv model, 10000 samples got a PSNR of 44.48


# Zero-shot

In [None]:
aligner_type = "zeroshot"
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets[1:], desc="Training"):

    aligner = train_zeroshot_aligner(data, permutation, n_samples, train_snr, n_samples, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}.")

Done with 2.
Done with 4.
Done with 6.
Done with 11.
Done with 18.
Done with 29.
Done with 48.
Done with 78.
Done with 127.
Done with 206.
Done with 335.
Done with 545.
Done with 885.
Done with 1438.
Done with 2335.
Done with 3792.
Done with 6158.
Done with 10000.


In [None]:
aligner_type = "twoconv"
log_file = f"alignment/logs/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets[1:]:

    aligner = _ZeroShotAlignment(
        F_tilde=torch.zeros(n_samples, resolution**2),
        G_tilde=torch.zeros(resolution**2, n_samples), 
        G=torch.zeros(1, 1),
        L=torch.zeros(n_samples, n_samples),
        mean=torch.zeros(n_samples, 1)
    )

    aligner_fp = f'alignment/models/plots/{folder}/aligner_zeroshot_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Zeroshot model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

Zeroshot model, 2 samples got a PSNR of 11.35
Zeroshot model, 4 samples got a PSNR of 13.35
Zeroshot model, 6 samples got a PSNR of 13.75
Zeroshot model, 11 samples got a PSNR of 14.65
Zeroshot model, 18 samples got a PSNR of 15.47
Zeroshot model, 29 samples got a PSNR of 16.40
Zeroshot model, 48 samples got a PSNR of 17.64
Zeroshot model, 78 samples got a PSNR of 18.57
Zeroshot model, 127 samples got a PSNR of 20.34
Zeroshot model, 206 samples got a PSNR of 21.87
Zeroshot model, 335 samples got a PSNR of 23.04
Zeroshot model, 545 samples got a PSNR of 26.47
Zeroshot model, 885 samples got a PSNR of 29.72
Zeroshot model, 1438 samples got a PSNR of 33.36
Zeroshot model, 2335 samples got a PSNR of 36.35
Zeroshot model, 3792 samples got a PSNR of 37.58
Zeroshot model, 6158 samples got a PSNR of 38.89
Zeroshot model, 10000 samples got a PSNR of 38.43
