# Small Example of iDLG on CIFAR10

In [None]:
from matplotlib import pyplot as plt 
import skimage 
import numpy as np 
from tqdm import tqdm
import torch as torch
import torch.nn as nn
from torchvision import transforms, datasets
from torch.optim.lbfgs import LBFGS
from PIL import Image

In [None]:
class iDLG:
    def __init__(
        self,
        model,
        orig_img,
        label,
        device,
        *,
        seed: int | None = None,
        clamp: tuple[float, float] | None = (0.0, 1.0),
    ) -> None:
        # Respect provided device and keep original dtype of the model/weights
        self.device = device if isinstance(device, str) else (device.type if hasattr(device, "type") else "cpu")
        self.model = model.to(self.device)
        self.orig_img = orig_img.to(self.device)
        self.criterion = nn.CrossEntropyLoss(reduction='sum').to(self.device)
        self.label = label.to(self.device)
        self.tt = transforms.ToPILImage()
        self.clamp = clamp

        # Align image dtype to model parameter dtype (usually float32)
        self.param_dtype = next(self.model.parameters()).dtype
        if self.orig_img.dtype != self.param_dtype:
            self.orig_img = self.orig_img.to(self.param_dtype)

        if seed is not None:
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)

    def _infer_label_from_grads(self, orig_grads):
        # Map grads to names
        named_grads = {name: g for (name, _), g in zip(self.model.named_parameters(), orig_grads)}
        last_bias_name = None
        for name, param in self.model.named_parameters():
            if name.endswith(".bias") and param.ndim == 1:
                last_bias_name = name  # keep overwriting â†’ last bias

        bias_grad = named_grads[last_bias_name]
        return torch.argmin(bias_grad).detach().reshape((1,))

    def attack(self, iterations=200):
        # iDLG training image reconstruction:
        self.model.eval()
        
        # compute original gradients
        predicted = self.model(self.orig_img)
        loss = self.criterion(predicted, self.label)
        orig_grads = torch.autograd.grad(loss, self.model.parameters())
        orig_grads = list((_.detach().clone() for _ in orig_grads))

        # initialize dummy in the correct iteration, respecting the random seed
        # dummy_data = (torch.randn(self.orig_img.size(), dtype=self.param_dtype, device=self.device).requires_grad_(True))
        dummy_data = torch.as_tensor(self.orig_img)

        # init with ground truth:
        label_pred = self._infer_label_from_grads(orig_grads).requires_grad_(False)
        optimizer = LBFGS(
            [dummy_data], lr=.1, max_iter=50,
            tolerance_grad=1e-09, tolerance_change=1e-11,
            history_size=100, line_search_fn='strong_wolfe'
        )

        history = []
        losses = []

        for iters in tqdm(range(iterations)):
            def closure():
                optimizer.zero_grad()
                dummy_pred = self.model(dummy_data)
                dummy_loss = self.criterion(dummy_pred, label_pred)
                dummy_dy_dx = torch.autograd.grad(dummy_loss, self.model.parameters(), create_graph=True)
                grad_diff = 0
                for gx, gy in zip(dummy_dy_dx, orig_grads):
                    grad_diff += ((gx - gy) ** 2).sum()
                grad_diff.backward()
                return grad_diff

            optimizer.step(closure)

            # Optional: keep dummy within valid input range
            if self.clamp is not None:
                with torch.no_grad():
                    dummy_data.clamp_(self.clamp[0], self.clamp[1])

            if iters % 1 == 0:
                current_loss = closure()
                losses.append(current_loss.item())
                history.append(self.tt(dummy_data[0].cpu()))

        return dummy_data.detach().numpy().squeeze(), label_pred, history, losses

### Define Model Architecture 

In [None]:
class LeNet(nn.Module):
    def __init__(self, channel: int = 3, hidden: int = 768, num_classes: int = 10):
        super(LeNet, self).__init__()
        act = nn.Sigmoid
        self.body = nn.Sequential(
            nn.Conv2d(channel, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1),
            act(),
        )
        self.fc = nn.Sequential(
            nn.Linear(hidden, num_classes)
        )

    def forward(self, x):
        out = self.body(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

### Initialize Model

In [None]:
model = LeNet(num_classes=5)
model.load_state_dict(torch.load("seed_41_alpha_1.00_th_95_method_None_epoch_1.pth")['model_state_dict'])
model.eval()

### Get a Datapoint 

In [None]:
# load an image from cifar10 from your dataloader 
input_image = torch.Tensor(np.transpose(skimage.io.imread("image.png") / 255, (2, 0, 1))[None, :, :, :])
label = torch.Tensor([1]).long() 

### Attack 

In [None]:
idlg = iDLG(model=model, device=torch.device("cpu"), orig_img = input_image, label = torch.tensor([label.item()]), clamp = (0, 1))
dummy_data_idlg, label_pred_idlg, history_idlg, losses_idlg = idlg.attack(iterations=1000)

fig, ax = plt.subplots(1, 2, figsize = (12, 3))
ax[0].plot(losses_idlg, 'b-', label = 'Loss')
ax[0].set_xlabel('Iteration')
ax[0].set_ylabel('Loss')
ax[0].legend()
ax[1].imshow(np.transpose(dummy_data_idlg, (1, 2, 0)))
# ax[1].imshow(history_invgra[34])
plt.show()

diff = [(np.array(history_idlg[i+1]) - np.array(history_idlg[i])).sum().item() for i in range(len(history_idlg)-1)]
plt.plot(diff)
plt.show()

