In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import medmnist
from medmnist import BreastMNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
import lpips

In [3]:
# Define the LeNet model
class LeNet(nn.Module):
    def __init__(self, channel=1, hidden=768, num_classes=2):
        super(LeNet, self).__init__()
        act = nn.Sigmoid
        self.body = nn.Sequential(
            nn.Conv2d(channel, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1),
            act(),
        )
        self.fc = nn.Sequential(
            nn.Linear(hidden, num_classes)
        )
    
    def forward(self, x):
        out = self.body(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [4]:
def compute_mse(img1, img2):
    return ((img1 - img2) ** 2).mean()

In [5]:
# Function to initialize weights
def weights_init(m):
    try:
        if hasattr(m, "weight"):
            m.weight.data.uniform_(-0.5, 0.5)
    except Exception:
        print(f'Warning: failed in weights_init for {m._get_name()}.weight')
    try:
        if hasattr(m, "bias"):
            m.bias.data.uniform_(-0.5, 0.5)
    except Exception:
        print(f'Warning: failed in weights_init for {m._get_name()}.bias')

In [6]:
# Load BreastMNIST dataset
def load_breastmnist(data_flag='train'):
    transform = transforms.Compose([transforms.ToTensor()])
    dataset = BreastMNIST(split=data_flag, download=True, transform=transform)
    return dataset

In [7]:
dataset = 'BreastMNIST'
num_classes = 2
channel = 1
hidden = 588
lr = 1.0
num_dummy = 1
Iteration = 50
num_exp = 3

use_cuda = torch.cuda.is_available()
device = 'cuda' if use_cuda else 'cpu'

print(f'Using dataset: {dataset}, device: {device}')

Using dataset: BreastMNIST, device: cuda


In [8]:
lpips_model = lpips.LPIPS(net='alex').to(device)

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /home/anas/PrivFedTL/anasenv/lib/python3.10/site-packages/lpips/weights/v0.1/alex.pth


In [9]:
# Load dataset
dst = load_breastmnist()
dataloader = DataLoader(dst, batch_size=1, shuffle=True)

In [13]:
for idx_net in range(num_exp):
    net = LeNet(channel=channel, hidden=hidden, num_classes=num_classes)
    net.apply(weights_init)
    net.to(device)
    
    print(f'Running {idx_net+1}/{num_exp} experiment')

    for data, target in dataloader:
        data, target = data.to(device), target.to(device)
        target = target.view(-1).long()
        break  # Get a single batch
    
    # Compute original gradient
    criterion = nn.CrossEntropyLoss().to(device)
    out = net(data)
    loss = criterion(out, target)
    dy_dx = torch.autograd.grad(loss, net.parameters())
    original_dy_dx = [_.detach().clone() for _ in dy_dx]
    
    # Generate dummy data
    dummy_data = torch.randn_like(data, requires_grad=True, device=device)
    # dummy_label = torch.randn((data.shape[0], num_classes), requires_grad=True, device=device)
    
    optimizer = torch.optim.LBFGS([dummy_data], lr=lr)
    label_pred = torch.argmin(torch.sum(original_dy_dx[-2], dim=-1), dim=-1).detach().reshape((1,))

    history = []
    history_iters = []

    def closure():
        optimizer.zero_grad()
        pred = net(dummy_data)
        dummy_loss = criterion(pred, label_pred)
        dummy_dy_dx = torch.autograd.grad(dummy_loss, net.parameters(), create_graph=True)
        grad_diff = sum(((gx - gy) ** 2).sum() for gx, gy in zip(dummy_dy_dx, original_dy_dx))
        grad_diff.backward()
        return grad_diff
        
    for iters in range(Iteration):

        optimizer.step(closure)

        # Compute loss, MSE, SSIM and LPIP
        with torch.no_grad():
            mse_val = compute_mse(dummy_data.cpu().numpy(), data.cpu().numpy())
            ssim_val = ssim(dummy_data.cpu().squeeze().numpy(), data.cpu().squeeze().numpy(), data_range=1)
            
            resize_transform = transforms.Resize((32, 32))
            dummy_resized = resize_transform(dummy_data)
            data_resized = resize_transform(data)
            lpips_val = lpips_model(dummy_resized, data_resized).item()
        
        if iters % (Iteration // 10) == 0 or (iters+1) == Iteration:
            print(f"Iteration {iters+1}: Loss={closure().item():.6f}, MSE={mse_val:.6f}, SSIM={ssim_val:.6f}, LPIPS={lpips_val:.6f}")
            history.append(dummy_data.detach().cpu().squeeze().numpy())
            history_iters.append(iters)
    
    print(f'Experiment {idx_net+1}')
    fig, axes = plt.subplots(1, len(history) + 2, figsize=(15, 3))
    for i, (img, iter_num) in enumerate(zip(history, history_iters)):
        axes[i].imshow(img, cmap='gray')
        axes[i].set_title(f"Iter {iter_num}")
        axes[i].axis('off')
    
    axes[-2].imshow(dummy_data.detach().cpu().squeeze(), cmap='gray')
    axes[-2].set_title(f"Iter {history_iters[-1]}")
    axes[-2].axis('off')
    axes[-2].text(0.5, -0.15, f"Pred: {label_pred.item()}", size=12, ha="center", va="top", transform=axes[-2].transAxes)

    axes[-1].imshow(data.cpu().squeeze(), cmap='gray')
    axes[-1].set_title("Original")
    axes[-1].axis('off')
    axes[-1].text(0.5, -0.15, f"Label: {target.item()}", size=12, ha="center", va="top", transform=axes[-1].transAxes)
    
    plt.savefig(f"reconstruction_progress_exp{idx_net+1}.png")
    plt.close()

Running 1/3 experiment
Iteration 1: Loss=0.000202, MSE=1.107912, SSIM=0.000532, LPIPS=0.487206
Iteration 6: Loss=0.000005, MSE=0.193219, SSIM=0.109814, LPIPS=0.257825
Iteration 11: Loss=0.000000, MSE=0.039441, SSIM=0.476558, LPIPS=0.119238
Iteration 16: Loss=0.000000, MSE=0.028182, SSIM=0.581216, LPIPS=0.093872
Iteration 21: Loss=0.000000, MSE=0.028182, SSIM=0.581216, LPIPS=0.093872
Iteration 26: Loss=0.000000, MSE=0.028182, SSIM=0.581216, LPIPS=0.093872
Iteration 31: Loss=0.000000, MSE=0.028182, SSIM=0.581216, LPIPS=0.093872
Iteration 36: Loss=0.000000, MSE=0.028182, SSIM=0.581216, LPIPS=0.093872
Iteration 41: Loss=0.000000, MSE=0.028182, SSIM=0.581216, LPIPS=0.093872
Iteration 46: Loss=0.000000, MSE=0.028182, SSIM=0.581216, LPIPS=0.093872
Iteration 50: Loss=0.000000, MSE=0.028182, SSIM=0.581216, LPIPS=0.093872
Experiment 1
Running 2/3 experiment
Iteration 1: Loss=16.957926, MSE=0.781006, SSIM=0.014709, LPIPS=0.531675
Iteration 6: Loss=0.012822, MSE=0.012339, SSIM=0.790329, LPIPS=0.04