# Create data

In [1]:
%cd ..

import os
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision import datasets

from dataset import Vanilla
from model import DeepJSCC
from tqdm import tqdm
import pickle

d:\File\Repos\Deep-JSCC-PyTorch


In [2]:
model1_fp = r'alignment\models\seed42_v1.pkl'
model2_fp = r'alignment\models\seed43_v1.pkl'
saved = r'out\checkpoints\CIFAR10_8_7.0_0.17_AWGN_11h35m08s_on_Mar_27_2025\epoch_999.pkl'
snr = 7

dataset = "cifar10"
batch_size = 64
num_workers = 4
channel = 'AWGN'

device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

In [3]:
############
# GET DATA #
############

if dataset == 'cifar10':
    transform = transforms.Compose([transforms.ToTensor(), ])
    train_dataset = datasets.CIFAR10(root='../dataset/', train=True, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)
    test_dataset = datasets.CIFAR10(root='../dataset/', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)

elif dataset == 'imagenet':
    # the size of paper is 128
    transform = transforms.Compose( [transforms.ToTensor(), transforms.Resize((128, 128))])

    print("loading data of imagenet")
    train_dataset = datasets.ImageFolder(root='./dataset/ImageNet/train', transform=transform)

    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)
    test_dataset = Vanilla(root='./dataset/ImageNet/val', transform=transform)
    test_loader = DataLoader(test_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)

else:
    raise Exception('Unknown dataset')

In [4]:
def load_from_checkpoint(path, snr):
    state_dict = torch.load(path, map_location=device)
    from collections import OrderedDict
    new_state_dict = OrderedDict()

    for k, v in state_dict.items():
        name = k.replace('module.','') # remove `module.`
        new_state_dict[name] = v

    file_name = os.path.basename(os.path.dirname(saved))
    c = file_name.split('_')[1]
    c = int(c)
    model = DeepJSCC(c=c, channel_type=channel, snr=snr)

    model.load_state_dict(new_state_dict)
    model.change_channel(channel, snr)

    return model

model1 = load_from_checkpoint(model1_fp, snr).encoder
model2 = load_from_checkpoint(model2_fp, snr).encoder

In [5]:
class AlignmentDataset(Dataset):
    def __init__(self, dataloader, model1, model2, device='cpu'):
        self.outputs = []

        model1.eval()
        model1.to(device)

        model2.eval()
        model2.to(device)

        with torch.no_grad():
            for inputs, _ in tqdm(dataloader, desc="Computing model outputs"):
                inputs = inputs.to(device)

                out1 = model1(inputs)
                out2 = model2(inputs)

                for o1, o2 in zip(out1, out2):
                    o1 = o1.flatten()
                    o2 = o2.flatten()
                    self.outputs.append((o1.cpu(), o2.cpu()))

    def __len__(self):
        return len(self.outputs)

    def __getitem__(self, idx):
        return self.outputs[idx]

In [6]:
# 50.000 (samples) / 64 (batch size) = 781.25 -> 782 (batches to compute) 
data = AlignmentDataset(train_loader, model1, model2, device)

Computing model outputs: 100%|██████████| 782/782 [00:34<00:00, 22.90it/s]


# Train model

In [7]:
def dataset_to_matrices(dataset, batch_size=128):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    data_1 = []
    data_2 = []
    for batch in loader:
        data_1.append(batch[0])
        data_2.append(batch[1])
    return torch.cat(data_1, dim=0), torch.cat(data_2, dim=0)

matrix_1, matrix_2 = dataset_to_matrices(data)

In [11]:
#################
# LEAST SQUARES #
#################

Y = matrix_1.T
Z = matrix_2.T

Q = Y @ Z.T @ torch.inverse(Z @ Z.T)

with open(r'alignment\models\aligner.pkl', 'wb') as f:
    pickle.dump(Q, f)

In [None]:
#############################
# LINEAR MODEL SGD TRAINING #
#############################

X = matrix_1
Y = matrix_2

# define linear model
model = nn.Linear(1024, 1024, bias=True)

# define loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# training loop
num_epochs = 1000
for epoch, _ in tqdm(range(num_epochs)):
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, Y)
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f'Epoch [{epoch}/{num_epochs}], Loss: {loss.item():.6f}')

# get the weight and bias
weights = model.weight.data   # shape: [1024, 1024]
bias = model.bias.data        # shape: [1024]