In [None]:
%cd ..

import os
import torch
import copy
import numpy as np
from tqdm.notebook import tqdm


from alignment.alignment_utils import load_deep_jscc
from alignment.alignment_model import *
from alignment.alignment_model import _LinearAlignment, _MLPAlignment, _ConvolutionalAlignment, _ZeroShotAlignment, _TwoConvAlignment
from alignment.alignment_training import *
from alignment.alignment_validation import *

In [None]:
snr_ae = 20
n_samples = 10000
resolution = 96
channel = 'Rayleigh'

snrs = [-20, -10, 0, 10, 20, 30]
seeds = [42]

model1_fp = f'alignment/models/autoencoders/{"rayleigh_" if channel == "Rayleigh" else ""}snr_{snr_ae}_seed_42.pkl'
model2_fp = f'alignment/models/autoencoders/{"rayleigh_" if channel == "Rayleigh" else ""}snr_{snr_ae}_seed_43.pkl'
folder = f'psnr_vs_snr'
os.makedirs(f'alignment/models/plots/{folder}', exist_ok=True)

dataset = "cifar10"
batch_size = 64
num_workers = 4

logs_folder = f'alignment/logs_rayleigh_{resolution}'
os.makedirs(logs_folder, exist_ok=True)

times = 10
c = 8

device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

train_loader, test_loader = get_data_loaders(dataset, resolution, batch_size, num_workers)
data = load_alignment_dataset(model1_fp, model2_fp, snr_ae, train_loader, c, device)

encoder = copy.deepcopy(load_deep_jscc(model1_fp, snr_ae, c, channel).encoder)
decoder = copy.deepcopy(load_deep_jscc(model2_fp, snr_ae, c, channel).decoder)

# No mismatch - Unaligned - Zeroshot max

In [None]:
for snr in snrs:
    for seed in seeds:

        log_file = f"{logs_folder}/lines_ae_{snr_ae}_snr_{snr}_seed_{seed}.txt"

        # unaligned
        model = AlignedDeepJSCC(encoder, decoder, None, snr, channel)

        result_msg = f"unaligned {validation_vectorized(model, test_loader, times, device):.2f}"
        print(result_msg)
        with open(log_file, 'a') as f:
            f.write(f"{result_msg}\n")

        # aligned
        model = AlignedDeepJSCC(encoder, copy.deepcopy(load_deep_jscc(model1_fp, snr, c, channel).decoder), None, snr, channel)

        result_msg = f"aligned {validation_vectorized(model, test_loader, times, device):.2f}"
        print(result_msg)
        with open(log_file, 'a') as f:
                f.write(f"{result_msg}\n")

        # zeroshot
        data.flat = True

        set_seed(seed)
        permutation = torch.randperm(len(data))

        aligner = train_zeroshot_aligner(data, permutation, resolution**2, snr, resolution**2, channel, device)
        aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, snr, channel)

        result_msg = f"zeroshot {validation_vectorized(aligned_model, test_loader, times, device):.2f}"
        print(result_msg)
        with open(log_file, 'a') as f:
                f.write(f"{result_msg}\n")

# Least Squares

In [None]:
aligner_type = "linear"
data.flat = True

seeds = [42]

for snr in snrs:
    for seed in seeds:
        set_seed(seed)
        permutation = torch.randperm(len(data))

        aligner = train_linear_aligner(data, permutation, n_samples, snr, channel)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        torch.save(aligner.state_dict(), aligner_fp)

        print(f"Done with SNR {snr} SEED {seed}.")

In [None]:
aligner_type = "linear"
aligner = _LinearAlignment(resolution**2)

for snr in snrs:
    for seed in seeds:

        log_file = f"{logs_folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.txt"

        with open(log_file, 'w') as f:
            pass

        set_seed(seed)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

        aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, snr, channel)

        psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
        
        result_msg = f"Linear model, AE {snr_ae} SNR {snr} SEED {seed} got a PSNR of {psnr_result:.2f}"
        print(result_msg)
        
        with open(log_file, 'a') as f:
            f.write(f"{result_msg}\n")

# Linear Neural

In [None]:
aligner_type = "neural"
data.flat = False

for snr in snrs:
    for seed in seeds:
        set_seed(seed)
        permutation = torch.randperm(len(data))
    
        aligner, epoch = train_neural_aligner(data, permutation, n_samples, batch_size, resolution, 6, snr, device)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        torch.save(aligner.state_dict(), aligner_fp)

        print(f"Done with SNR {snr} SEED {seed}.")

In [None]:
aligner_type = "neural"
aligner = _LinearAlignment(resolution**2)

for snr in snrs:
    for seed in seeds:

        log_file = f"{logs_folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.txt"

        with open(log_file, 'w') as f:
            pass

        set_seed(seed)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

        aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, snr, channel)

        psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
        
        result_msg = f"Linear model, AE {snr_ae} SNR {snr} SEED {seed} got a PSNR of {psnr_result:.2f}"
        print(result_msg)
        
        with open(log_file, 'a') as f:
            f.write(f"{result_msg}\n")

# MLP

In [None]:
aligner_type = "mlp"
data.flat = False

for snr in snrs:
    for seed in seeds:
        set_seed(seed)
        permutation = torch.randperm(len(data))

        aligner, epoch = train_mlp_aligner(data, permutation, n_samples, batch_size, resolution, 6, snr, device)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        torch.save(aligner.state_dict(), aligner_fp)

        print(f"Done with SNR {snr} SEED {seed}.")

In [None]:
aligner_type = "mlp"
aligner = _MLPAlignment(input_dim=resolution**2, hidden_dims=[resolution**2])

for snr in snrs:
    for seed in seeds:

        log_file = f"{logs_folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.txt"

        with open(log_file, 'w') as f:
            pass

        set_seed(seed)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

        aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, snr, channel)

        psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
        
        result_msg = f"Linear model, AE {snr_ae} SNR {snr} SEED {seed} got a PSNR of {psnr_result:.2f}"
        print(result_msg)
        
        with open(log_file, 'a') as f:
            f.write(f"{result_msg}\n")

# Convolutional

In [None]:
aligner_type = "conv"
data.flat = False

for snr in snrs:
    for seed in seeds:
        set_seed(seed)
        permutation = torch.randperm(len(data))

        aligner, epoch = train_conv_aligner(data, permutation, n_samples, c, batch_size, snr, channel, device)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        torch.save(aligner.state_dict(), aligner_fp)

        print(f"Done with SNR {snr} SEED {seed}.")

In [None]:
aligner_type = "conv"
aligner = _ConvolutionalAlignment(in_channels=2*c, out_channels=2*c, kernel_size=5)

for snr in snrs:
    for seed in seeds:

        log_file = f"{logs_folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.txt"

        with open(log_file, 'w') as f:
            pass

        set_seed(seed)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

        aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, snr, channel)

        psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
        
        result_msg = f"Conv model, AE {snr_ae} SNR {snr} SEED {seed} got a PSNR of {psnr_result:.2f}"
        print(result_msg)
        
        with open(log_file, 'a') as f:
            f.write(f"{result_msg}\n")

# Two Conv

In [None]:
aligner_type = "twoconv"
data.flat = False

seeds = [47]

for snr in snrs:
    for seed in seeds:
        set_seed(seed)
        permutation = torch.randperm(len(data))

        aligner, epoch = train_twoconv_aligner(data, permutation, n_samples, c, batch_size, snr, channel, device)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        torch.save(aligner.state_dict(), aligner_fp)

        print(f"Done with SNR {snr} SEED {seed}.")

In [None]:
aligner_type = "twoconv"
aligner = _TwoConvAlignment(in_channels=2*c, hidden_channels=2*c, out_channels=2*c, kernel_size=5)

for snr in snrs:
    for seed in seeds:

        log_file = f"{logs_folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.txt"

        with open(log_file, 'w') as f:
            pass

        set_seed(seed)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

        aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, snr, channel)

        psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
        
        result_msg = f"Linear model, AE {snr_ae} SNR {snr} SEED {seed} got a PSNR of {psnr_result:.2f}"
        print(result_msg)
        
        with open(log_file, 'a') as f:
            f.write(f"{result_msg}\n")

# Zero-shot

In [None]:
aligner_type = "zeroshot"
data.flat = True

for snr in snrs:
    for seed in seeds:
        set_seed(seed)
        permutation = torch.randperm(len(data))

        aligner = train_zeroshot_aligner(data, permutation, n_samples, snr, n_samples, channel, device)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        torch.save(aligner.state_dict(), aligner_fp)

        print(f"Done with SNR {snr} SEED {seed}.")

In [None]:
aligner_type = "zeroshot"
aligner = _ZeroShotAlignment(
    F_tilde=torch.zeros(n_samples, resolution**2),
    G_tilde=torch.zeros(resolution**2, n_samples), 
    G=torch.zeros(1, 1),
    L=torch.zeros(n_samples, n_samples),
    mean=torch.zeros(n_samples, 1)
)

for snr in snrs:
    for seed in seeds:

        log_file = f"{logs_folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.txt"

        with open(log_file, 'w') as f:
            pass

        set_seed(seed)

        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.pth'
        aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

        aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, snr, channel)

        psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
        
        result_msg = f"Linear model, AE {snr_ae} SNR {snr} SEED {seed} got a PSNR of {psnr_result:.2f}"
        print(result_msg)
        
        with open(log_file, 'a') as f:
            f.write(f"{result_msg}\n")