In [None]:
%cd ..

import os
import torch
import copy
import numpy as np
from tqdm.notebook import tqdm


from alignment.alignment_utils import load_deep_jscc
from alignment.alignment_model import *
from alignment.alignment_model import _LinearAlignment, _MLPAlignment, _ConvolutionalAlignment, _ZeroShotAlignment, _TwoConvAlignment
from alignment.alignment_training import *
from alignment.alignment_validation import *

In [None]:
snr = 30
seed = 42
resolution = 96

model1_fp = f'alignment/models/autoencoders/snr_{snr}_seed_42.pkl'
model2_fp = f'alignment/models/autoencoders/snr_{snr}_seed_43.pkl'
folder = f'psnr_vs_pilots'
os.makedirs(f'alignment/models/plots/{folder}', exist_ok=True)

dataset = "cifar10"
channel = 'AWGN'
batch_size = 64
num_workers = 4

logs_folder = f'alignment/logs_{resolution}'
os.makedirs(logs_folder, exist_ok=True)

train_snr = snr
val_snr = snr
times = 10
c = 8

n_points = 20
pilots_sets = np.unique(np.logspace(0, np.log10(10000), num=n_points, base=10).astype(int))
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

encoder = copy.deepcopy(load_deep_jscc(model1_fp, val_snr, c, "AWGN").encoder)
decoder = copy.deepcopy(load_deep_jscc(model2_fp, val_snr, c, "AWGN").decoder)

train_loader, test_loader = get_data_loaders(dataset, resolution, batch_size, num_workers)
data = load_alignment_dataset(model1_fp, model2_fp, train_snr, train_loader, c, device)

# No mismatch - Unaligned - Zeroshot max

In [None]:
log_file = f"{logs_folder}/lines_snr_{snr}_seed_{seed}.txt"

# unaligned
model = AlignedDeepJSCC(encoder, decoder, None, val_snr, "AWGN")

result_msg = f"unaligned {validation_vectorized(model, test_loader, times, device):.2f}"
print(result_msg)
with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

# aligned
model = AlignedDeepJSCC(encoder, copy.deepcopy(load_deep_jscc(model1_fp, val_snr, c, "AWGN").decoder), None, val_snr, "AWGN")

result_msg = f"aligned {validation_vectorized(model, test_loader, times, device):.2f}"
print(result_msg)
with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

# zeroshot
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

aligner = train_zeroshot_aligner(data, permutation, resolution**2, train_snr, resolution**2, device)
aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

result_msg = f"zeroshot {validation_vectorized(aligned_model, test_loader, times, device):.2f}"
print(result_msg)
with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

# Least Squares

In [None]:
aligner_type = "linear"
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):

    aligner = train_linear_aligner(data, permutation, n_samples, train_snr)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}")

In [None]:
aligner_type = "linear"
aligner = _LinearAlignment(resolution**2)
log_file = f"{logs_folder}/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, channel)

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Linear model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

# Linear Neural

In [None]:
aligner_type = "neural"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):
    
    aligner, epoch = train_neural_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}. Trained for {epoch} epochs.")

In [None]:
aligner_type = "neural"
aligner = _LinearAlignment(resolution**2)
log_file = f"{logs_folder}/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Neural model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

# MLP

In [None]:
aligner_type = "mlp"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):
    
    aligner, epoch = train_mlp_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}. Trained for {epoch} epochs.")

In [None]:
aligner_type = "mlp"
aligner = _MLPAlignment(input_dim=resolution**2, hidden_dims=[resolution**2])
log_file = f"{logs_folder}/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"MLP model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

# Convolutional

In [None]:
aligner_type = "conv"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):

    aligner, epoch = train_conv_aligner(data, permutation, n_samples, c, batch_size, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}. Trained for {epoch} epochs.")

In [None]:
aligner_type = "conv"
aligner = _ConvolutionalAlignment(in_channels=2*c, out_channels=2*c, kernel_size=5)
log_file = f"{logs_folder}/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Conv model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

# Two Conv

In [None]:
aligner_type = "twoconv"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):

    aligner, epoch = train_twoconv_aligner(data, permutation, n_samples, c, batch_size, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}. Trained for {epoch} epochs.")

In [None]:
aligner_type = "twoconv"
aligner = _TwoConvAlignment(in_channels=2*c, hidden_channels=2*c, out_channels=2*c, kernel_size=5)
log_file = f"{logs_folder}/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Twoconv model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

# Zero-shot

In [None]:
aligner_type = "zeroshot"
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets[1:], desc="Training"):

    aligner = train_zeroshot_aligner(data, permutation, n_samples, train_snr, n_samples, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}.")

In [None]:
aligner_type = "zeroshot"
log_file = f"{logs_folder}/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets[1:]:

    aligner = _ZeroShotAlignment(
        F_tilde=torch.zeros(n_samples, resolution**2),
        G_tilde=torch.zeros(resolution**2, n_samples), 
        G=torch.zeros(1, 1),
        L=torch.zeros(n_samples, n_samples),
        mean=torch.zeros(n_samples, 1)
    )

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Zeroshot model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")