# Create data

In [1]:
model1 = r"alignment\models\seed42_v1.pkl"
model2 = r"alignment\models\seed43_v1.pkl"

dataset = "cifar10"
batch_size = 64
num_workers = 4
channel = 'AWGN'

In [2]:
%cd ..

d:\Github repos\Deep-JSCC-PyTorch


In [3]:
import os
from torch.utils.data import Dataset
import torch
from PIL import Image
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.optim as optim
from dataset import Vanilla
from tqdm import tqdm
import torch.nn as nn
import pickle
from channel import Channel
from model import DeepJSCC, _Encoder, _Decoder

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

# load data
if dataset == 'cifar10':
    transform = transforms.Compose([transforms.ToTensor(), ])
    train_dataset = datasets.CIFAR10(root='../dataset/', train=True,
                                        download=True, transform=transform)

    train_loader = DataLoader(train_dataset, shuffle=True,
                                batch_size=batch_size, num_workers=num_workers)
    test_dataset = datasets.CIFAR10(root='../dataset/', train=False,
                                    download=True, transform=transform)
    test_loader = DataLoader(test_dataset, shuffle=True,
                                batch_size=batch_size, num_workers=num_workers)
elif dataset == 'imagenet':
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Resize((128, 128))])  # the size of paper is 128
    print("loading data of imagenet")
    train_dataset = datasets.ImageFolder(root='./dataset/ImageNet/train', transform=transform)

    train_loader = DataLoader(train_dataset, shuffle=True,
                                batch_size=batch_size, num_workers=num_workers)
    test_dataset = Vanilla(root='./dataset/ImageNet/val', transform=transform)
    test_loader = DataLoader(test_dataset, shuffle=True,
                                batch_size=batch_size, num_workers=num_workers)
else:
    raise Exception('Unknown dataset')

In [5]:
model1_fp = r'alignment\models\seed42_v1.pkl'
model2_fp = r'alignment\models\seed43_v1.pkl'
saved = r'out\checkpoints\CIFAR10_8_7.0_0.17_AWGN_11h35m08s_on_Mar_27_2025\epoch_999.pkl'
snr = 7

In [6]:
def load_from_checkpoint(path, snr):
    state_dict = torch.load(path, map_location=torch.device('cpu'))
    from collections import OrderedDict
    new_state_dict = OrderedDict()

    for k, v in state_dict.items():
        name = k.replace('module.','') # remove `module.`
        new_state_dict[name] = v

    file_name = os.path.basename(os.path.dirname(saved))
    c = file_name.split('_')[1]
    c = int(c)
    model = DeepJSCC(c=c, channel_type=channel, snr=snr)

    model.load_state_dict(new_state_dict)
    model.change_channel(channel, snr)

    return model

model1 = load_from_checkpoint(model1_fp, snr).encoder
model2 = load_from_checkpoint(model2_fp, snr).encoder

In [7]:
class ModelOutputDataset(Dataset):
    def __init__(self, dataloader, model1, model2, device='cpu'):
        self.outputs = []

        # Ensure models are in eval mode
        model1.eval()
        model2.eval()

        # Move models to the correct device
        model1.to(device)
        model2.to(device)

        with torch.no_grad():
            for inputs, _ in tqdm(dataloader, desc="Computing model outputs"):
                inputs = inputs.to(device)

                out1 = model1(inputs)
                out2 = model2(inputs)

                for o1, o2 in zip(out1, out2):
                    o1 = o1.flatten()
                    o2 = o2.flatten()
                    self.outputs.append((o1.cpu(), o2.cpu()))

    def __len__(self):
        return len(self.outputs)

    def __getitem__(self, idx):
        return self.outputs[idx]

In [8]:
# 50.000 (samples) / 64 (batch size) = 781.25 -> 782 (batches to compute) 

data = ModelOutputDataset(train_loader, model1, model2, device)

Computing model outputs: 100%|██████████| 782/782 [01:38<00:00,  7.97it/s] 


In [9]:
data.__getitem__(0)[0].shape

torch.Size([1024])

In [10]:
# save to pickle?

# Train model

In [11]:
def dataset_to_matrices(dataset, batch_size=128):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    data_1 = []
    data_2 = []
    for batch in loader:
        data_1.append(batch[0])
        data_2.append(batch[1])
    return torch.cat(data_1, dim=0), torch.cat(data_2, dim=0)

# Usage
matrix_1, matrix_2 = dataset_to_matrices(data)

In [14]:
Y = matrix_1.T
Z = matrix_2.T

Q = Y @ Z.T @ torch.inverse(Z @ Z.T)

Q.shape

torch.Size([1024, 1024])

In [None]:
import torch
import torch.nn as nn
import pickle
from channel import Channel
from model import DeepJSCC, _Encoder, _Decoder

class AlignedDeepJSCC(nn.Module):
    def __init__(self, model1, model2, aligner):
        super(AlignedDeepJSCC, self).__init__()

        # get encoder from model1
        self.encoder = model1.encoder
        self.snr = model1.snr

        if self.snr is not None:
            self.channel = model1.channel

        # get aligner
        self.aligner = aligner

        # get decoder from model2
        self.decoder = model2.decoder

    def forward(self, x):
        z = self.encoder(x)
        if hasattr(self, 'channel') and self.channel is not None:
            z = self.channel(z)

        z = self.aligner(z)
        
        x_hat = self.decoder(z)
        return x_hat

    def change_channel(self, channel_type='AWGN', snr=None):
        if snr is None:
            self.channel = None
        else:
            self.channel = Channel(channel_type, snr)

    def get_channel(self):
        if hasattr(self, 'channel') and self.channel is not None:
            return self.channel.get_channel()
        return None

    def loss(self, prd, gt):
        criterion = nn.MSELoss(reduction='mean')
        loss = criterion(prd, gt)
        return loss

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

X = matrix_1
Y = matrix_2

# Define the linear model
model = nn.Linear(1024, 1024, bias=True)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training loop
num_epochs = 1000
for epoch, _ in tqdm(range(num_epochs)):
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, Y)
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f'Epoch [{epoch}/{num_epochs}], Loss: {loss.item():.6f}')

# Get the weight and bias
weights = model.weight.data   # Shape: [1024, 1024]
bias = model.bias.data        # Shape: [1024]