In [1]:
%cd ..

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import copy
import numpy as np
import random

from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from torchvision import datasets

from alignment.alignment_utils import load_deep_jscc, get_batch_psnr
from alignment.alignment_model import _ConvolutionalAlignment, _LinearAlignment, _ZeroShotAlignment, AlignedDeepJSCC
from utils import image_normalization
import matplotlib.pyplot as plt

from dataset import Vanilla
from model import DeepJSCC
from tqdm import tqdm
import pickle
from channel import Channel
from PIL import Image

from concurrent.futures import ThreadPoolExecutor
from functools import partial
from alignment.linear_models_gpu import Baseline

/home/lorenzo/repos/Deep-JSCC-PyTorch


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model1_fp = r'alignment/models/autoencoders/upscaled_42.pkl'
model2_fp = r'alignment/models/autoencoders/upscaled_43.pkl'

train_snr = None
val_snr = 7
times = 10
c = 8

dataset = "cifar10"
resolution = 96
folder = "psnr_vs_pilots_4"

batch_size = 64
num_workers = 4
channel = 'AWGN'

device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

samples = np.unique(np.logspace(0, np.log10(10000), num=100, base=10).astype(int))
seed = 42

os.makedirs(r'alignment/models/plots/'+folder, exist_ok=True)

test_mode = False

In [3]:
def set_seed(seed):
    random.seed(seed)                      # Python RNG
    np.random.seed(seed)                   # NumPy RNG
    torch.manual_seed(seed)                # PyTorch CPU RNG
    torch.cuda.manual_seed(seed)           # PyTorch GPU RNG
    torch.cuda.manual_seed_all(seed)       # All GPUs
    torch.backends.cudnn.deterministic = True   # Deterministic cuDNN
    torch.backends.cudnn.benchmark = False      # Disable benchmark to ensure reproducibility

# Init

## Data

In [4]:
############
# GET DATA #
############

if dataset == 'cifar10':
    transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((resolution, resolution))])

    train_dataset = datasets.CIFAR10(root='../dataset/', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)

    test_dataset = datasets.CIFAR10(root='../dataset/', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)

elif dataset == 'imagenet':
    # the size of paper is 128
    transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((resolution, resolution))])

    print("loading data of imagenet")

    train_dataset = datasets.ImageFolder(root='./dataset/ImageNet/train', transform=transform)
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)

    test_dataset = Vanilla(root='./dataset/ImageNet/val', transform=transform)
    test_loader = DataLoader(test_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)

elif dataset == 'imagenette':
    transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((resolution, resolution))])

    train_dataset = datasets.Imagenette(root='../dataset/', split="train", download=True, transform=transform)
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)

    test_dataset = datasets.Imagenette(root='../dataset/', split="val", download=True, transform=transform)
    test_loader = DataLoader(test_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)

else:
    raise Exception('Unknown dataset')

In [5]:
class AlignmentDataset(Dataset):
    def __init__(self, dataloader, model1, model2, flat=False):
        self.outputs = []

        model1.eval()
        model1.to(device)

        model2.eval()
        model2.to(device)

        with torch.no_grad():
            for inputs, _ in tqdm(dataloader, desc="Computing model outputs"):
                inputs = inputs.to(device)

                out1 = model1(inputs)
                out2 = model2(inputs)

                for o1, o2 in zip(out1, out2):
                    if flat:
                        o1 = o1.flatten()
                        o2 = o2.flatten()

                    self.outputs.append((o1.cpu(), o2.cpu()))

    def __len__(self):
        return len(self.outputs)

    def __getitem__(self, idx):
        return self.outputs[idx]  

## Utils

In [6]:
def validation(model, dataloader, times):

    model = model.to(device)
    batch_psnr_list = []

    with torch.no_grad():
        for inputs, _ in dataloader:

            inputs = inputs.to(device)
            psnr = torch.zeros(size=(inputs.shape[0], ), device=device)

            for _ in range(times):
                print(inputs.shape)

                demo_image = model(inputs)
                demo_image = image_normalization('denormalization')(demo_image)
                gt = image_normalization('denormalization')(inputs)
                psnr += get_batch_psnr(demo_image, gt)

            psnr /= times
            batch_mean_psnr = psnr.mean().item()
            batch_psnr_list.append(batch_mean_psnr)

    overall_mean_psnr = sum(batch_psnr_list) / len(batch_psnr_list)

    return overall_mean_psnr

def load_from_checkpoint(path, snr):
    state_dict = torch.load(path, map_location=device)
    from collections import OrderedDict
    new_state_dict = OrderedDict()

    for k, v in state_dict.items():
        name = k.replace('module.','') # remove `module.`
        new_state_dict[name] = v

    model = DeepJSCC(c=c, channel_type=channel, snr=snr)

    model.load_state_dict(new_state_dict)
    model.change_channel(channel, snr)

    return model

## Validation parallelized

In [7]:
def validation_worker(model, inputs, gt, times, worker_id):
    """Worker function for parallel model inference"""
    model.eval()
    psnr_sum = torch.zeros(inputs.shape[0], device=inputs.device)
    
    with torch.no_grad():
        for _ in range(times):
            demo_image = model(inputs)
            demo_image = image_normalization('denormalization')(demo_image)
            psnr_sum += get_batch_psnr(demo_image, gt)
    
    return psnr_sum / times

def validation_parallel_inference(model, dataloader, times, num_workers=None):
    """Version with parallel inference for multiple runs"""
    model = model.to(device)
    model.eval()
    
    # Auto-detect optimal number of workers
    if num_workers is None:
        num_workers = min(times, torch.cuda.device_count() if torch.cuda.is_available() else mp.cpu_count())
    
    total_psnr = 0.0
    total_samples = 0
    
    with torch.no_grad():
        for inputs, *_ in dataloader:
            inputs = inputs.to(device)
            batch_size = inputs.shape[0]
            
            # Denormalize ground truth once
            gt = image_normalization('denormalization')(inputs)
            
            if times == 1:
                # No need for parallelization with single run
                demo_image = model(inputs)
                demo_image = image_normalization('denormalization')(demo_image)
                batch_psnr = get_batch_psnr(demo_image, gt).sum().item()
            else:
                # Parallel inference for multiple runs
                runs_per_worker = times // num_workers
                remaining_runs = times % num_workers
                
                batch_psnr_sum = torch.zeros(batch_size, device=device)
                
                # Use thread pool for GPU parallelization (better for CUDA)
                with ThreadPoolExecutor(max_workers=num_workers) as executor:
                    futures = []
                    
                    # Submit jobs with different number of runs per worker
                    for i in range(num_workers):
                        worker_runs = runs_per_worker + (1 if i < remaining_runs else 0)
                        if worker_runs > 0:
                            future = executor.submit(
                                validation_worker, 
                                model, inputs, gt, worker_runs, i
                            )
                            futures.append(future)
                    
                    # Collect results
                    for future in futures:
                        batch_psnr_sum += future.result()
                
                batch_psnr = batch_psnr_sum.sum().item()
            
            total_psnr += batch_psnr
            total_samples += batch_size
    
    return total_psnr / total_samples

def validation_parallel_batches(model, dataloader, times, num_workers=4):
    """Version with parallel batch processing"""
    model = model.to(device)
    model.eval()
    
    def process_batch(batch_data):
        inputs, *_ = batch_data
        inputs = inputs.to(device)
        batch_size = inputs.shape[0]
        
        # Denormalize ground truth once
        gt = image_normalization('denormalization')(inputs)
        
        # Accumulate PSNR across multiple runs
        batch_psnr_sum = torch.zeros(batch_size, device=device)
        
        with torch.no_grad():
            for _ in range(times):
                demo_image = model(inputs)
                demo_image = image_normalization('denormalization')(demo_image)
                batch_psnr_sum += get_batch_psnr(demo_image, gt)
        
        batch_mean_psnr = (batch_psnr_sum / times).sum().item()
        return batch_mean_psnr, batch_size
    
    # Process batches in parallel
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        results = list(executor.map(process_batch, dataloader))
    
    total_psnr = sum(psnr for psnr, _ in results)
    total_samples = sum(samples for _, samples in results)
    
    return total_psnr / total_samples

def validation_vectorized(model, dataloader, times):
    """Vectorized version for maximum efficiency when memory allows"""
    model = model.to(device)
    model.eval()
    
    total_psnr = 0.0
    total_samples = 0
    
    with torch.no_grad():
        for inputs, *_ in dataloader:
            inputs = inputs.to(device)
            batch_size = inputs.shape[0]
            
            # Denormalize ground truth once
            gt = image_normalization('denormalization')(inputs)
            
            if times == 1:
                demo_image = model(inputs)
                demo_image = image_normalization('denormalization')(demo_image)
                batch_psnr = get_batch_psnr(demo_image, gt).sum().item()
            else:
                # Vectorized computation - process all runs at once
                # Repeat inputs for all runs
                inputs_repeated = inputs.repeat(times, 1, 1, 1)
                
                # Single forward pass for all runs
                demo_images = model(inputs_repeated)
                demo_images = image_normalization('denormalization')(demo_images)
                
                # Reshape to separate runs and batch dimension
                demo_images = demo_images.view(times, batch_size, *demo_images.shape[1:])
                gt_repeated = gt.unsqueeze(0).repeat(times, 1, 1, 1, 1)
                
                # Compute PSNR for all runs at once
                psnr_all_runs = torch.stack([
                    get_batch_psnr(demo_images[i], gt_repeated[i]) 
                    for i in range(times)
                ])
                
                # Average across runs and sum across batch
                batch_psnr = psnr_all_runs.mean(dim=0).sum().item()
            
            total_psnr += batch_psnr
            total_samples += batch_size
    
    return total_psnr / total_samples

# Main validation function with automatic method selection
def validation_2(model, dataloader, times, method='auto', num_workers=None):
    """
    Optimized validation with multiple parallelization strategies
    
    Args:
        model: The model to validate
        dataloader: Data loader for validation data
        times: Number of inference runs per sample
        method: 'auto', 'vectorized', 'parallel_inference', 'parallel_batches', or 'sequential'
        num_workers: Number of parallel workers (auto-detected if None)
    """
    
    if method == 'auto':
        # Auto-select best method based on conditions
        if times == 1:
            method = 'sequential'
        elif times <= 4 and torch.cuda.is_available():
            method = 'vectorized'  # Best for GPU with moderate times
        elif times > 4:
            method = 'parallel_inference'  # Best for many inference runs
        else:
            method = 'parallel_batches'  # Best for CPU or complex cases
    
    if method == 'vectorized':
        return validation_vectorized(model, dataloader, times)
    elif method == 'parallel_inference':
        return validation_parallel_inference(model, dataloader, times, num_workers)
    elif method == 'parallel_batches':
        return validation_parallel_batches(model, dataloader, times, num_workers)
    else:  # sequential
        return validation_sequential(model, dataloader, times)

def validation_sequential(model, dataloader, times):
    """Original optimized sequential version for comparison"""
    model = model.to(device)
    model.eval()
    
    total_psnr = 0.0
    total_samples = 0
    
    with torch.no_grad():
        for inputs, *_ in dataloader:
            inputs = inputs.to(device)
            batch_size = inputs.shape[0]
            
            gt = image_normalization('denormalization')(inputs)
            batch_psnr_sum = torch.zeros(batch_size, device=device)
            
            for _ in range(times):
                demo_image = model(inputs)
                demo_image = image_normalization('denormalization')(demo_image)
                batch_psnr_sum += get_batch_psnr(demo_image, gt)
            
            batch_mean_psnr = (batch_psnr_sum / times).sum().item()
            total_psnr += batch_mean_psnr
            total_samples += batch_size
    
    return total_psnr / total_samples

# No mismatch - Unaligned

In [None]:
model1 = load_deep_jscc(model1_fp, val_snr, c, "AWGN")
model2 = load_deep_jscc(model2_fp, val_snr, c, "AWGN")

encoder = copy.deepcopy(model1.encoder)
decoder = copy.deepcopy(model2.decoder)

model = AlignedDeepJSCC(encoder, decoder, None, val_snr, "AWGN")
print(f"Unaligned {validation_vectorized(model, test_loader, times):.2f}")

In [None]:
model1 = load_deep_jscc(model1_fp, val_snr, c, "AWGN")

encoder = copy.deepcopy(model1.encoder)
decoder = copy.deepcopy(model1.decoder)

model = AlignedDeepJSCC(encoder, decoder, None, val_snr, "AWGN")

print(f"Aligned {validation_vectorized(model, test_loader, times):.2f}")

# Least Squares

In [None]:
def dataset_to_matrices(dataset, batch_size=128):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    data_1 = []
    data_2 = []
    
    for batch in loader:
        data_1.append(batch[0])
        data_2.append(batch[1])

    return torch.cat(data_1, dim=0), torch.cat(data_2, dim=0)

def aligner_least_squares(matrix_1, matrix_2):
    Y = matrix_1.T
    Z = matrix_2.T

    Q = Y @ Z.T @ torch.inverse(Z @ Z.T)

    return _LinearAlignment(align_matrix=Q)

def aligner_least_squares(matrix_1, matrix_2, n_samples):
    Y = matrix_1.T  # [d, n]
    Z = matrix_2.T  # [d, n]

    ZZ_T = Z @ Z.T
    YZ_T = Y @ Z.T

    reg_matrix = (10000) * torch.eye(ZZ_T.size(0), device=ZZ_T.device, dtype=ZZ_T.dtype)
    Q = YZ_T @ torch.linalg.inv(ZZ_T + reg_matrix)

    return _LinearAlignment(align_matrix=Q)

In [None]:
model1 = load_from_checkpoint(model1_fp, train_snr).encoder
model2 = load_from_checkpoint(model2_fp, train_snr).encoder

data = AlignmentDataset(train_loader, model1, model2, flat=True)

In [None]:
set_seed(seed)
permutation = torch.randperm(len(data))

for sample in samples:
    indices = permutation[:sample]
    subset = Subset(data, indices)

    matrix_1, matrix_2 = dataset_to_matrices(subset)

    aligner = aligner_least_squares(matrix_1, matrix_2, sample)

    with open(r'alignment/models/plots/'+folder+'/aligner_linear_'+str(sample)+'.pkl', 'wb') as f:
        pickle.dump(aligner, f)

    print(f"Done with {sample}.")

In [8]:
set_seed(seed)

model1 = load_deep_jscc(model1_fp, val_snr, c, "AWGN")
model2 = load_deep_jscc(model2_fp, val_snr, c, "AWGN")

encoder = copy.deepcopy(model1.encoder)
decoder = copy.deepcopy(model2.decoder)

for sample in samples:

    aligner_fp = r'alignment/models/plots/'+folder+'/aligner_linear_'+str(sample)+'.pkl'

    with open(aligner_fp, 'rb') as f:
        aligner = pickle.load(f)

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, channel)

    print(f"Linear model, {sample} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times):.2f}")

Linear model, 1 samples got a PSNR of 11.50
Linear model, 2 samples got a PSNR of 11.09
Linear model, 3 samples got a PSNR of 11.07
Linear model, 4 samples got a PSNR of 11.28
Linear model, 5 samples got a PSNR of 11.41
Linear model, 6 samples got a PSNR of 11.30
Linear model, 7 samples got a PSNR of 11.48
Linear model, 8 samples got a PSNR of 11.44
Linear model, 9 samples got a PSNR of 11.17
Linear model, 10 samples got a PSNR of 11.23
Linear model, 11 samples got a PSNR of 11.18
Linear model, 12 samples got a PSNR of 11.12
Linear model, 13 samples got a PSNR of 11.10
Linear model, 14 samples got a PSNR of 11.12
Linear model, 16 samples got a PSNR of 11.92
Linear model, 17 samples got a PSNR of 11.61
Linear model, 19 samples got a PSNR of 11.61
Linear model, 21 samples got a PSNR of 12.00
Linear model, 23 samples got a PSNR of 11.85
Linear model, 25 samples got a PSNR of 11.71
Linear model, 28 samples got a PSNR of 11.86
Linear model, 31 samples got a PSNR of 11.67
Linear model, 34 sa

In [None]:
if test_mode:
    sample = 10000

    set_seed(seed)
    permutation = torch.randperm(len(data))

    indices = permutation[:sample]
    subset = Subset(data, indices)

    matrix_1, matrix_2 = dataset_to_matrices(subset)

    aligner = aligner_least_squares(matrix_1, matrix_2, sample)

    print(f"Done with {sample}.")

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, channel)

    print(f"Linear model, {sample} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times):.2f}")

# Linear Neural

In [None]:
model1 = load_from_checkpoint(model1_fp, train_snr).encoder
model2 = load_from_checkpoint(model2_fp, train_snr).encoder

data = AlignmentDataset(train_loader, model1, model2, flat=False)

In [None]:
set_seed(seed)
permutation = torch.randperm(len(data))

for sample in samples:
    indices = permutation[:sample]
    subset = Subset(data, indices)
    dataloader = DataLoader(subset, batch_size=batch_size, shuffle=True)

    epochs_max = 10000
    ratio = 6
    patience = 10
    check_interval = 1
    min_delta = 1e-5
    lambda_reg = 0.001

    aligner = _LinearAlignment(size=resolution * resolution * 3 * 2 // ratio).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(aligner.parameters(), lr=1e-3, weight_decay=lambda_reg)
    channel = Channel("AWGN", train_snr)

    best_loss = float('inf')
    best_model_state = None
    checks_without_improvement = 0
    epoch = 0

    while True:
        epoch_loss = 0.0

        for inputs, targets in dataloader:
            
            if train_snr is not None:
                inputs = channel(inputs)

            optimizer.zero_grad()
            outputs = aligner(inputs.to(device))

            mse_loss = criterion(outputs, targets.to(device))
            loss = inputs.shape[0] * mse_loss

            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        epoch += 1

        if epoch % check_interval == 0:
            avg_loss = epoch_loss / len(dataloader)
            if best_loss - avg_loss > min_delta:
                best_loss = avg_loss
                best_model_state = copy.deepcopy(aligner.state_dict())
                checks_without_improvement = 0
            else:
                checks_without_improvement += 1

            if checks_without_improvement >= patience:
                break

        if epoch > epochs_max:
            break

    print(f"Done with {sample}. Trained for {epoch} epochs.")

    # Restore best model
    if best_model_state is not None:
        aligner.load_state_dict(best_model_state)

    aligner_fp = f'alignment/models/plots/{folder}/aligner_neural_{sample}.pkl'

    with open(aligner_fp, 'wb') as f:
        pickle.dump(aligner.to("cpu"), f)

In [8]:
set_seed(seed)

model1 = load_deep_jscc(model1_fp, val_snr, c, "AWGN")
model2 = load_deep_jscc(model2_fp, val_snr, c, "AWGN")

encoder = copy.deepcopy(model1.encoder)
decoder = copy.deepcopy(model2.decoder)

for sample in samples:

    aligner_fp = r'alignment/models/plots/'+folder+'/aligner_neural_'+str(sample)+'.pkl'

    with open(aligner_fp, 'rb') as f:
        aligner = pickle.load(f)

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"Neural model, {sample} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times):.2f}")

Neural model, 1 samples got a PSNR of 10.97
Neural model, 2 samples got a PSNR of 10.91
Neural model, 3 samples got a PSNR of 10.92
Neural model, 4 samples got a PSNR of 10.93
Neural model, 5 samples got a PSNR of 11.04
Neural model, 6 samples got a PSNR of 11.10
Neural model, 7 samples got a PSNR of 11.13
Neural model, 8 samples got a PSNR of 11.22
Neural model, 9 samples got a PSNR of 11.24
Neural model, 10 samples got a PSNR of 11.26
Neural model, 11 samples got a PSNR of 11.33
Neural model, 12 samples got a PSNR of 11.44
Neural model, 13 samples got a PSNR of 11.57
Neural model, 14 samples got a PSNR of 11.64
Neural model, 16 samples got a PSNR of 11.73
Neural model, 17 samples got a PSNR of 11.81
Neural model, 19 samples got a PSNR of 11.79
Neural model, 21 samples got a PSNR of 12.04
Neural model, 23 samples got a PSNR of 12.22
Neural model, 25 samples got a PSNR of 12.19
Neural model, 28 samples got a PSNR of 12.38
Neural model, 31 samples got a PSNR of 12.44
Neural model, 34 sa

In [None]:
if test_mode:
    sample = 10000

    set_seed(seed)
    permutation = torch.randperm(len(data))

    indices = permutation[:sample]
    subset = Subset(data, indices)
    dataloader = DataLoader(subset, batch_size=batch_size, shuffle=True)

    epochs_max = 10000
    ratio = 6
    patience = 10
    check_interval = 1
    min_delta = 1e-5
    lambda_reg = 0.001

    aligner = _LinearAlignment(size=resolution * resolution * 3 * 2 // ratio).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(aligner.parameters(), lr=1e-3, weight_decay=lambda_reg)
    channel = Channel("AWGN", 30)

    best_loss = float('inf')
    best_model_state = None
    checks_without_improvement = 0
    epoch = 0

    while True:
        epoch_loss = 0.0

        for inputs, targets in dataloader:

            if train_snr is not None:
                inputs = channel(inputs)
                
            optimizer.zero_grad()
            outputs = aligner(inputs.to(device))

            mse_loss = criterion(outputs, targets.to(device))
            loss = inputs.shape[0] * mse_loss

            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        epoch += 1

        if epoch % check_interval == 0:
            avg_loss = epoch_loss / len(dataloader)
            if best_loss - avg_loss > min_delta:
                best_loss = avg_loss
                best_model_state = copy.deepcopy(aligner.state_dict())
                checks_without_improvement = 0
            else:
                checks_without_improvement += 1

            if checks_without_improvement >= patience:
                break

        if epoch > epochs_max:
            break

    print(f"Done with {sample}. Trained for {epoch} epochs.")

    # Restore best model
    if best_model_state is not None:
        aligner.load_state_dict(best_model_state)

    model1 = load_deep_jscc(model1_fp, val_snr, c, "AWGN")
    model2 = load_deep_jscc(model2_fp, val_snr, c, "AWGN")

    encoder = copy.deepcopy(model1.encoder)
    decoder = copy.deepcopy(model2.decoder)

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"Neural model, {sample} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times):.2f}")

# Convolutional

In [8]:
model1 = load_from_checkpoint(model1_fp, train_snr).encoder
model2 = load_from_checkpoint(model2_fp, train_snr).encoder

data = AlignmentDataset(train_loader, model1, model2, flat=False)

Computing model outputs: 100%|██████████| 782/782 [00:09<00:00, 84.41it/s] 


In [9]:
set_seed(seed)
permutation = torch.randperm(len(data))

for sample in samples:
    indices = permutation[:sample]
    subset = Subset(data, indices)
    dataloader = DataLoader(subset, batch_size=batch_size, shuffle=True)

    epochs_max=10000
    ratio=6
    patience=10
    check_interval=1
    min_delta=1e-5
    reg_val = 0.001

    aligner = _ConvolutionalAlignment(in_channels=2*c, out_channels=2*c, kernel_size=5).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(aligner.parameters(), lr=1e-3, weight_decay=reg_val)
    channel = Channel("AWGN", train_snr)

    best_loss = float('inf')
    checks_without_improvement = 0
    epoch = 0

    while True:
        epoch_loss = 0.0

        for inputs, targets in dataloader:

            if train_snr is not None:
                inputs = channel(inputs)

            optimizer.zero_grad()
            outputs = aligner(inputs.to(device))
            loss = criterion(outputs, targets.to(device))
            loss = loss * inputs.shape[0] # scale by batch size
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        epoch += 1

        if epoch % check_interval == 0:
            avg_loss = epoch_loss / len(dataloader)
            if best_loss - avg_loss > min_delta:
                best_loss = avg_loss
                best_model_state = copy.deepcopy(aligner.state_dict())
                checks_without_improvement = 0
            else:
                checks_without_improvement += 1

            if checks_without_improvement >= patience:
                break

        if epoch > epochs_max:
            break

    print(f"Done with {sample}. Trained for {epoch} epochs.")

    # Restore best model
    if best_model_state is not None:
        aligner.load_state_dict(best_model_state)

    aligner_fp = r'alignment/models/plots/'+folder+'/aligner_conv_'+str(sample)+'.pkl'

    with open(aligner_fp, 'wb') as f:
        pickle.dump(aligner.to("cpu"), f)

Done with 1. Trained for 168 epochs.
Done with 2. Trained for 204 epochs.
Done with 3. Trained for 289 epochs.
Done with 4. Trained for 250 epochs.
Done with 5. Trained for 237 epochs.
Done with 6. Trained for 217 epochs.
Done with 7. Trained for 278 epochs.
Done with 8. Trained for 305 epochs.
Done with 9. Trained for 280 epochs.
Done with 10. Trained for 293 epochs.
Done with 11. Trained for 299 epochs.
Done with 12. Trained for 334 epochs.
Done with 13. Trained for 315 epochs.
Done with 14. Trained for 324 epochs.
Done with 16. Trained for 341 epochs.
Done with 17. Trained for 350 epochs.
Done with 19. Trained for 342 epochs.
Done with 21. Trained for 397 epochs.
Done with 23. Trained for 341 epochs.
Done with 25. Trained for 368 epochs.
Done with 28. Trained for 309 epochs.
Done with 31. Trained for 368 epochs.
Done with 34. Trained for 357 epochs.
Done with 37. Trained for 457 epochs.
Done with 41. Trained for 414 epochs.
Done with 45. Trained for 421 epochs.
Done with 49. Trained

In [10]:
set_seed(seed)

model1 = load_deep_jscc(model1_fp, val_snr, c, "AWGN")
model2 = load_deep_jscc(model2_fp, val_snr, c, "AWGN")

encoder = copy.deepcopy(model1.encoder)
decoder = copy.deepcopy(model2.decoder)

for sample in samples:

    aligner_fp = r'alignment/models/plots/'+folder+'/aligner_conv_'+str(sample)+'.pkl'

    with open(aligner_fp, 'rb') as f:
        aligner = pickle.load(f)

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"Conv model, {sample} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times):.2f}")

Conv model, 1 samples got a PSNR of 19.85
Conv model, 2 samples got a PSNR of 23.71
Conv model, 3 samples got a PSNR of 23.80
Conv model, 4 samples got a PSNR of 24.41
Conv model, 5 samples got a PSNR of 25.30
Conv model, 6 samples got a PSNR of 25.68
Conv model, 7 samples got a PSNR of 26.41
Conv model, 8 samples got a PSNR of 26.97
Conv model, 9 samples got a PSNR of 27.01
Conv model, 10 samples got a PSNR of 27.43
Conv model, 11 samples got a PSNR of 27.43
Conv model, 12 samples got a PSNR of 27.43
Conv model, 13 samples got a PSNR of 27.86
Conv model, 14 samples got a PSNR of 28.20
Conv model, 16 samples got a PSNR of 27.88
Conv model, 17 samples got a PSNR of 28.30
Conv model, 19 samples got a PSNR of 28.42
Conv model, 21 samples got a PSNR of 28.26
Conv model, 23 samples got a PSNR of 28.24
Conv model, 25 samples got a PSNR of 28.13
Conv model, 28 samples got a PSNR of 28.39
Conv model, 31 samples got a PSNR of 28.20
Conv model, 34 samples got a PSNR of 28.22
Conv model, 37 sampl

In [None]:
if test_mode:
    
    sample = 10000

    set_seed(seed)
    permutation = torch.randperm(len(data))

    indices = permutation[:sample]
    subset = Subset(data, indices)
    dataloader = DataLoader(subset, batch_size=batch_size, shuffle=True)

    epochs_max=10000
    ratio=6
    patience=10
    check_interval=1
    min_delta=1e-5
    reg_val = 0.001

    aligner = _ConvolutionalAlignment(in_channels=2*c, out_channels=2*c, kernel_size=5).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(aligner.parameters(), lr=1e-3, weight_decay=reg_val)
    channel = Channel("AWGN", train_snr)

    best_loss = float('inf')
    checks_without_improvement = 0
    epoch = 0

    while True:
        epoch_loss = 0.0

        for inputs, targets in dataloader:

            if train_snr is not None:
                inputs = channel(inputs)

            optimizer.zero_grad()
            outputs = aligner(inputs.to(device))
            loss = criterion(outputs, targets.to(device))
            loss = loss * inputs.shape[0] # scale by batch size
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        epoch += 1

        if epoch % check_interval == 0:
            avg_loss = epoch_loss / len(dataloader)
            if best_loss - avg_loss > min_delta:
                best_loss = avg_loss
                best_model_state = copy.deepcopy(aligner.state_dict())
                checks_without_improvement = 0
            else:
                checks_without_improvement += 1

            if checks_without_improvement >= patience:
                break

        if epoch > epochs_max:
            break

    print(f"Done with {sample}. Trained for {epoch} epochs.")

    # Restore best model
    if best_model_state is not None:
        aligner.load_state_dict(best_model_state)

    model1 = load_deep_jscc(model1_fp, val_snr, c, "AWGN")
    model2 = load_deep_jscc(model2_fp, val_snr, c, "AWGN")

    encoder = copy.deepcopy(model1.encoder)
    decoder = copy.deepcopy(model2.decoder)

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, val_snr, "AWGN")

    print(f"Conv model, {sample} samples got a PSNR of {validation_vectorized(aligned_model, test_loader, times):.2f}")

# Zero-shot

In [8]:
model1 = load_from_checkpoint(model1_fp, train_snr).encoder
model2 = load_from_checkpoint(model2_fp, train_snr).encoder

data = AlignmentDataset(train_loader, model1, model2, flat=True)

Computing model outputs: 100%|██████████| 782/782 [00:12<00:00, 64.75it/s]


In [9]:
set_seed(seed)
permutation = torch.randperm(len(data))

samples_zeroshot = set()
for num in samples:
    if num % 2 != 0:
        num += 1
    samples_zeroshot.add(num)
samples_zeroshot = sorted(samples_zeroshot)

In [None]:
for sample in samples_zeroshot[1:]:
    indices = permutation[:sample]
    subset = Subset(data, indices)

    dataloader = DataLoader(subset, batch_size=len(subset))
    input, output = next(iter(dataloader))

    flattened_image_size = resolution * resolution

    try:

        baseline = Baseline(
            input_dim=flattened_image_size,
            output_dim=flattened_image_size,
            channel_matrix=torch.eye(1, dtype=torch.complex64),
            snr=train_snr,
            channel_usage=None,
            typology='pre',
            strategy='PFE',
            use_channel=True if train_snr is not None else False,
            seed=seed,
        )

        baseline.fit(input, output)

    except RuntimeError:
        
        print(f"Skipped {sample}.")
        continue

    aligner_fp = r'alignment/models/plots/'+folder+'/aligner_zeroshot_'+str(sample)+'.pkl'

    with open(aligner_fp, 'wb') as f:
        pickle.dump(baseline, f)

    print(f"Done with {sample}.")

In [11]:
set_seed(seed)

model1 = load_deep_jscc(model1_fp, val_snr, c, "AWGN")
model2 = load_deep_jscc(model2_fp, val_snr, c, "AWGN")

encoder = copy.deepcopy(model1.encoder).to(device)
decoder = copy.deepcopy(model2.decoder).to(device)

for sample in samples_zeroshot[-1:]:

    aligner_fp = r'alignment/models/plots/'+folder+'/aligner_zeroshot_'+str(sample)+'.pkl'

    with open(aligner_fp, 'rb') as f:
        aligner = pickle.load(f)

    if val_snr is not None:
        aligner.use_channel = True
        aligner.snr = val_snr
    
    else:
        aligner.use_channel = False
        aligner.snr = val_snr

    batch_psnr_list = []

    with torch.no_grad():
    
        for inputs, _ in test_loader:

            inputs = inputs.to(device)
            inputs = inputs.squeeze()
            psnr_all = 0.0

            for _ in range(times):

                demo_image = encoder(inputs)
                shape = demo_image.shape

                demo_image = demo_image.reshape(shape[0], -1)
                demo_image = aligner.transform(demo_image)

                demo_image = demo_image.reshape(shape)
                demo_image = decoder(demo_image)

                demo_image = image_normalization('denormalization')(demo_image)

                gt = image_normalization('denormalization')(inputs)
                psnr_all += get_batch_psnr(demo_image, gt)

            psnr_all /= times
            batch_mean_psnr = psnr_all.mean().item()
            batch_psnr_list.append(batch_mean_psnr)

    overall_mean_psnr = sum(batch_psnr_list) / len(batch_psnr_list)

    print(f"Zeroshot model, {sample} samples got a PSNR of {overall_mean_psnr}")

Zeroshot model, 10000 samples got a PSNR of 24.807317502939018


In [None]:
if test_mode:
    set_seed(seed)
    permutation = torch.randperm(len(data))

    sample = 10000

    indices = permutation[:sample]
    subset = Subset(data, indices)

    dataloader = DataLoader(subset, batch_size=len(subset))
    input, output = next(iter(dataloader))

    flattened_image_size = resolution * resolution

    baseline = Baseline(
        input_dim=flattened_image_size,
        output_dim=flattened_image_size,
        channel_matrix=torch.eye(1, dtype=torch.complex64),
        snr=train_snr,
        channel_usage=None,
        typology='pre',
        strategy='PFE',
        use_channel=True if train_snr is not None else False,
        seed=seed,
    )

    baseline.fit(input, output)

    print(f"Done with {sample}.")

    set_seed(seed)

    model1 = load_deep_jscc(model1_fp, val_snr, c, "AWGN")
    model2 = load_deep_jscc(model2_fp, val_snr, c, "AWGN")

    encoder = copy.deepcopy(model1.encoder).to(device)
    decoder = copy.deepcopy(model2.decoder).to(device)
    aligner = baseline

    batch_psnr_list = []

    with torch.no_grad():

        for inputs, _ in test_loader:

            inputs = inputs.to(device)
            inputs = inputs.squeeze()
            psnr_all = 0.0

            for _ in range(times):

                demo_image = encoder(inputs)
                shape = demo_image.shape

                demo_image = demo_image.reshape(shape[0], -1)
                demo_image = aligner.transform(demo_image)

                demo_image = demo_image.reshape(shape)
                demo_image = decoder(demo_image)

                demo_image = image_normalization('denormalization')(demo_image)

                gt = image_normalization('denormalization')(inputs)
                psnr_all += get_batch_psnr(demo_image, gt)

            psnr_all /= times
            batch_mean_psnr = psnr_all.mean().item()
            batch_psnr_list.append(batch_mean_psnr)

    overall_mean_psnr = sum(batch_psnr_list) / len(batch_psnr_list)

    print(f"Zeroshot model, {sample} samples got a PSNR of {overall_mean_psnr}")

Done with 10000.
Zeroshot model, 10000 samples got a PSNR of 23.538920226370454
