In [1]:
%cd ..

import os
import torch
import copy
import numpy as np
from tqdm.notebook import tqdm


from alignment.alignment_utils import load_deep_jscc
from alignment.alignment_model import *
from alignment.alignment_model import _LinearAlignment, _MLPAlignment, _ConvolutionalAlignment, _ZeroShotAlignment, _TwoConvAlignment
from alignment.alignment_training import *
from alignment.alignment_validation import *

/home/lorenzo/repos/Deep-JSCC-PyTorch


In [2]:
snr = 30
seed = 42
resolution = 96

model1_fp = f'alignment/models/autoencoders/snr_{snr}_seed_42.pkl'
model2_fp = f'alignment/models/autoencoders/snr_{snr}_seed_43.pkl'
folder = f'psnr_vs_pilots'
os.makedirs(f'alignment/models/plots/{folder}', exist_ok=True)

dataset = "cifar10"
channel = 'AWGN'
batch_size = 1024
num_workers = 4

logs_folder = f'alignment/logs_{resolution}'
os.makedirs(logs_folder, exist_ok=True)

train_snr = snr
val_snr = snr
times = 10
c = 8

n_points = 20
pilots_sets = np.unique(np.logspace(0, np.log10(10000), num=n_points, base=10).astype(int))
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

encoder = copy.deepcopy(load_deep_jscc(model1_fp, val_snr, c, "AWGN").encoder)
decoder = copy.deepcopy(load_deep_jscc(model2_fp, val_snr, c, "AWGN").decoder)

train_loader, test_loader = get_data_loaders(dataset, resolution, batch_size, num_workers)
data = load_alignment_dataset(model1_fp, model2_fp, train_snr, train_loader, c, device)

Caching inputs: 100%|██████████| 49/49 [00:02<00:00, 18.08it/s]


In [3]:
for seed in [42, 43, 44, 45, 46]:

    aligner_type = "neural"
    data.flat = False
    
    set_seed(seed)
    permutation = torch.randperm(len(data))
    
    for n_samples in tqdm(pilots_sets, desc="Training"):
        
        aligner, epoch = train_neural_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)
    
        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
        torch.save(aligner.state_dict(), aligner_fp)
    
    aligner = _LinearAlignment(resolution**2)
    log_file = f"alignment/logs/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"
    
    with open(log_file, 'w') as f:
        pass
    
    set_seed(seed)
    
    for n_samples in pilots_sets:
    
        aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
        aligner.load_state_dict(torch.load(aligner_fp, map_location=device))
    
        aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")
    
        psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
        
        result_msg = f"Neural model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
        print(result_msg)
        
        with open(log_file, 'a') as f:
            f.write(f"{result_msg}\n")

Training:  68%|██████▊   | 13/19 [14:48<06:50, 68.37s/it] 


KeyboardInterrupt: 

# No mismatch - Unaligned - Zeroshot max

In [None]:
log_file = f"{logs_folder}/lines_snr_{snr}_seed_{seed}.txt"

# unaligned
model = AlignedDeepJSCC(encoder, decoder, None, val_snr, "AWGN")

result_msg = f"unaligned {validation_vectorized(model, test_loader, times, device):.2f}"
print(result_msg)
with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

# aligned
model = AlignedDeepJSCC(encoder, copy.deepcopy(load_deep_jscc(model1_fp, val_snr, c, "AWGN").decoder), None, val_snr, "AWGN")

result_msg = f"aligned {validation_vectorized(model, test_loader, times, device):.2f}"
print(result_msg)
with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

# zeroshot
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

aligner = train_zeroshot_aligner(data, permutation, resolution**2, train_snr, resolution**2, device)
aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

result_msg = f"zeroshot {validation_vectorized(aligned_model, test_loader, times, device):.2f}"
print(result_msg)
with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

unaligned 11.72
aligned 43.71
zeroshot 27.91


# Least Squares

In [None]:
aligner_type = "linear"
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):

    aligner = train_linear_aligner(data, permutation, n_samples, train_snr)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}")

Training: 100%|██████████| 1/1 [00:01<00:00,  1.77s/it]


In [None]:
aligner_type = "linear"
aligner = _LinearAlignment(resolution**2)
log_file = f"{logs_folder}/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, channel)

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Linear model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

Linear model, 10000 samples got a PSNR of 42.61


# Linear Neural

In [None]:
def train_neural_aligner(data, permutation, n_samples, batch_size, resolution, ratio, train_snr, device):
    """
    Train convolutional aligner with Adam optimization using train/validation split,
    mixed precision, warmup, and learning rate scheduling.
    """

    import torch.cuda.amp as amp

    # train settings
    epochs_max = 10000
    patience = 20
    min_delta = 1e-5
    base_lr = 1e-3     # max learning rate
    final_lr = 1e-4    # target low LR
    warmup_frac = 0.05

    # prepare data with train/validation split
    indices = permutation[:n_samples]

    if n_samples < 10:
        use_val = False
        train_indices = indices
        val_indices = []
    else:
        use_val = True
        val_size = max(1, int(0.1 * n_samples))
        train_size = n_samples - val_size
        train_indices = indices[:train_size]
        val_indices = indices[train_size:]

    # create datasets and dataloaders
    train_subset = AlignmentSubset(data, train_indices)
    train_dataloader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)

    if use_val:
        val_subset = AlignmentSubset(data, val_indices)
        val_dataloader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

    # prepare model and optimizer
    aligner = _LinearAlignment(size=resolution * resolution * 3 * 2 // ratio).to(device)
    channel = Channel("AWGN", train_snr)
    criterion = nn.MSELoss(reduction='mean')
    optimizer = optim.Adam(aligner.parameters(), lr=base_lr)

    # learning rate scheduler with warmup
    total_steps = epochs_max * len(train_dataloader)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=base_lr,
        steps_per_epoch=len(train_dataloader),
        epochs=epochs_max,
        final_div_factor=base_lr / final_lr,
        pct_start=warmup_frac,
        anneal_strategy='cos'
    )

    # AMP scaler for mixed precision
    scaler = amp.GradScaler('cuda')

    # init train state
    best_loss = float('inf')
    best_model_state = None
    checks_without_improvement = 0
    epoch = 0

    while True:
        aligner.train()
        train_loss = 0.0

        for inputs, targets in train_dataloader:
            if train_snr is not None:
                inputs = channel(inputs)

            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            with amp.autocast('cuda'):
                outputs = aligner(inputs)
                loss = criterion(outputs, targets)
                loss_scaled = loss * inputs.shape[0]

            scaler.scale(loss_scaled).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            train_loss += loss_scaled.item()

        # validation phase
        if use_val:
            aligner.eval()
            val_loss = 0.0
            with torch.no_grad():
                for inputs, targets in val_dataloader:
                    if train_snr is not None:
                        inputs = channel(inputs)
                    inputs, targets = inputs.to(device), targets.to(device)
                    with amp.autocast('cuda'):
                        outputs = aligner(inputs)
                        loss = criterion(outputs, targets)
                        loss_scaled = loss * inputs.shape[0]
                        val_loss += loss_scaled.item()

            avg_val_loss = val_loss / len(val_dataloader)
            current_loss = avg_val_loss
        else:
            avg_train_loss = train_loss / len(train_dataloader)
            current_loss = avg_train_loss

        epoch += 1

        if best_loss - current_loss > min_delta:
            best_loss = current_loss
            best_model_state = copy.deepcopy(aligner.state_dict())
            checks_without_improvement = 0
        else:
            checks_without_improvement += 1

        if checks_without_improvement >= patience or epoch > epochs_max:
            break

    if best_model_state is not None:
        aligner.load_state_dict(best_model_state)

    return aligner.cpu(), epoch

In [None]:
aligner_type = "neural"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm([10000], desc="Training"):
    
    aligner, epoch = train_neural_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

aligner = _LinearAlignment(resolution**2)

set_seed(seed)

for n_samples in [10000]:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Neural model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)

  scaler = amp.GradScaler()
  with amp.autocast():
  with amp.autocast():


In [None]:
aligner_type = "neural"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):
    
    aligner, epoch = train_neural_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}. Trained for {epoch} epochs.")

Training: 100%|██████████| 19/19 [07:41<00:00, 24.29s/it]


In [None]:
aligner_type = "neural"
aligner = _LinearAlignment(resolution**2)
log_file = f"{logs_folder}/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Neural model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

Neural model, 1 samples got a PSNR of 12.60


Neural model, 2 samples got a PSNR of 12.78
Neural model, 4 samples got a PSNR of 12.06
Neural model, 6 samples got a PSNR of 12.69
Neural model, 11 samples got a PSNR of 12.92
Neural model, 18 samples got a PSNR of 13.27
Neural model, 29 samples got a PSNR of 13.52
Neural model, 48 samples got a PSNR of 13.96
Neural model, 78 samples got a PSNR of 14.55
Neural model, 127 samples got a PSNR of 15.00
Neural model, 206 samples got a PSNR of 15.41
Neural model, 335 samples got a PSNR of 15.98
Neural model, 545 samples got a PSNR of 16.68
Neural model, 885 samples got a PSNR of 18.00
Neural model, 1438 samples got a PSNR of 18.58
Neural model, 2335 samples got a PSNR of 20.09
Neural model, 3792 samples got a PSNR of 21.40
Neural model, 6158 samples got a PSNR of 23.84
Neural model, 10000 samples got a PSNR of 25.22


# MLP

In [11]:
def train_mlp_aligner(data, permutation, n_samples, batch_size, resolution, ratio, train_snr, device):
    """
    Train convolutional aligner with Adam optimization using train/validation split.
    """

    # train settings
    epochs_max=1000
    patience=20
    min_delta=1e-5

    # prepare data with train/validation split
    indices = permutation[:n_samples]
    
    # handle small datasets (< 10 samples)
    if n_samples < 10:
        use_val = False

        # use all data for training, no validation split
        train_indices = indices
        val_indices = []

    else:
        use_val = True

        # split into 90 train - 10 validation
        val_size = max(1, int(0.1 * n_samples))
        train_size = n_samples - val_size
        
        train_indices = indices[:train_size]
        val_indices = indices[train_size:]
    
    # create datasets and dataloaders
    train_subset = AlignmentSubset(data, train_indices)
    train_dataloader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    
    if use_val:
        val_subset = AlignmentSubset(data, val_indices)
        val_dataloader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

    # prepare model and optimizer
    size = resolution * resolution * 3 * 2 // ratio
    aligner = _MLPAlignment(size, [size]).to(device)
    channel = Channel("AWGN", train_snr)
    criterion = nn.MSELoss(reduction='mean')
    optimizer = optim.Adam(aligner.parameters(), lr=1e-3)
    
    # add learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, min_lr=1e-6)

    # init train state
    best_loss = float('inf')
    best_model_state = None
    checks_without_improvement = 0
    epoch = 0

    # train loop
    while True:
        # training phase
        aligner.train()
        train_loss = 0.0

        for inputs, targets in train_dataloader:
            if train_snr is not None:
                inputs = channel(inputs)

            optimizer.zero_grad()
            outputs = aligner(inputs.to(device))
            loss = criterion(outputs, targets.to(device))
            loss = loss * inputs.shape[0]
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # validation phase
        if use_val:
            aligner.eval()
            val_loss = 0.0
            
            with torch.no_grad():
                for inputs, targets in val_dataloader:
                    if train_snr is not None:
                        inputs = channel(inputs)
                    
                    outputs = aligner(inputs.to(device))
                    loss = criterion(outputs, targets.to(device))
                    loss = loss * inputs.shape[0]
                    val_loss += loss.item()
            
            # use validation loss for early stopping
            avg_val_loss = val_loss / len(val_dataloader)
            current_loss = avg_val_loss
        else:
            # use training loss if no validation set
            avg_train_loss = train_loss / len(train_dataloader)
            current_loss = avg_train_loss

        epoch += 1

        # step the scheduler with current loss
        scheduler.step(current_loss)

        # check if improvement
        if best_loss - current_loss > min_delta:
            best_loss = current_loss
            best_model_state = copy.deepcopy(aligner.state_dict())
            checks_without_improvement = 0
        else:
            checks_without_improvement += 1

        # break if patience exceeded
        if checks_without_improvement >= patience:
            break

        # break if max epochs exceeded
        if epoch > epochs_max:
            break

    # restore best model
    if best_model_state is not None:
        aligner.load_state_dict(best_model_state)
    
    return aligner.cpu(), epoch

In [12]:
aligner_type = "mlp"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm([10000], desc="Training"):
    
    aligner, epoch = train_mlp_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

aligner = _MLPAlignment(input_dim=resolution**2, hidden_dims=[resolution**2])

set_seed(seed)

for n_samples in [10000]:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"MLP model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)

Training: 100%|██████████| 1/1 [19:33<00:00, 1173.36s/it]


MLP model, 10000 samples got a PSNR of 17.56


In [None]:
aligner_type = "mlp"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for reg in [0.0001, 0.001, 0.01]:
    for lr in [1e-5, 1e-4, 1e-3]:

        for n_samples in tqdm([10000], desc="Training"):
            
            aligner, epoch = train_mlp_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device, lr, reg)

            aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
            torch.save(aligner.state_dict(), aligner_fp)

        aligner = _MLPAlignment(input_dim=resolution**2, hidden_dims=[resolution**2])

        set_seed(seed)

        for n_samples in [10000]:

            aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
            aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

            aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

            psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
            
            result_msg = f"LR: {lr} REG: {reg}, MLP model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
            print(result_msg)

In [None]:
aligner_type = "mlp"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):
    
    aligner, epoch = train_mlp_aligner(data, permutation, n_samples, batch_size, resolution, 6, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}. Trained for {epoch} epochs.")

Training: 100%|██████████| 1/1 [02:20<00:00, 140.62s/it]


In [None]:
aligner_type = "mlp"
aligner = _MLPAlignment(input_dim=resolution**2, hidden_dims=[resolution**2])
log_file = f"{logs_folder}/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"MLP model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

MLP model, 10000 samples got a PSNR of 27.07


# Convolutional

In [6]:
def train_conv_aligner(data, permutation, n_samples, c, batch_size, train_snr, device):
    """
    Train convolutional aligner with Adam optimization using train/validation split.
    """

    # train settings
    epochs_max=10000
    patience=10
    min_delta=1e-5
    reg_val = 0.001

    # prepare data with train/validation split
    indices = permutation[:n_samples]
    
    # handle small datasets (< 10 samples)
    if n_samples < 10:
        use_val = False

        # use all data for training, no validation split
        train_indices = indices
        val_indices = []

    else:
        use_val = True

        # split into 90 train - 10 validation
        val_size = max(1, int(0.1 * n_samples))
        train_size = n_samples - val_size
        
        train_indices = indices[:train_size]
        val_indices = indices[train_size:]
    
    # create datasets and dataloaders
    train_subset = AlignmentSubset(data, train_indices)
    train_dataloader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    
    if use_val:
        val_subset = AlignmentSubset(data, val_indices)
        val_dataloader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

    # prepare model and optimizer
    aligner = _ConvolutionalAlignment(in_channels=2*c, out_channels=2*c, kernel_size=5).to(device)
    channel = Channel("AWGN", train_snr)
    criterion = nn.MSELoss(reduction='mean')
    optimizer = optim.Adam(aligner.parameters(), lr=1e-4)

    # init train state
    best_loss = float('inf')
    best_model_state = None
    checks_without_improvement = 0
    epoch = 0

    # train loop
    while True:
        # training phase
        aligner.train()
        train_loss = 0.0

        for inputs, targets in train_dataloader:
            if train_snr is not None:
                inputs = channel(inputs)

            optimizer.zero_grad()
            outputs = aligner(inputs.to(device))
            loss = criterion(outputs, targets.to(device))
            loss = loss * inputs.shape[0]
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # validation phase
        if use_val:
            aligner.eval()
            val_loss = 0.0
            
            with torch.no_grad():
                for inputs, targets in val_dataloader:
                    if train_snr is not None:
                        inputs = channel(inputs)
                    
                    outputs = aligner(inputs.to(device))
                    loss = criterion(outputs, targets.to(device))
                    loss = loss * inputs.shape[0]
                    val_loss += loss.item()
            
            # use validation loss for early stopping
            avg_val_loss = val_loss / len(val_dataloader)
            current_loss = avg_val_loss
        else:
            # use training loss if no validation set
            avg_train_loss = train_loss / len(train_dataloader)
            current_loss = avg_train_loss

        epoch += 1

        # check if improvement
        if best_loss - current_loss > min_delta:
            best_loss = current_loss
            best_model_state = copy.deepcopy(aligner.state_dict())
            checks_without_improvement = 0
        else:
            checks_without_improvement += 1

        # break if patience exceeded
        if checks_without_improvement >= patience:
            break

        # break if max epochs exceeded
        if epoch > epochs_max:
            break

    # restore best model
    if best_model_state is not None:
        aligner.load_state_dict(best_model_state)
    
    return aligner.cpu(), epoch

In [7]:
aligner_type = "conv"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm([10000], desc="Training"):
    
    aligner, epoch = train_conv_aligner(data, permutation, n_samples, c, batch_size, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

aligner = _ConvolutionalAlignment(in_channels=2*c, out_channels=2*c, kernel_size=5)

set_seed(seed)

for n_samples in [10000]:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Conv model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)

Training: 100%|██████████| 1/1 [04:32<00:00, 272.33s/it]


Conv model, 10000 samples got a PSNR of 41.66


In [None]:
aligner_type = "conv"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):

    aligner, epoch = train_conv_aligner(data, permutation, n_samples, c, batch_size, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}. Trained for {epoch} epochs.")

Training: 100%|██████████| 19/19 [03:37<00:00, 11.45s/it]


In [None]:
aligner_type = "conv"
aligner = _ConvolutionalAlignment(in_channels=2*c, out_channels=2*c, kernel_size=5)
log_file = f"{logs_folder}/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Conv model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

Conv model, 1 samples got a PSNR of 24.77
Conv model, 2 samples got a PSNR of 30.57
Conv model, 4 samples got a PSNR of 32.55
Conv model, 6 samples got a PSNR of 33.52
Conv model, 11 samples got a PSNR of 32.98
Conv model, 18 samples got a PSNR of 34.33
Conv model, 29 samples got a PSNR of 34.56
Conv model, 48 samples got a PSNR of 34.82
Conv model, 78 samples got a PSNR of 34.81
Conv model, 127 samples got a PSNR of 35.05
Conv model, 206 samples got a PSNR of 35.31
Conv model, 335 samples got a PSNR of 35.41
Conv model, 545 samples got a PSNR of 35.34
Conv model, 885 samples got a PSNR of 35.34
Conv model, 1438 samples got a PSNR of 35.35
Conv model, 2335 samples got a PSNR of 35.33
Conv model, 3792 samples got a PSNR of 35.40
Conv model, 6158 samples got a PSNR of 35.26
Conv model, 10000 samples got a PSNR of 35.22


# Two Conv

In [3]:
def train_twoconv_aligner(data, permutation, n_samples, c, batch_size, train_snr, device):
    """
    Train convolutional aligner with Adam optimization using train/validation split.
    """

    # train settings
    epochs_max=10000
    patience=20
    min_delta=1e-5
    # prepare data with train/validation split
    indices = permutation[:n_samples]
    
    # handle small datasets (< 10 samples)
    if n_samples < 10:
        use_val = False

        # use all data for training, no validation split
        train_indices = indices
        val_indices = []

    else:
        use_val = True

        # split into 90 train - 10 validation
        val_size = max(1, int(0.1 * n_samples))
        train_size = n_samples - val_size
        
        train_indices = indices[:train_size]
        val_indices = indices[train_size:]
    
    # create datasets and dataloaders
    train_subset = AlignmentSubset(data, train_indices)
    train_dataloader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    
    if use_val:
        val_subset = AlignmentSubset(data, val_indices)
        val_dataloader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

    # prepare model and optimizer
    aligner = _TwoConvAlignment(in_channels=2*c, hidden_channels=2*c, out_channels=2*c, kernel_size=5).to(device)
    channel = Channel("AWGN", train_snr)
    criterion = nn.MSELoss(reduction='mean')
    optimizer = optim.Adam(aligner.parameters(), lr=1e-4)

    # init train state
    best_loss = float('inf')
    best_model_state = None
    checks_without_improvement = 0
    epoch = 0

    # train loop
    while True:
        # training phase
        aligner.train()
        train_loss = 0.0

        for inputs, targets in train_dataloader:
            if train_snr is not None:
                inputs = channel(inputs)

            optimizer.zero_grad()
            outputs = aligner(inputs.to(device))
            loss = criterion(outputs, targets.to(device))
            loss = loss * inputs.shape[0]
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # validation phase
        if use_val:
            aligner.eval()
            val_loss = 0.0
            
            with torch.no_grad():
                for inputs, targets in val_dataloader:
                    if train_snr is not None:
                        inputs = channel(inputs)
                    
                    outputs = aligner(inputs.to(device))
                    loss = criterion(outputs, targets.to(device))
                    loss = loss * inputs.shape[0]
                    val_loss += loss.item()
            
            # use validation loss for early stopping
            avg_val_loss = val_loss / len(val_dataloader)
            current_loss = avg_val_loss
        else:
            # use training loss if no validation set
            avg_train_loss = train_loss / len(train_dataloader)
            current_loss = avg_train_loss

        epoch += 1

        # check if improvement
        if best_loss - current_loss > min_delta:
            best_loss = current_loss
            best_model_state = copy.deepcopy(aligner.state_dict())
            checks_without_improvement = 0
        else:
            checks_without_improvement += 1

        # break if patience exceeded
        if checks_without_improvement >= patience:
            break

        # break if max epochs exceeded
        if epoch > epochs_max:
            break

    # restore best model
    if best_model_state is not None:
        aligner.load_state_dict(best_model_state)
    
    return aligner.cpu(), epoch

In [4]:
aligner_type = "twoconv"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm([10000], desc="Training"):
    
    aligner, epoch = train_twoconv_aligner(data, permutation, n_samples, c, batch_size, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

aligner = _TwoConvAlignment(in_channels=2*c, hidden_channels=2*c, out_channels=2*c, kernel_size=5)

set_seed(seed)

for n_samples in [10000]:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Twoconv model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)

Training:   0%|          | 0/1 [00:00<?, ?it/s]

Training: 100%|██████████| 1/1 [21:32<00:00, 1292.57s/it]


Twoconv model, 10000 samples got a PSNR of 44.29


In [None]:
aligner_type = "twoconv"
data.flat = False

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets, desc="Training"):

    aligner, epoch = train_twoconv_aligner(data, permutation, n_samples, c, batch_size, train_snr, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}. Trained for {epoch} epochs.")

Training: 100%|██████████| 19/19 [16:56<00:00, 53.52s/it] 


In [None]:
aligner_type = "twoconv"
aligner = _TwoConvAlignment(in_channels=2*c, hidden_channels=2*c, out_channels=2*c, kernel_size=5)
log_file = f"{logs_folder}/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets:

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Twoconv model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

Twoconv model, 1 samples got a PSNR of 16.13
Twoconv model, 2 samples got a PSNR of 24.65
Twoconv model, 4 samples got a PSNR of 30.05
Twoconv model, 6 samples got a PSNR of 31.60
Twoconv model, 11 samples got a PSNR of 32.94
Twoconv model, 18 samples got a PSNR of 33.66
Twoconv model, 29 samples got a PSNR of 35.74
Twoconv model, 48 samples got a PSNR of 36.88
Twoconv model, 78 samples got a PSNR of 37.24
Twoconv model, 127 samples got a PSNR of 37.55
Twoconv model, 206 samples got a PSNR of 37.68
Twoconv model, 335 samples got a PSNR of 37.92
Twoconv model, 545 samples got a PSNR of 37.86
Twoconv model, 885 samples got a PSNR of 38.43
Twoconv model, 1438 samples got a PSNR of 38.45
Twoconv model, 2335 samples got a PSNR of 38.72
Twoconv model, 3792 samples got a PSNR of 38.53
Twoconv model, 6158 samples got a PSNR of 38.11
Twoconv model, 10000 samples got a PSNR of 38.74


# Zero-shot

In [None]:
aligner_type = "zeroshot"
data.flat = True

set_seed(seed)
permutation = torch.randperm(len(data))

for n_samples in tqdm(pilots_sets[1:], desc="Training"):

    aligner = train_zeroshot_aligner(data, permutation, n_samples, train_snr, n_samples, device)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    torch.save(aligner.state_dict(), aligner_fp)

    # tqdm.write(f"Done with {n_samples}.")

Training: 100%|██████████| 18/18 [01:31<00:00,  5.10s/it]


In [None]:
aligner_type = "zeroshot"
log_file = f"{logs_folder}/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt"

with open(log_file, 'w') as f:
    pass

set_seed(seed)

for n_samples in pilots_sets[1:]:

    aligner = _ZeroShotAlignment(
        F_tilde=torch.zeros(n_samples, resolution**2),
        G_tilde=torch.zeros(resolution**2, n_samples), 
        G=torch.zeros(1, 1),
        L=torch.zeros(n_samples, n_samples),
        mean=torch.zeros(n_samples, 1)
    )

    aligner_fp = f'alignment/models/plots/{folder}/aligner_{aligner_type}_{n_samples}.pth'
    aligner.load_state_dict(torch.load(aligner_fp, map_location=device))

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    psnr_result = validation_vectorized(aligned_model, test_loader, times, device)
    
    result_msg = f"Zeroshot model, {n_samples} samples got a PSNR of {psnr_result:.2f}"
    print(result_msg)
    
    with open(log_file, 'a') as f:
        f.write(f"{result_msg}\n")

Zeroshot model, 2 samples got a PSNR of 11.28
Zeroshot model, 4 samples got a PSNR of 11.55
Zeroshot model, 6 samples got a PSNR of 11.77
Zeroshot model, 11 samples got a PSNR of 11.90
Zeroshot model, 18 samples got a PSNR of 12.72
Zeroshot model, 29 samples got a PSNR of 13.34
Zeroshot model, 48 samples got a PSNR of 13.29
Zeroshot model, 78 samples got a PSNR of 14.30
Zeroshot model, 127 samples got a PSNR of 15.11
Zeroshot model, 206 samples got a PSNR of 15.08
Zeroshot model, 335 samples got a PSNR of 16.24
Zeroshot model, 545 samples got a PSNR of 15.94
Zeroshot model, 885 samples got a PSNR of 16.95
Zeroshot model, 1438 samples got a PSNR of 19.43
Zeroshot model, 2335 samples got a PSNR of 18.31
Zeroshot model, 3792 samples got a PSNR of 22.53
Zeroshot model, 6158 samples got a PSNR of 26.78
Zeroshot model, 10000 samples got a PSNR of 28.33
