In [1]:
%cd ..

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import copy
import numpy as np
import random

from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from torchvision import datasets

from alignment.alignment_utils import load_deep_jscc, get_batch_psnr
from alignment.alignment_model import _ConvolutionalAlignment, _LinearAlignment, _ZeroShotAlignment, AlignedDeepJSCC
from utils import image_normalization
import matplotlib.pyplot as plt

from dataset import Vanilla
from model import DeepJSCC
from tqdm import tqdm
import pickle
from channel import Channel
from PIL import Image

from concurrent.futures import ThreadPoolExecutor
from functools import partial
from alignment.linear_models_gpu import Baseline

/home/lorenzo/repos/Deep-JSCC-PyTorch


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model1_fp = r'alignment/models/autoencoders/upscaled_42.pkl'
model2_fp = r'alignment/models/autoencoders/upscaled_43.pkl'

snr = 30
times = 10
c = 8

dataset = "cifar10"
resolution = 96
folder = "psnr_vs_pilots_2"

batch_size = 64
num_workers = 4
channel = 'AWGN'

device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

samples = np.unique(np.logspace(0, np.log10(10000), num=100, base=10).astype(int))
seed = 42

os.makedirs(r'alignment/models/plots/'+folder, exist_ok=True)

In [3]:
def set_seed(seed):
    random.seed(seed)                      # Python RNG
    np.random.seed(seed)                   # NumPy RNG
    torch.manual_seed(seed)                # PyTorch CPU RNG
    torch.cuda.manual_seed(seed)           # PyTorch GPU RNG
    torch.cuda.manual_seed_all(seed)       # All GPUs
    torch.backends.cudnn.deterministic = True   # Deterministic cuDNN
    torch.backends.cudnn.benchmark = False      # Disable benchmark to ensure reproducibility

# Init

## Data

In [4]:
############
# GET DATA #
############

if dataset == 'cifar10':
    transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((resolution, resolution))])

    train_dataset = datasets.CIFAR10(root='../dataset/', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)

    test_dataset = datasets.CIFAR10(root='../dataset/', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)

elif dataset == 'imagenet':
    # the size of paper is 128
    transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((resolution, resolution))])

    print("loading data of imagenet")

    train_dataset = datasets.ImageFolder(root='./dataset/ImageNet/train', transform=transform)
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)

    test_dataset = Vanilla(root='./dataset/ImageNet/val', transform=transform)
    test_loader = DataLoader(test_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)

elif dataset == 'imagenette':
    transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((resolution, resolution))])

    train_dataset = datasets.Imagenette(root='../dataset/', split="train", download=True, transform=transform)
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)

    test_dataset = datasets.Imagenette(root='../dataset/', split="val", download=True, transform=transform)
    test_loader = DataLoader(test_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)

else:
    raise Exception('Unknown dataset')

In [5]:
class AlignmentDataset(Dataset):
    def __init__(self, dataloader, model1, model2, flat=False):
        self.outputs = []

        model1.eval()
        model1.to(device)

        model2.eval()
        model2.to(device)

        with torch.no_grad():
            for inputs, _ in tqdm(dataloader, desc="Computing model outputs"):
                inputs = inputs.to(device)

                out1 = model1(inputs)
                out2 = model2(inputs)

                for o1, o2 in zip(out1, out2):
                    if flat:
                        o1 = o1.flatten()
                        o2 = o2.flatten()

                    self.outputs.append((o1.cpu(), o2.cpu()))

    def __len__(self):
        return len(self.outputs)

    def __getitem__(self, idx):
        return self.outputs[idx]  

## Utils

In [6]:
def validation(model, dataloader, times):

    model = model.to(device)
    batch_psnr_list = []

    with torch.no_grad():
        for inputs, _ in dataloader:

            inputs = inputs.to(device)
            psnr = torch.zeros(size=(inputs.shape[0], ), device=device)

            for _ in range(times):
                demo_image = model(inputs)
                demo_image = image_normalization('denormalization')(demo_image)
                gt = image_normalization('denormalization')(inputs)
                psnr += get_batch_psnr(demo_image, gt)

            psnr /= times
            batch_mean_psnr = psnr.mean().item()
            batch_psnr_list.append(batch_mean_psnr)

    overall_mean_psnr = sum(batch_psnr_list) / len(batch_psnr_list)

    return overall_mean_psnr

def load_from_checkpoint(path, snr):
    state_dict = torch.load(path, map_location=device)
    from collections import OrderedDict
    new_state_dict = OrderedDict()

    for k, v in state_dict.items():
        name = k.replace('module.','') # remove `module.`
        new_state_dict[name] = v

    model = DeepJSCC(c=c, channel_type=channel, snr=snr)

    model.load_state_dict(new_state_dict)
    model.change_channel(channel, snr)

    return model

## Validation parallelized

In [7]:
def validation_worker(model, inputs, gt, times, worker_id):
    """Worker function for parallel model inference"""
    model.eval()
    psnr_sum = torch.zeros(inputs.shape[0], device=inputs.device)
    
    with torch.no_grad():
        for _ in range(times):
            demo_image = model(inputs)
            demo_image = image_normalization('denormalization')(demo_image)
            psnr_sum += get_batch_psnr(demo_image, gt)
    
    return psnr_sum / times

def validation_parallel_inference(model, dataloader, times, num_workers=None):
    """Version with parallel inference for multiple runs"""
    model = model.to(device)
    model.eval()
    
    # Auto-detect optimal number of workers
    if num_workers is None:
        num_workers = min(times, torch.cuda.device_count() if torch.cuda.is_available() else mp.cpu_count())
    
    total_psnr = 0.0
    total_samples = 0
    
    with torch.no_grad():
        for inputs, *_ in dataloader:
            inputs = inputs.to(device)
            batch_size = inputs.shape[0]
            
            # Denormalize ground truth once
            gt = image_normalization('denormalization')(inputs)
            
            if times == 1:
                # No need for parallelization with single run
                demo_image = model(inputs)
                demo_image = image_normalization('denormalization')(demo_image)
                batch_psnr = get_batch_psnr(demo_image, gt).sum().item()
            else:
                # Parallel inference for multiple runs
                runs_per_worker = times // num_workers
                remaining_runs = times % num_workers
                
                batch_psnr_sum = torch.zeros(batch_size, device=device)
                
                # Use thread pool for GPU parallelization (better for CUDA)
                with ThreadPoolExecutor(max_workers=num_workers) as executor:
                    futures = []
                    
                    # Submit jobs with different number of runs per worker
                    for i in range(num_workers):
                        worker_runs = runs_per_worker + (1 if i < remaining_runs else 0)
                        if worker_runs > 0:
                            future = executor.submit(
                                validation_worker, 
                                model, inputs, gt, worker_runs, i
                            )
                            futures.append(future)
                    
                    # Collect results
                    for future in futures:
                        batch_psnr_sum += future.result()
                
                batch_psnr = batch_psnr_sum.sum().item()
            
            total_psnr += batch_psnr
            total_samples += batch_size
    
    return total_psnr / total_samples

def validation_parallel_batches(model, dataloader, times, num_workers=4):
    """Version with parallel batch processing"""
    model = model.to(device)
    model.eval()
    
    def process_batch(batch_data):
        inputs, *_ = batch_data
        inputs = inputs.to(device)
        batch_size = inputs.shape[0]
        
        # Denormalize ground truth once
        gt = image_normalization('denormalization')(inputs)
        
        # Accumulate PSNR across multiple runs
        batch_psnr_sum = torch.zeros(batch_size, device=device)
        
        with torch.no_grad():
            for _ in range(times):
                demo_image = model(inputs)
                demo_image = image_normalization('denormalization')(demo_image)
                batch_psnr_sum += get_batch_psnr(demo_image, gt)
        
        batch_mean_psnr = (batch_psnr_sum / times).sum().item()
        return batch_mean_psnr, batch_size
    
    # Process batches in parallel
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        results = list(executor.map(process_batch, dataloader))
    
    total_psnr = sum(psnr for psnr, _ in results)
    total_samples = sum(samples for _, samples in results)
    
    return total_psnr / total_samples

def validation_vectorized(model, dataloader, times):
    """Vectorized version for maximum efficiency when memory allows"""
    model = model.to(device)
    model.eval()
    
    total_psnr = 0.0
    total_samples = 0
    
    with torch.no_grad():
        for inputs, *_ in dataloader:
            inputs = inputs.to(device)
            batch_size = inputs.shape[0]
            
            # Denormalize ground truth once
            gt = image_normalization('denormalization')(inputs)
            
            if times == 1:
                demo_image = model(inputs)
                demo_image = image_normalization('denormalization')(demo_image)
                batch_psnr = get_batch_psnr(demo_image, gt).sum().item()
            else:
                # Vectorized computation - process all runs at once
                # Repeat inputs for all runs
                inputs_repeated = inputs.repeat(times, 1, 1, 1)
                
                # Single forward pass for all runs
                demo_images = model(inputs_repeated)
                demo_images = image_normalization('denormalization')(demo_images)
                
                # Reshape to separate runs and batch dimension
                demo_images = demo_images.view(times, batch_size, *demo_images.shape[1:])
                gt_repeated = gt.unsqueeze(0).repeat(times, 1, 1, 1, 1)
                
                # Compute PSNR for all runs at once
                psnr_all_runs = torch.stack([
                    get_batch_psnr(demo_images[i], gt_repeated[i]) 
                    for i in range(times)
                ])
                
                # Average across runs and sum across batch
                batch_psnr = psnr_all_runs.mean(dim=0).sum().item()
            
            total_psnr += batch_psnr
            total_samples += batch_size
    
    return total_psnr / total_samples

# Main validation function with automatic method selection
def validation(model, dataloader, times, method='auto', num_workers=None):
    """
    Optimized validation with multiple parallelization strategies
    
    Args:
        model: The model to validate
        dataloader: Data loader for validation data
        times: Number of inference runs per sample
        method: 'auto', 'vectorized', 'parallel_inference', 'parallel_batches', or 'sequential'
        num_workers: Number of parallel workers (auto-detected if None)
    """
    
    if method == 'auto':
        # Auto-select best method based on conditions
        if times == 1:
            method = 'sequential'
        elif times <= 4 and torch.cuda.is_available():
            method = 'vectorized'  # Best for GPU with moderate times
        elif times > 4:
            method = 'parallel_inference'  # Best for many inference runs
        else:
            method = 'parallel_batches'  # Best for CPU or complex cases
    
    if method == 'vectorized':
        return validation_vectorized(model, dataloader, times)
    elif method == 'parallel_inference':
        return validation_parallel_inference(model, dataloader, times, num_workers)
    elif method == 'parallel_batches':
        return validation_parallel_batches(model, dataloader, times, num_workers)
    else:  # sequential
        return validation_sequential(model, dataloader, times)

def validation_sequential(model, dataloader, times):
    """Original optimized sequential version for comparison"""
    model = model.to(device)
    model.eval()
    
    total_psnr = 0.0
    total_samples = 0
    
    with torch.no_grad():
        for inputs, *_ in dataloader:
            inputs = inputs.to(device)
            batch_size = inputs.shape[0]
            
            gt = image_normalization('denormalization')(inputs)
            batch_psnr_sum = torch.zeros(batch_size, device=device)
            
            for _ in range(times):
                demo_image = model(inputs)
                demo_image = image_normalization('denormalization')(demo_image)
                batch_psnr_sum += get_batch_psnr(demo_image, gt)
            
            batch_mean_psnr = (batch_psnr_sum / times).sum().item()
            total_psnr += batch_mean_psnr
            total_samples += batch_size
    
    return total_psnr / total_samples

# No mismatch - Unaligned

In [8]:
model1 = load_deep_jscc(model1_fp, snr, c, "AWGN")
model2 = load_deep_jscc(model2_fp, snr, c, "AWGN")

encoder = copy.deepcopy(model1.encoder)
decoder = copy.deepcopy(model2.decoder)

model = AlignedDeepJSCC(encoder, decoder, None, snr, "AWGN")
print(f"Unaligned {validation_vectorized(model, test_loader, times):.2f}")

Unaligned 9.54


In [None]:
model1 = load_deep_jscc(model1_fp, snr, c, "AWGN")

encoder = copy.deepcopy(model1.encoder)
decoder = copy.deepcopy(model1.decoder)

model = AlignedDeepJSCC(encoder, decoder, None, 30, "AWGN")

print(f"Aligned {validation_vectorized(model, test_loader, times):.2f}")

Aligned 40.66


Computing model outputs:  28%|██▊       | 220/782 [02:35<43:24,  4.63s/it]

# Least Squares

In [10]:
def dataset_to_matrices(dataset, batch_size=128):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    data_1 = []
    data_2 = []
    
    for batch in loader:
        data_1.append(batch[0])
        data_2.append(batch[1])

    return torch.cat(data_1, dim=0), torch.cat(data_2, dim=0)

def aligner_least_squares(matrix_1, matrix_2):
    Y = matrix_1.T
    Z = matrix_2.T

    Q = Y @ Z.T @ torch.inverse(Z @ Z.T)

    return _LinearAlignment(align_matrix=Q)

def aligner_least_squares(matrix_1, matrix_2, n_samples):
    Y = matrix_1.T  # [d, n]
    Z = matrix_2.T  # [d, n]

    ZZ_T = Z @ Z.T
    YZ_T = Y @ Z.T

    reg_matrix = (10000) * torch.eye(ZZ_T.size(0), device=ZZ_T.device, dtype=ZZ_T.dtype)
    Q = YZ_T @ torch.linalg.inv(ZZ_T + reg_matrix)

    return _LinearAlignment(align_matrix=Q)

In [11]:
model1 = load_from_checkpoint(model1_fp, snr).encoder
model2 = load_from_checkpoint(model2_fp, snr).encoder

data = AlignmentDataset(train_loader, model1, model2, flat=True)

Computing model outputs: 100%|██████████| 782/782 [00:08<00:00, 88.19it/s] 


In [12]:
set_seed(seed)
permutation = torch.randperm(len(data))

for sample in samples:
    indices = permutation[:sample]
    subset = Subset(data, indices)

    matrix_1, matrix_2 = dataset_to_matrices(subset)

    aligner = aligner_least_squares(matrix_1, matrix_2, sample)

    with open(r'alignment/models/plots/'+folder+'/aligner_linear_'+str(sample)+'.pkl', 'wb') as f:
        pickle.dump(aligner, f)

    print(f"Done with {sample}.")

Done with 1.
Done with 2.
Done with 3.
Done with 4.
Done with 5.
Done with 6.
Done with 7.
Done with 8.
Done with 9.
Done with 10.
Done with 11.
Done with 12.
Done with 13.
Done with 14.
Done with 16.
Done with 17.
Done with 19.
Done with 21.
Done with 23.
Done with 25.
Done with 28.
Done with 31.
Done with 34.
Done with 37.
Done with 41.
Done with 45.
Done with 49.
Done with 54.
Done with 59.
Done with 65.
Done with 72.
Done with 79.
Done with 86.
Done with 95.
Done with 104.
Done with 114.
Done with 126.
Done with 138.
Done with 151.
Done with 166.
Done with 183.
Done with 200.
Done with 220.
Done with 242.
Done with 265.
Done with 291.
Done with 319.
Done with 351.
Done with 385.
Done with 422.
Done with 464.
Done with 509.
Done with 559.
Done with 613.
Done with 673.
Done with 739.
Done with 811.
Done with 890.
Done with 977.
Done with 1072.
Done with 1176.
Done with 1291.
Done with 1417.
Done with 1555.
Done with 1707.
Done with 1873.
Done with 2056.
Done with 2257.
Done with 2477

In [13]:
set_seed(seed)

model1 = load_deep_jscc(model1_fp, snr, c, "AWGN")
model2 = load_deep_jscc(model2_fp, snr, c, "AWGN")

encoder = copy.deepcopy(model1.encoder)
decoder = copy.deepcopy(model2.decoder)

for sample in samples:

    aligner_fp = r'alignment/models/plots/'+folder+'/aligner_linear_'+str(sample)+'.pkl'

    with open(aligner_fp, 'rb') as f:
        aligner = pickle.load(f)

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, snr, channel)

    print(f"Linear model, {sample} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times):.2f}")

Linear model, 1 samples got a PSNR of 10.88
Linear model, 2 samples got a PSNR of 10.74
Linear model, 3 samples got a PSNR of 9.95
Linear model, 4 samples got a PSNR of 9.95
Linear model, 5 samples got a PSNR of 10.21
Linear model, 6 samples got a PSNR of 10.45
Linear model, 7 samples got a PSNR of 10.17
Linear model, 8 samples got a PSNR of 10.01
Linear model, 9 samples got a PSNR of 10.55
Linear model, 10 samples got a PSNR of 10.58
Linear model, 11 samples got a PSNR of 10.98
Linear model, 12 samples got a PSNR of 11.13
Linear model, 13 samples got a PSNR of 11.07
Linear model, 14 samples got a PSNR of 10.99
Linear model, 16 samples got a PSNR of 11.05
Linear model, 17 samples got a PSNR of 11.18
Linear model, 19 samples got a PSNR of 11.06
Linear model, 21 samples got a PSNR of 11.29
Linear model, 23 samples got a PSNR of 11.42
Linear model, 25 samples got a PSNR of 11.88
Linear model, 28 samples got a PSNR of 11.86
Linear model, 31 samples got a PSNR of 11.80
Linear model, 34 samp

# Linear Neural

In [8]:
model1 = load_from_checkpoint(model1_fp, snr).encoder
model2 = load_from_checkpoint(model2_fp, snr).encoder

data = AlignmentDataset(train_loader, model1, model2, flat=False)

Computing model outputs: 100%|██████████| 782/782 [00:08<00:00, 87.03it/s] 


In [9]:
set_seed(seed)
permutation = torch.randperm(len(data))

for sample in samples:
    indices = permutation[:sample]
    subset = Subset(data, indices)
    dataloader = DataLoader(subset, batch_size=batch_size, shuffle=True)

    epochs_max = 10000
    ratio = 6
    patience = 10
    check_interval = 1
    min_delta = 1e-5
    lambda_reg = 0

    aligner = _LinearAlignment(size=resolution * resolution * 3 * 2 // ratio).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(aligner.parameters(), lr=1e-3)
    channel = Channel("AWGN", 7)

    best_loss = float('inf')
    best_model_state = None
    checks_without_improvement = 0
    epoch = 0

    while True:
        epoch_loss = 0.0

        for inputs, targets in dataloader:
            inputs = channel(inputs)
            optimizer.zero_grad()
            outputs = aligner(inputs.to(device))

            mse_loss = criterion(outputs, targets.to(device))
            l2_loss = sum((param**2).sum() for param in aligner.parameters())
            loss = inputs.shape[0] * mse_loss + lambda_reg * l2_loss

            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        epoch += 1

        if epoch % check_interval == 0:
            avg_loss = epoch_loss / len(dataloader)
            if best_loss - avg_loss > min_delta:
                best_loss = avg_loss
                best_model_state = copy.deepcopy(aligner.state_dict())
                checks_without_improvement = 0
            else:
                checks_without_improvement += 1

            if checks_without_improvement >= patience:
                break

        if epoch > epochs_max:
            break

    print(f"Done with {sample}. Trained for {epoch} epochs.")

    # Restore best model
    if best_model_state is not None:
        aligner.load_state_dict(best_model_state)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_neural_{sample}.pkl'

    with open(aligner_fp, 'wb') as f:
        pickle.dump(aligner.to("cpu"), f)

Done with 1. Trained for 18 epochs.
Done with 2. Trained for 23 epochs.
Done with 3. Trained for 23 epochs.
Done with 4. Trained for 63 epochs.
Done with 5. Trained for 45 epochs.
Done with 6. Trained for 63 epochs.
Done with 7. Trained for 60 epochs.
Done with 8. Trained for 54 epochs.
Done with 9. Trained for 58 epochs.
Done with 10. Trained for 56 epochs.
Done with 11. Trained for 59 epochs.
Done with 12. Trained for 58 epochs.
Done with 13. Trained for 65 epochs.
Done with 14. Trained for 61 epochs.
Done with 16. Trained for 57 epochs.
Done with 17. Trained for 74 epochs.
Done with 19. Trained for 72 epochs.
Done with 21. Trained for 70 epochs.
Done with 23. Trained for 66 epochs.
Done with 25. Trained for 63 epochs.
Done with 28. Trained for 93 epochs.
Done with 31. Trained for 65 epochs.
Done with 34. Trained for 65 epochs.
Done with 37. Trained for 69 epochs.
Done with 41. Trained for 758 epochs.
Done with 45. Trained for 745 epochs.
Done with 49. Trained for 901 epochs.
Done wi

In [10]:
set_seed(seed)

model1 = load_deep_jscc(model1_fp, snr, c, "AWGN")
model2 = load_deep_jscc(model2_fp, snr, c, "AWGN")

encoder = copy.deepcopy(model1.encoder)
decoder = copy.deepcopy(model2.decoder)

for sample in samples:

    aligner_fp = r'alignment/models/plots/'+folder+'/aligner_neural_'+str(sample)+'.pkl'

    with open(aligner_fp, 'rb') as f:
        aligner = pickle.load(f)

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, snr, "AWGN")

    print(f"Neural model, {sample} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times):.2f}")

Neural model, 1 samples got a PSNR of 11.42
Neural model, 2 samples got a PSNR of 11.24
Neural model, 3 samples got a PSNR of 11.26
Neural model, 4 samples got a PSNR of 11.37
Neural model, 5 samples got a PSNR of 11.36
Neural model, 6 samples got a PSNR of 11.58
Neural model, 7 samples got a PSNR of 11.50
Neural model, 8 samples got a PSNR of 11.65
Neural model, 9 samples got a PSNR of 11.80
Neural model, 10 samples got a PSNR of 11.76
Neural model, 11 samples got a PSNR of 11.77
Neural model, 12 samples got a PSNR of 11.74
Neural model, 13 samples got a PSNR of 11.81
Neural model, 14 samples got a PSNR of 11.77
Neural model, 16 samples got a PSNR of 11.79
Neural model, 17 samples got a PSNR of 11.97
Neural model, 19 samples got a PSNR of 11.96
Neural model, 21 samples got a PSNR of 12.12
Neural model, 23 samples got a PSNR of 12.16
Neural model, 25 samples got a PSNR of 12.27
Neural model, 28 samples got a PSNR of 12.30
Neural model, 31 samples got a PSNR of 12.33
Neural model, 34 sa

# Convolutional

In [8]:
model1 = load_from_checkpoint(model1_fp, snr).encoder
model2 = load_from_checkpoint(model2_fp, snr).encoder

data = AlignmentDataset(train_loader, model1, model2, flat=False)

Computing model outputs: 100%|██████████| 782/782 [00:08<00:00, 96.96it/s] 


In [9]:
set_seed(seed)
permutation = torch.randperm(len(data))

for sample in samples:
    indices = permutation[:sample]
    subset = Subset(data, indices)
    dataloader = DataLoader(subset, batch_size=batch_size, shuffle=True)

    epochs_max=10000
    ratio=6
    patience=5
    check_interval=1
    min_delta=1e-5

    aligner = _ConvolutionalAlignment(in_channels=2*c, out_channels=2*c, kernel_size=5).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(aligner.parameters(), lr=1e-3)
    channel = Channel("AWGN", 7)

    best_loss = float('inf')
    checks_without_improvement = 0
    epoch = 0

    while True:
        epoch_loss = 0.0

        for inputs, targets in dataloader:
            inputs = channel(inputs)
            optimizer.zero_grad()
            outputs = aligner(inputs.to(device))
            loss = criterion(outputs, targets.to(device))
            loss = loss * inputs.shape[0] # scale by batch size
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        epoch += 1

        if epoch % check_interval == 0:
            avg_loss = epoch_loss / len(dataloader)
            if best_loss - avg_loss > min_delta:
                best_loss = avg_loss
                best_model_state = copy.deepcopy(aligner.state_dict())
                checks_without_improvement = 0
            else:
                checks_without_improvement += 1

            if checks_without_improvement >= patience:
                break

        if epoch > epochs_max:
            break

    print(f"Done with {sample}. Trained for {epoch} epochs.")

    # Restore best model
    if best_model_state is not None:
        aligner.load_state_dict(best_model_state)

    aligner_fp = r'alignment/models/plots/'+folder+'/aligner_conv_'+str(sample)+'.pkl'

    with open(aligner_fp, 'wb') as f:
        pickle.dump(aligner.to("cpu"), f)

Done with 1. Trained for 164 epochs.
Done with 2. Trained for 157 epochs.
Done with 3. Trained for 147 epochs.
Done with 4. Trained for 209 epochs.
Done with 5. Trained for 180 epochs.
Done with 6. Trained for 196 epochs.
Done with 7. Trained for 231 epochs.
Done with 8. Trained for 254 epochs.
Done with 9. Trained for 185 epochs.
Done with 10. Trained for 186 epochs.
Done with 11. Trained for 249 epochs.
Done with 12. Trained for 211 epochs.
Done with 13. Trained for 250 epochs.
Done with 14. Trained for 227 epochs.
Done with 16. Trained for 272 epochs.
Done with 17. Trained for 283 epochs.
Done with 19. Trained for 237 epochs.
Done with 21. Trained for 234 epochs.
Done with 23. Trained for 236 epochs.
Done with 25. Trained for 266 epochs.
Done with 28. Trained for 288 epochs.
Done with 31. Trained for 257 epochs.
Done with 34. Trained for 231 epochs.
Done with 37. Trained for 229 epochs.
Done with 41. Trained for 279 epochs.
Done with 45. Trained for 259 epochs.
Done with 49. Trained

In [10]:
set_seed(seed)

model1 = load_deep_jscc(model1_fp, snr, c, "AWGN")
model2 = load_deep_jscc(model2_fp, snr, c, "AWGN")

encoder = copy.deepcopy(model1.encoder)
decoder = copy.deepcopy(model2.decoder)

for sample in samples:

    aligner_fp = r'alignment/models/plots/'+folder+'/aligner_conv_'+str(sample)+'.pkl'

    with open(aligner_fp, 'rb') as f:
        aligner = pickle.load(f)

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, snr, "AWGN")

    print(f"Conv model, {sample} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times):.2f}")

Conv model, 1 samples got a PSNR of 14.75
Conv model, 2 samples got a PSNR of 24.90
Conv model, 3 samples got a PSNR of 26.01
Conv model, 4 samples got a PSNR of 25.22
Conv model, 5 samples got a PSNR of 25.47
Conv model, 6 samples got a PSNR of 25.68
Conv model, 7 samples got a PSNR of 25.54
Conv model, 8 samples got a PSNR of 25.49
Conv model, 9 samples got a PSNR of 24.90
Conv model, 10 samples got a PSNR of 24.83
Conv model, 11 samples got a PSNR of 25.16
Conv model, 12 samples got a PSNR of 25.45
Conv model, 13 samples got a PSNR of 25.90
Conv model, 14 samples got a PSNR of 25.52
Conv model, 16 samples got a PSNR of 25.26
Conv model, 17 samples got a PSNR of 25.43
Conv model, 19 samples got a PSNR of 24.93
Conv model, 21 samples got a PSNR of 25.37
Conv model, 23 samples got a PSNR of 26.05
Conv model, 25 samples got a PSNR of 25.83
Conv model, 28 samples got a PSNR of 25.68
Conv model, 31 samples got a PSNR of 25.62
Conv model, 34 samples got a PSNR of 26.11
Conv model, 37 sampl

# Zero-shot

In [8]:
model1 = load_from_checkpoint(model1_fp, snr).encoder
model2 = load_from_checkpoint(model2_fp, snr).encoder

data = AlignmentDataset(train_loader, model1, model2, flat=True)

Computing model outputs: 100%|██████████| 782/782 [00:08<00:00, 88.36it/s] 


In [9]:
set_seed(seed)
permutation = torch.randperm(len(data))

samples_zeroshot = set()
for num in samples:
    if num % 2 != 0:
        num += 1
    samples_zeroshot.add(num)
samples_zeroshot = sorted(samples_zeroshot)

In [None]:
for sample in samples_zeroshot[1:]:
    indices = permutation[:sample]
    subset = Subset(data, indices)

    dataloader = DataLoader(subset, batch_size=len(subset))
    input, output = next(iter(dataloader))

    flattened_image_size = resolution * resolution

    try:

        baseline = Baseline(
            input_dim=flattened_image_size,
            output_dim=flattened_image_size,
            channel_matrix=torch.eye(1, dtype=torch.complex64),
            snr=7,
            channel_usage=None,
            typology='pre',
            strategy='PFE',
            use_channel=True,
            seed=seed,
        )

        baseline.fit(input, output)

    except RuntimeError:
        
        print(f"Skipped {sample}.")
        continue

    aligner_fp = r'alignment/models/plots/'+folder+'/aligner_zeroshot_'+str(sample)+'.pkl'

    with open(aligner_fp, 'wb') as f:
        pickle.dump(baseline, f)

    print(f"Done with {sample}.")

In [None]:
set_seed(seed)

model1 = load_deep_jscc(model1_fp, snr, c, "AWGN")
model2 = load_deep_jscc(model2_fp, snr, c, "AWGN")

encoder = copy.deepcopy(model1.encoder).to(device)
decoder = copy.deepcopy(model2.decoder).to(device)

for sample in samples_zeroshot[62:]:

    skip = False
    aligner_fp = r'alignment/models/plots/'+folder+'/aligner_zeroshot_'+str(sample)+'.pkl'

    if not os.path.exists(aligner_fp):
        continue

    with open(aligner_fp, 'rb') as f:
        aligner = pickle.load(f)

    aligner.snr = 30
    batch_psnr_list = []

    with torch.no_grad():
    
        for inputs, _ in test_loader:

            inputs = inputs.to(device)
            inputs = inputs.squeeze()
            psnr_all = 0.0

            for _ in range(times):

                demo_image = encoder(inputs)
                shape = demo_image.shape

                demo_image = demo_image.reshape(shape[0], -1)

                try:
                    demo_image = aligner.transform(demo_image)

                except RuntimeError:
                    skip = True
                    print(f"Skipped {sample}")
                    break

                demo_image = demo_image.reshape(shape)
                demo_image = decoder(demo_image)

                demo_image = image_normalization('denormalization')(demo_image)

                gt = image_normalization('denormalization')(inputs)
                psnr_all += get_batch_psnr(demo_image, gt)

            if skip==True:
                break

            psnr_all /= times
            batch_mean_psnr = psnr_all.mean().item()
            batch_psnr_list.append(batch_mean_psnr)

    if skip==True:
        continue

    overall_mean_psnr = sum(batch_psnr_list) / len(batch_psnr_list)

    print(f"Zeroshot model, {sample} samples got a PSNR of {overall_mean_psnr}")

Zeroshot model, 2718 samples got a PSNR of 23.582919260498823
Zeroshot model, 2984 samples got a PSNR of 23.983496113187947
Zeroshot model, 3274 samples got a PSNR of 24.38637371427694
Zeroshot model, 3594 samples got a PSNR of 24.791458907400727
Zeroshot model, 3944 samples got a PSNR of 25.170723082912954
Zeroshot model, 4328 samples got a PSNR of 25.554468106312356
Zeroshot model, 4750 samples got a PSNR of 25.904247721289373
Zeroshot model, 5214 samples got a PSNR of 26.256022556572205
Zeroshot model, 5722 samples got a PSNR of 26.588640954084457
Zeroshot model, 6280 samples got a PSNR of 26.916303829023033
