In [None]:
%cd ..

import os
import torch
import copy
import numpy as np
from tqdm.notebook import tqdm


from alignment.alignment_utils import load_deep_jscc
from alignment.alignment_model import *
from alignment.alignment_model import _LinearAlignment, _MLPAlignment, _ConvolutionalAlignment, _ZeroShotAlignment, _TwoConvAlignment
from alignment.alignment_training import *
from alignment.alignment_validation import *

In [None]:
snr = 30
seed = 42
resolution = 96

model1_fp = f'alignment/models/autoencoders/upscaled_42.pkl'
model2_fp = f'alignment/models/autoencoders/upscaled_43.pkl'
folder = f'high_res'
os.makedirs(f'alignment/models/plots/{folder}', exist_ok=True)

dataset = "cifar10"
channel = 'AWGN'
batch_size = 64
num_workers = 4

logs_folder = f'alignment/logs_{resolution}'
os.makedirs(logs_folder, exist_ok=True)

train_snr = snr
val_snr = snr
times = 10
c = 8

n_points = 20
pilots_sets = np.unique(np.logspace(0, np.log10(10000), num=n_points, base=10).astype(int))
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

encoder = copy.deepcopy(load_deep_jscc(model1_fp, val_snr, c, "AWGN").encoder)
decoder = copy.deepcopy(load_deep_jscc(model2_fp, val_snr, c, "AWGN").decoder)

train_loader, test_loader = get_data_loaders(dataset, resolution, batch_size, num_workers)
data = load_alignment_dataset(model1_fp, model2_fp, train_snr, train_loader, c, device)

In [None]:
aligner_type = "conv"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for snr in [-20, -10, 0, 10, 20, 30]:
    aligner, epoch = train_conv_aligner(data, permutation, 50000, c, batch_size, snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{snr}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

In [None]:
aligner_type = "twoconv"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for snr in [-20, -10, 0, 10, 20, 30]:
    aligner, epoch = train_twoconv_aligner(data, permutation, 50000, c, batch_size, snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{snr}.pth'
    torch.save(aligner.state_dict(), aligner_fp)