In [6]:
# Configuration and Hyperparameters
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mode = 1  # Set to 0 for Learn, 1 for Explain
save_path = "./lear_results/"
os.makedirs(save_path, exist_ok=True)

# Hyperparameters (from original M_config)
disc_ch = cfmap_ch = 32
num_epochs = 100
batch_size = 256
if mode == 0:
    lr = 0.0005
    lr_decay = 0.98
else:
    g_step, d_step = 1, 1
    lr_g, lr_d = 0.001, 0.001
    lr_decay = 0.99
    beta1 = 0.9
    one_sided_label_smoothing = 0.1

# Loss weights (lambda values from hyper_param list in M_config)
loss_weights = {'cls': 1.0, 'norm': 10.0, 'GAN': 1.0, 'cyc': 1.0, 'dis': 0.5}

In [7]:
from torchvision import datasets

def load_data():
    """Loads MNIST data as NumPy arrays."""
    train = datasets.MNIST(root=".", train=True, download=True)
    test  = datasets.MNIST(root=".", train=False, download=True)
    images_train = train.data.numpy().astype(np.float32)  # shape (60000,28,28)
    labels_train = train.targets.numpy().astype(np.int64) # shape (60000,)
    images_test  = test.data.numpy().astype(np.float32)   # shape (10000,28,28)
    labels_test  = test.targets.numpy().astype(np.int64)
    return images_train, labels_train, images_test, labels_test

def create_splits(labels_train, labels_test):
    """Creates train/val/test index splits using a fixed seed."""
    rng = np.random.RandomState(seed=970304)
    all_idx = rng.permutation(np.where(labels_train >= 0)[0])
    n_val = int(len(all_idx) * 0.1)
    valid_idx = all_idx[:n_val]
    train_idx = all_idx[n_val:]
    test_idx = rng.permutation(np.where(labels_test >= 0)[0])
    return train_idx, valid_idx, test_idx

def separate_data(idx, all_dat, all_lbl, center=True):
    """
    Selects a batch by indices and preprocesses:
      - center=True: normalize each image to [0,1].
      - center=False: pad 2 pixels, random crop back to 28x28, then normalize.
    Returns (batch,1,28,28) images and (batch,10) one-hot labels.
    """
    batch_size = len(idx)
    dat = all_dat[idx]  # shape (batch,28,28)
    if not center:
        # Pad and random crop
        padded = np.pad(dat, ((0,0),(2,2),(2,2)), mode='constant')  # to 32x32
        cropped = np.empty((batch_size,28,28), dtype=np.float32)
        for i in range(batch_size):
            h_off = np.random.randint(0, 5)
            w_off = np.random.randint(0, 5)
            cropped[i] = padded[i, h_off:h_off+28, w_off:w_off+28]
        dat = cropped
    # Per-image normalization to [0,1]
    for i in range(batch_size):
        img = dat[i]
        img_min = img.min()
        img_max = img.max()
        if img_max > img_min:
            dat[i] = (img - img_min) / (img_max - img_min)
        else:
            dat[i] = img  # constant image
    dat = np.expand_dims(dat, axis=1)  # to (batch,1,28,28)
    # One-hot encode labels
    lbl = all_lbl[idx]
    lbl_onehot = np.eye(10)[lbl]
    lbl_onehot = lbl_onehot.astype(np.float32)
    return dat.astype(np.float32), lbl_onehot

def code_creator(size):
    """Generates random target one-hot codes for each sample in a batch."""
    target_c = np.zeros((size,10), dtype=np.float32)
    for i in range(size):
        c = np.random.randint(0, 10)  # pick random class
        target_c[i, c] = 1.0
    return target_c

# Update ipython-input-22-47982226
def codemap(target_c):
    """
    Converts target class vectors into spatial code maps c1 (14x14), c2 (7x7), c3 (4x4).
    Each map has 10 channels: for class k, the (h,w) entries are target_c[:,k].
    """
    batch = target_c.shape[0]
    c1 = np.zeros((batch, 10, 14, 14), dtype=np.float32)
    c2 = np.zeros((batch, 10, 7, 7), dtype=np.float32)
    # c3 size should be 4x4 to match the encoder's conv3_2 output after correction
    c3 = np.zeros((batch, 10, 4, 4), dtype=np.float32)
    for b in range(batch):
        for k in range(10):
            if target_c[b, k] != 0:
                c1[b, k, :, :] = target_c[b, k]
                c2[b, k, :, :] = target_c[b, k]
                c3[b, k, :, :] = target_c[b, k]
    return c1, c2, c3

In [8]:
# Update ipython-input-23-47982226
class Encoder(nn.Module):
    def __init__(self, ch=32):
        super().__init__()
        # conv1 block (28x28 -> 14x14)
        self.conv1_1 = nn.Conv2d(1,   ch, kernel_size=3, stride=1, padding=1)
        self.bn1_1   = nn.BatchNorm2d(ch)
        self.conv1_2 = nn.Conv2d(ch,  ch, kernel_size=4, stride=2, padding=1)  # 28->14
        self.bn1_2   = nn.BatchNorm2d(ch)
        # conv2 block (14x14 -> 7x7)
        self.conv2_1 = nn.Conv2d(ch,  ch*2, kernel_size=3, stride=1, padding=1)
        self.bn2_1   = nn.BatchNorm2d(ch*2)
        self.conv2_2 = nn.Conv2d(ch*2, ch*2, kernel_size=4, stride=2, padding=1)  # 14->7
        self.bn2_2   = nn.BatchNorm2d(ch*2)
        # conv3 block (7x7 -> 4x4) - Changed kernel size and padding to achieve 4x4 output
        self.conv3_1 = nn.Conv2d(ch*2, ch*4, kernel_size=3, stride=1, padding=1)
        self.bn3_1   = nn.BatchNorm2d(ch*4)
        # (7 - kernel_size + 2*padding)/stride + 1 = (7 - 3 + 2*1)/2 + 1 = 4x4
        self.conv3_2 = nn.Conv2d(ch*4, ch*4, kernel_size=3, stride=2, padding=1)  # 7->4 # Corrected
        self.bn3_2   = nn.BatchNorm2d(ch*4)
        # conv4 block (4x4 -> 2x2) - Now takes 4x4 input and correctly outputs 2x2
        self.conv4_1 = nn.Conv2d(ch*4, ch*8, kernel_size=3, stride=1, padding=1)
        self.bn4_1   = nn.BatchNorm2d(ch*8)
        # (4 - kernel_size + 2*padding)/stride + 1 = (4 - 4 + 2*1)/2 + 1 = 2x2
        self.conv4_2 = nn.Conv2d(ch*8, ch*8, kernel_size=4, stride=2, padding=1)  # 4->2 # Corrected
        self.bn4_2   = nn.BatchNorm2d(ch*8)

    def forward(self, x):
        x = F.relu(self.bn1_1(self.conv1_1(x)))
        x = F.relu(self.bn1_2(self.conv1_2(x)))
        conv1_2 = x  # size [B, ch, 14,14]
        x = F.relu(self.bn2_1(self.conv2_1(x)))
        x = F.relu(self.bn2_2(self.conv2_2(x)))
        conv2_2 = x  # [B, 2ch, 7,7]
        x = F.relu(self.bn3_1(self.conv3_1(x)))
        x = F.relu(self.bn3_2(self.conv3_2(x)))
        conv3_2 = x  # [B, 4ch, 4,4] # This comment is now correct with the change
        x = F.relu(self.bn4_1(self.conv4_1(x)))
        x = F.relu(self.bn4_2(self.conv4_2(x)))
        conv4_2 = x  # [B, 8ch, 2,2] # This comment is now correct
        return conv1_2, conv2_2, conv3_2, conv4_2

class Classifier(nn.Module):
    def __init__(self, ch=32):
        super().__init__()
        self.encoder = Encoder(ch)
        # Two fully-connected layers with dropouts
        self.dropout1 = nn.Dropout(0.5)
        # Corrected input size to FC layer: flattened size of [B, 8ch, 2, 2] is B * 8ch * 2 * 2 = B * 32ch
        self.fc1 = nn.Linear(ch * 8 * 4, 128) # Adjusted input size to FC layer
        self.dropout2 = nn.Dropout(0.25)
        self.fc2      = nn.Linear(128, 10)

    def forward(self, x):
        # x: [B,1,28,28]
        _, _, _, enc_out = self.encoder(x)      # enc_out [B,8ch,2,2] # This comment is now correct
        # Flatten [B, 8ch, 2, 2] to [B, 8ch * 4]
        flat = enc_out.view(x.size(0), -1)      # [B, 8ch*4] # This comment is now correct
        x = self.dropout1(flat)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)                         # [B,10] logits
        return x

class Decoder(nn.Module):
    def __init__(self, ch=32):
        super().__init__()
        # Upsampling (bilinear) layers
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        # dec_upconv3: from enc_conv4_2 upsample (2x2 -> 4x4) # This is correct now
        self.dec_upconv3 = nn.Conv2d(ch*8, ch*4, kernel_size=3, stride=1, padding=1)
        self.bn_upconv3  = nn.BatchNorm2d(ch*4)
        # dec_code_conv3: after concatenating with c3 (size [4ch+10,4,4]) # This is correct now
        self.dec_code_conv3 = nn.Conv2d(ch*4 + 10, ch*4, kernel_size=3, stride=1, padding=1)
        self.bn_code_conv3  = nn.BatchNorm2d(ch*4)
        # dec_conv3: after merging skip and up paths (size [8ch,4,4]) # This is correct now
        self.dec_conv3 = nn.Conv2d(ch*8, ch*4, kernel_size=3, stride=1, padding=1)
        self.bn_conv3  = nn.BatchNorm2d(ch*4)
        # dec_upconv2: (4x4 -> 8x8 -> 7x7 after valid conv)
        # Input is 4x4. Upsample by 2 -> 8x8.
        # dec_upconv2: kernel_size=2, stride=1, padding=0. Output: (8-2+0)/1 + 1 = 6+1 = 7x7. OK.
        self.dec_upconv2 = nn.Conv2d(ch*4, ch*2, kernel_size=2, stride=1, padding=0)  # 8->7 # Corrected comment
        self.bn_upconv2  = nn.BatchNorm2d(ch*2)
        # dec_code_conv2: cat with c2 (size [2ch+10,7,7]). OK
        self.dec_code_conv2 = nn.Conv2d(ch*2 + 10, ch*2, kernel_size=3, stride=1, padding=1)
        self.bn_code_conv2  = nn.BatchNorm2d(ch*2)
        # dec_conv2: after merging (size [4ch,7,7]). OK.
        self.dec_conv2 = nn.Conv2d(ch*4, ch*2, kernel_size=3, stride=1, padding=1)
        self.bn_conv2  = nn.BatchNorm2d(ch*2)
        # dec_upconv1: (7x7 -> 14x14)
        self.dec_upconv1 = nn.Conv2d(ch*2, ch, kernel_size=3, stride=1, padding=1)
        self.bn_upconv1  = nn.BatchNorm2d(ch)
        # dec_code_conv1: cat with c1 (size [ch+10,14,14]). OK.
        self.dec_code_conv1 = nn.Conv2d(ch + 10, ch, kernel_size=3, stride=1, padding=1)
        self.bn_code_conv1  = nn.BatchNorm2d(ch)
        # dec_conv1: after merging (size [2ch,14,14]). OK.
        self.dec_conv1 = nn.Conv2d(ch*2, ch, kernel_size=3, stride=1, padding=1)
        self.bn_conv1  = nn.BatchNorm2d(ch)
        # Final upsampling via ConvTranspose2d (14->28)
        # Input: 14x14. Output: (14-1)*2 - 2*1 + 4 = 13*2 - 2 + 4 = 26 - 2 + 4 = 28x28. OK.
        self.dec_up = nn.ConvTranspose2d(ch, 1, kernel_size=4, stride=2, padding=1)

    def forward(self, enc1, enc2, enc3, enc4, c1, c2, c3):
        """
        enc1: conv1_2 (B, ch,14,14)
        enc2: conv2_2 (B,2ch,7,7)
        enc3: conv3_2 (B,4ch,4,4) # Corrected comment
        enc4: conv4_2 (B,8ch,2,2) # Corrected comment
        c1: [B,10,14,14]; c2: [B,10,7,7]; c3: [B,10,4,4] # This comment should match the code after fixing codemap
        """
        # Stage 3
        x = self.upsample(enc4)                          # (2x2 -> 4x4)
        x = F.relu(self.bn_upconv3(self.dec_upconv3(x))) # (B,4ch,4,4)
        # We changed enc3 to be 4x4 by modifying the encoder, and adjusted c3 to be 4x4 by modifying codemap.
        # So concatenation [enc3 (B,4ch,4,4), c3 (B,10,4,4)] along dim=1 should work now.
        cat3 = torch.cat([enc3, c3], dim=1)             # (B,4ch+10,4,4) # This line is now correct
        y = F.relu(self.bn_code_conv3(self.dec_code_conv3(cat3)))  # (B,4ch,4,4)
        z = torch.cat([y, x], dim=1)                    # (B,8ch,4,4)
        dec3 = F.relu(self.bn_conv3(self.dec_conv3(z))) # (B,4ch,4,4)
        # Stage 2
        x = self.upsample(dec3)                         # (4x4 -> 8x8)
        x = F.relu(self.bn_upconv2(self.dec_upconv2(x)))# (B,2ch,7,7) # Corrected comment
        cat2 = torch.cat([enc2, c2], dim=1)             # (B,2ch+10,7,7)
        y = F.relu(self.bn_code_conv2(self.dec_code_conv2(cat2)))  # (B,2ch,7,7)
        z = torch.cat([y, x], dim=1)                    # (B,4ch,7,7)
        dec2 = F.relu(self.bn_conv2(self.dec_conv2(z))) # (B,2ch,7,7)
        # Stage 1
        x = self.upsample(dec2)                         # (7x7 -> 14x14)
        x = F.relu(self.bn_upconv1(self.dec_upconv1(x)))# (B,ch,14,14)
        cat1 = torch.cat([enc1, c1], dim=1)             # (B,ch+10,14,14)
        y = F.relu(self.bn_code_conv1(self.dec_code_conv1(cat1)))  # (B,ch,14,14)
        z = torch.cat([y, x], dim=1)                    # (B,2ch,14,14)
        dec1 = F.relu(self.bn_conv1(self.dec_conv1(z))) # (B,ch,14,14)
        # Final upsampling
        out = self.dec_up(dec1)  # (B,1,28,28)
        out = torch.tanh(out)    # counterfactual map in [-1,1]
        return out

class Discriminator(nn.Module):
    def __init__(self, ch=32):
        super().__init__()
        # conv1 (28x28)
        self.conv1_1 = nn.Conv2d(1, ch, kernel_size=3, stride=1, padding=1)
        self.conv1_2 = nn.Conv2d(ch, ch, kernel_size=4, stride=2, padding=1)  # 28->14
        self.bn1_2   = nn.BatchNorm2d(ch)
        # conv2 (14x14)
        self.conv2_1 = nn.Conv2d(ch, ch*2, kernel_size=3, stride=1, padding=1)
        self.bn2_1   = nn.BatchNorm2d(ch*2)
        self.conv2_2 = nn.Conv2d(ch*2, ch*2, kernel_size=4, stride=2, padding=1)  # 14->7
        self.bn2_2   = nn.BatchNorm2d(ch*2)
        # conv3 (7x7)
        self.conv3_1 = nn.Conv2d(ch*2, ch*4, kernel_size=3, stride=1, padding=1)
        self.bn3_1   = nn.BatchNorm2d(ch*4)
        self.conv3_2 = nn.Conv2d(ch*4, ch*4, kernel_size=4, stride=2, padding=1)  # 7->4
        self.bn3_2   = nn.BatchNorm2d(ch*4)
        # conv4 (4x4)
        self.conv4_1 = nn.Conv2d(ch*4, ch*8, kernel_size=3, stride=1, padding=1)
        self.bn4_1   = nn.BatchNorm2d(ch*8)
        self.conv4_2 = nn.Conv2d(ch*8, ch*8, kernel_size=4, stride=2, padding=1)  # 4->2
        # Output layer
        self.fc      = nn.Linear(ch*8, 1)

    def forward(self, x):
        x = self.conv1_1(x)
        x = self.conv1_2(x);  x = self.bn1_2(x); x = F.leaky_relu(x, 0.2)
        x = self.conv2_1(x);  x = self.bn2_1(x); x = F.leaky_relu(x, 0.2)
        x = self.conv2_2(x);  x = self.bn2_2(x); x = F.leaky_relu(x, 0.2)
        x = self.conv3_1(x);  x = self.bn3_1(x); x = F.leaky_relu(x, 0.2)
        x = self.conv3_2(x);  x = self.bn3_2(x); x = F.leaky_relu(x, 0.2)
        x = self.conv4_1(x);  x = self.bn4_1(x); x = F.leaky_relu(x, 0.2)
        x = self.conv4_2(x)   # [B, ch*8, 2,2]
        x = x.view(x.size(0), -1)
        x = self.fc(x)        # [B, 1] raw logits
        return x

In [9]:
# Loss objects
ce_loss_fn = nn.CrossEntropyLoss()
mse_loss   = nn.MSELoss()
l1_loss    = nn.L1Loss()

def classification_accuracy(logits, labels_idx):
    """Compute accuracy given logits and true class indices."""
    preds = torch.argmax(logits, dim=1)
    return (preds == labels_idx).float().mean().item()

def one_sided_smooth(target_onehot):
    """Apply one-sided label smoothing (1->0.9, 0->0) for one-hot vectors."""
    # Convert numpy to torch if needed
    # target_onehot is a float tensor [B,10]
    return torch.where(target_onehot == 1.0, 0.9, target_onehot)

In [10]:
if mode == 0:
    # Load data and create splits
    images_train, labels_train, images_test, labels_test = load_data()
    train_idx, valid_idx, test_idx = create_splits(labels_train, labels_test)

    # Initialize model, optimizer, scheduler
    model = Classifier(ch=cfmap_ch).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=lr_decay)

    best_val_acc = 0.0
    for epoch in range(num_epochs):
        np.random.shuffle(train_idx)
        model.train()
        running_loss = 0.0
        # --- Training ---
        for i in range(0, len(train_idx), batch_size):
            batch_idx = train_idx[i:i+batch_size]
            batch_x, batch_y = separate_data(batch_idx, images_train, labels_train, center=False)
            inputs = torch.tensor(batch_x).to(device)          # [B,1,28,28]
            labels_onehot = torch.tensor(batch_y).to(device)
            labels_idx = torch.tensor(labels_train[batch_idx]).to(device)

            optimizer.zero_grad()
            logits = model(inputs)     # [B,10]
            loss = ce_loss_fn(logits, labels_idx)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)

        scheduler.step()  # decay LR

        # --- Validation ---
        model.eval()
        with torch.no_grad():
            val_loss = 0.0; val_acc = 0.0; count = 0
            for i in range(0, len(valid_idx), batch_size):
                val_batch = valid_idx[i:i+batch_size]
                vx, vy = separate_data(val_batch, images_train, labels_train, center=True)
                vx = torch.tensor(vx).to(device)
                vy_idx = torch.tensor(labels_train[val_batch]).to(device)
                logits = model(vx)
                loss = ce_loss_fn(logits, vy_idx)
                acc = classification_accuracy(logits, vy_idx)
                val_loss += loss.item()
                val_acc  += acc
                count += 1
            val_loss /= count
            val_acc  /= count
            print(f"Epoch {epoch+1}/{num_epochs}: Train Loss={running_loss/len(train_idx):.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}")

            # Save best model by validation accuracy
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), os.path.join(save_path, f"cls_model_best.pth"))

        # --- Test Evaluation ---
        test_loss = 0.0; test_acc = 0.0; tcount = 0
        for i in range(0, len(test_idx), batch_size):
            test_batch = test_idx[i:i+batch_size]
            tx, ty = separate_data(test_batch, images_test, labels_test, center=True)
            tx = torch.tensor(tx).to(device)
            ty_idx = torch.tensor(labels_test[test_batch]).to(device)
            logits = model(tx)
            loss = ce_loss_fn(logits, ty_idx)
            acc = classification_accuracy(logits, ty_idx)
            test_loss += loss.item()
            test_acc  += acc
            tcount += 1
        test_loss /= tcount
        test_acc  /= tcount
        print(f"Epoch {epoch+1}: Test Loss={test_loss:.4f}, Test Acc={test_acc:.4f}")

    # Save final models (encoder + classifier)
    torch.save(model.state_dict(), os.path.join(save_path, "cls_model_final.pth"))
    torch.save(model.encoder.state_dict(), os.path.join(save_path, "enc_model_final.pth"))

elif mode == 1:
    # Load data and splits (re-use indices from mode0 or recompute)
    images_train, labels_train, images_test, labels_test = load_data()
    train_idx, valid_idx, test_idx = create_splits(labels_train, labels_test)

    # Initialize models
    classifier = Classifier(ch=cfmap_ch).to(device)
    classifier.load_state_dict(torch.load(os.path.join(save_path, "cls_model_best.pth")))
    for param in classifier.parameters():
        param.requires_grad = False  # freeze classifier
    decoder = Decoder(ch=cfmap_ch).to(device)
    discriminator = Discriminator(ch=disc_ch).to(device)

    # Separate encoder for obtaining features
    encoder = classifier.encoder

    # Optimizers and schedulers
    gen_params = list(decoder.parameters())
    disc_params = list(discriminator.parameters())
    gen_opt = optim.Adam(gen_params, lr=lr_g, betas=(beta1, 0.999))
    disc_opt = optim.Adam(disc_params, lr=lr_d, betas=(beta1, 0.999))
    gen_scheduler = optim.lr_scheduler.ExponentialLR(gen_opt, gamma=lr_decay)
    disc_scheduler = optim.lr_scheduler.ExponentialLR(disc_opt, gamma=lr_decay)

    for epoch in range(num_epochs):
        np.random.shuffle(train_idx)
        # Training
        decoder.train(); discriminator.train()
        for step in range(0, len(train_idx), batch_size):
            batch_idx = train_idx[step:step+batch_size]
            batch_x, batch_y = separate_data(batch_idx, images_train, labels_train, center=False)
            inputs = torch.tensor(batch_x).to(device)           # [B,1,28,28]
            labels_onehot = torch.tensor(batch_y).to(device)   # [B,10]
            labels_idx = torch.tensor(labels_train[batch_idx]).to(device)

            # Prepare target class codes for generator
            target_c = torch.tensor(code_creator(len(batch_idx)), dtype=torch.float32).to(device)

            # === Discriminator step ===
            # Generate CF-map and pseudo-images
            with torch.no_grad():
                enc1, enc2, enc3, enc4 = encoder(inputs)  # get encoder features
                c1_np, c2_np, c3_np = codemap(target_c.cpu().numpy())
                c1 = torch.tensor(c1_np).to(device)
                c2 = torch.tensor(c2_np).to(device)
                c3 = torch.tensor(c3_np).to(device)
                CFmap = decoder(enc1, enc2, enc3, enc4, c1, c2, c3)
                pseudo = inputs + CFmap

            real_logits = discriminator(inputs)
            fake_logits = discriminator(pseudo)
            # Real labels=1, Fake labels=0
            d_loss_real = mse_loss(real_logits, torch.ones_like(real_logits))
            d_loss_fake = mse_loss(fake_logits, torch.zeros_like(fake_logits))
            d_loss = loss_weights['dis'] * (d_loss_real + d_loss_fake)

            disc_opt.zero_grad()
            d_loss.backward()
            disc_opt.step()

            # === Generator (Decoder) step ===
            # Recompute CF-map (now with grad) and pseudo
            enc1, enc2, enc3, enc4 = encoder(inputs)
            CFmap = decoder(enc1, enc2, enc3, enc4, c1, c2, c3)
            pseudo = inputs + CFmap

            # Classification output on pseudo-images
            logits_fake = classifier(pseudo)
            # Optionally apply one-sided smoothing to target codes
            if one_sided_label_smoothing:
                smoothed = one_sided_smooth(target_c)
                cls_loss = loss_weights['cls'] * torch.mean(-torch.sum(smoothed.to(device) * F.log_softmax(logits_fake, dim=1), dim=1))
            else:
                # direct CE with integer targets
                cls_loss = loss_weights['cls'] * ce_loss_fn(logits_fake, torch.tensor(np.argmax(target_c.cpu().numpy(), axis=1)).to(device))

            # GAN loss (generator tries to make discriminator output 1)
            gan_loss = loss_weights['GAN'] * mse_loss(discriminator(pseudo), torch.ones_like(fake_logits))

            # Cycle-consistency: feed pseudo back into decoder using predicted label
            with torch.no_grad():
                pred_probs = F.softmax(logits_fake, dim=1)
            c1_pred_np, c2_pred_np, c3_pred_np = codemap(pred_probs.cpu().numpy())
            c1_pred = torch.tensor(c1_pred_np).to(device)
            c2_pred = torch.tensor(c2_pred_np).to(device)
            c3_pred = torch.tensor(c3_pred_np).to(device)
            enc1_p, enc2_p, enc3_p, enc4_p = encoder(pseudo)
            tilde_map = decoder(enc1_p, enc2_p, enc3_p, enc4_p, c1_pred, c2_pred, c3_pred)
            like_input = pseudo + tilde_map
            cyc_loss = loss_weights['cyc'] * torch.mean(torch.abs(like_input - inputs))

            # L1 norm on CF-map (sparsity/regularization)
            l1_norm_loss = loss_weights['norm'] * torch.mean(torch.abs(CFmap))

            # Total generator loss
            g_loss = cls_loss + gan_loss + cyc_loss + l1_norm_loss

            gen_opt.zero_grad()
            g_loss.backward()
            gen_opt.step()

        gen_scheduler.step()
        disc_scheduler.step()

        # Validation printout (accuracy of classifier on reconstruction)
        decoder.eval()
        total_tst_acc = 0.0
        batches = 0
        with torch.no_grad():
            for step in range(0, len(test_idx), batch_size):
                batch_idx = test_idx[step:step+batch_size]
                tx, ty = separate_data(batch_idx, images_test, labels_test, center=True)
                tx = torch.tensor(tx).to(device)
                ty_idx = torch.tensor(labels_test[batch_idx]).to(device)
                # Generate pseudo images targeting random classes (or use code_creator)
                target_c = torch.tensor(code_creator(len(batch_idx)), dtype=torch.float32).to(device)
                enc1, enc2, enc3, enc4 = encoder(tx)
                c1_np, c2_np, c3_np = codemap(target_c.cpu().numpy())
                c1 = torch.tensor(c1_np).to(device)
                c2 = torch.tensor(c2_np).to(device)
                c3 = torch.tensor(c3_np).to(device)
                CFmap = decoder(enc1, enc2, enc3, enc4, c1, c2, c3)
                pseudo = tx + CFmap
                pred = classifier(pseudo)
                # Compare predicted class to the target code's class
                target_idx = torch.tensor(np.argmax(target_c.cpu().numpy(), axis=1)).to(device)
                acc = classification_accuracy(pred, target_idx)
                total_tst_acc += acc
                batches += 1
            avg_tst_acc = total_tst_acc / batches
            print(f"Epoch {epoch+1}/{num_epochs} (Explain): Test Acc (pseudo vs target) = {avg_tst_acc:.4f}")

    # Save final decoder model
    torch.save(decoder.state_dict(), os.path.join(save_path, "dec_model_final.pth"))

Epoch 1/100 (Explain): Test Acc (pseudo vs target) = 0.9826
Epoch 2/100 (Explain): Test Acc (pseudo vs target) = 0.9607
Epoch 3/100 (Explain): Test Acc (pseudo vs target) = 0.9866
Epoch 4/100 (Explain): Test Acc (pseudo vs target) = 0.9462
Epoch 5/100 (Explain): Test Acc (pseudo vs target) = 0.9818
Epoch 6/100 (Explain): Test Acc (pseudo vs target) = 0.9545
Epoch 7/100 (Explain): Test Acc (pseudo vs target) = 0.9591
Epoch 8/100 (Explain): Test Acc (pseudo vs target) = 0.9794
Epoch 9/100 (Explain): Test Acc (pseudo vs target) = 0.9844
Epoch 10/100 (Explain): Test Acc (pseudo vs target) = 0.9871
Epoch 11/100 (Explain): Test Acc (pseudo vs target) = 0.9896
Epoch 12/100 (Explain): Test Acc (pseudo vs target) = 0.9849
Epoch 13/100 (Explain): Test Acc (pseudo vs target) = 0.9953
Epoch 14/100 (Explain): Test Acc (pseudo vs target) = 0.9851
Epoch 15/100 (Explain): Test Acc (pseudo vs target) = 0.9821
Epoch 16/100 (Explain): Test Acc (pseudo vs target) = 0.9958
Epoch 17/100 (Explain): Test Acc 

In [11]:
import matplotlib.pyplot as plt
from scipy.linalg import sqrtm
import numpy as np
import torch # Import torch

def visualize_counterfactual(save_path, images_test, decoder, encoder, epoch):
    """
    Selects a subset of test images and for each target class (0-9),
    generates the counterfactual map and pseudo-image, then plots in a grid.
    """
    plt_dir = os.path.join(save_path, "plt")
    os.makedirs(plt_dir, exist_ok=True)
    # Select 10 representative test images (could be random or fixed indices)
    idxs = np.array([0,1,2,3,4,5,6,7,8,9])  # first 10 digits
    fig_dat = []
    for i in idxs:
        img = images_test[i].astype(np.float32)
        # Normalize each image to [0,1] for consistent display
        img_min = img.min()
        img_max = img.max()
        if img_max > img_min:
            img = (img - img_min) / (img_max - img_min)
        else:
            img = img # constant image
        fig_dat.append(img)
    fig_dat = np.stack(fig_dat, axis=0)  # shape (10,28,28)
    code = np.eye(10, dtype=np.float32)  # identity (target codes)
    # Build batch of input repeats: for each image, we will vary target from 0..9
    total_input = []
    total_codes = []
    for i in range(10):
        for j in range(10):
            total_input.append(fig_dat[i])
            total_codes.append(code[j])
    total_input = np.stack(total_input, axis=0)  # (100,28,28)
    total_codes = np.stack(total_codes, axis=0)  # (100,10)
    total_input = np.expand_dims(total_input, axis=1)  # (100,1,28,28)
    total_input_t = torch.tensor(total_input, dtype=torch.float32).to(device)
    total_codes_t = torch.tensor(total_codes).to(device)
    # Generate CF maps
    with torch.no_grad():
        enc1, enc2, enc3, enc4 = encoder(total_input_t)
        # Ensure codemap is called with the correct input type and moved to CPU for numpy conversion
        c1_np, c2_np, c3_np = codemap(total_codes_t.cpu().numpy())
        c1 = torch.tensor(c1_np).to(device)
        c2 = torch.tensor(c2_np).to(device)
        c3 = torch.tensor(c3_np).to(device)
        CFmap = decoder(enc1, enc2, enc3, enc4, c1, c2, c3)  # (100,1,28,28)
        CFmap_np = CFmap.cpu().numpy().squeeze(1)
        pseudo_np = total_input[:,0,:,:] + CFmap_np

    # Plot grid: 10 rows (input index), each row shows: original, then 10 CF maps, then 10 pseudo-images
    rows, cols = 10, 21
    fig = plt.figure(figsize=(cols/2, rows/2))
    for i in range(10):
        # Original image in first column
        ax_in = fig.add_subplot(rows, cols, i*cols + 1)
        ax_in.imshow(fig_dat[i], cmap='gray'); ax_in.axis('off')
        if i==0:
            ax_in.set_title("Input")
        # Each target
        for j in range(10):
            ax_cf = fig.add_subplot(rows, cols, i*cols + 2 + j)
            ax_cf.imshow(CFmap_np[i*10+j], cmap='gray'); ax_cf.axis('off')
            ax_pseudo = fig.add_subplot(rows, cols, i*cols + 12 + j)
            ax_pseudo.imshow(pseudo_np[i*10+j], cmap='gray'); ax_pseudo.axis('off')
            if i==0:
                ax_cf.set_title(f"CF {j}")
                ax_pseudo.set_title(f"Class {j}")
    plt.suptitle(f"Counterfactual Maps - Epoch {epoch}")
    plt.savefig(os.path.join(plt_dir, f"epoch{epoch}.png"))
    plt.close()

def calculate_fid(act1, act2):
    """
    Compute a simple FID-like score between two batches of images (arrays).
    Inputs are expected to be (batch, height, width).
    """
    # Reshape to (batch, height * width)
    act1 = act1.reshape(act1.shape[0], -1).astype(np.float64)
    act2 = act2.reshape(act2.shape[0], -1).astype(np.float64)

    mu1 = np.mean(act1, axis=0)
    mu2 = np.mean(act2, axis=0)
    # Use bias=False for sample covariance
    sigma1 = np.cov(act1, rowvar=False, bias=False)
    sigma2 = np.cov(act2, rowvar=False, bias=False)

    diff = np.sum((mu1 - mu2)**2) # Sum squared difference for vectors
    covmean = sqrtm(sigma1.dot(sigma2))
    # numerical stability
    if np.iscomplexobj(covmean):
        covmean = covmean.real
        # Optional: check if imaginary part is significant
        # if np.max(np.abs(covmean.imag)) > 1e-6:
        #     print(f"Warning: Significant imaginary part in sqrtm result: {np.max(np.abs(covmean.imag))}")

    fid = diff + np.trace(sigma1 + sigma2 - 2*covmean)
    return fid

def estimate_fid(images_test, labels_test, decoder, encoder):
    """
    Estimates FID-like scores for each class by comparing real vs. fake distributions.
    For each class c, we pick samples of class c (real) and generate
    pseudo-images of class c from other real images (fake), and compute FID.
    """
    code = np.eye(10, dtype=np.float32)
    min_samples_per_class = 800 # Ensure enough samples per class
    fid_scores = {}

    for num in range(10):
        # 1. Get real images of class 'num'
        real_class_idx = np.where(labels_test == num)[0]
        rng = np.random.RandomState(seed=970304 + num) # Use different seed per class
        # Shuffle and select a fixed number of samples
        rng.shuffle(real_class_idx)
        if len(real_class_idx) < min_samples_per_class:
             print(f"Warning: Not enough real samples for class {num}. Found {len(real_class_idx)}, need {min_samples_per_class}.")
             continue # Skip FID calculation for this class if not enough samples
        real_images_class_num = images_test[real_class_idx[:min_samples_per_class]]

        # Normalize real images (per image as done in separate_data)
        real_norm = []
        for img in real_images_class_num:
            im = img.astype(np.float32) # Ensure float type
            img_min = im.min()
            img_max = im.max()
            if img_max > img_min:
                im = (im - img_min) / (img_max - img_min)
            else:
                im = im # constant image
            real_norm.append(im)
        real_norm = np.stack(real_norm, axis=0) # (min_samples_per_class, 28, 28)


        # 2. Generate fake images (pseudo-images) targeting class 'num'
        # Pick images from *other* classes as input, and set target to 'num'
        other_classes_idx = np.where(labels_test != num)[0]
        # Shuffle and select a fixed number of samples to generate pseudo-images from
        rng.shuffle(other_classes_idx)
        if len(other_classes_idx) < min_samples_per_class:
             print(f"Warning: Not enough samples from other classes to generate pseudo for class {num}. Found {len(other_classes_idx)}, need {min_samples_per_class}.")
             continue # Skip FID calculation for this class if not enough samples
        input_images_for_pseudo = images_test[other_classes_idx[:min_samples_per_class]]

        # Normalize input images for pseudo generation (using separate_data logic)
        input_norm_pseudo = []
        for img in input_images_for_pseudo:
            im = img.astype(np.float32) # Ensure float type
            img_min = im.min()
            img_max = im.max()
            if img_max > img_min:
                im = (im - img_min) / (img_max - img_min)
            else:
                im = im # constant image
            input_norm_pseudo.append(im)
        input_norm_pseudo = np.stack(input_norm_pseudo, axis=0) # (min_samples_per_class, 28, 28)

        # Add channel dimension and convert to tensor
        input_norm_pseudo_t = torch.tensor(np.expand_dims(input_norm_pseudo, axis=1), dtype=torch.float32).to(device)

        # Create target code for class 'num'
        target_code_num = np.zeros((min_samples_per_class, 10), dtype=np.float32)
        target_code_num[:, num] = 1.0
        target_code_num_t = torch.tensor(target_code_num).to(device)

        # Generate pseudo-images
        decoder.eval() # Ensure decoder is in eval mode
        encoder.eval() # Ensure encoder is in eval mode
        with torch.no_grad():
            enc1, enc2, enc3, enc4 = encoder(input_norm_pseudo_t)
            c1_np, c2_np, c3_np = codemap(target_code_num_t.cpu().numpy())
            c1 = torch.tensor(c1_np).to(device)
            c2 = torch.tensor(c2_np).to(device)
            c3 = torch.tensor(c3_np).to(device)
            CFmap = decoder(enc1, enc2, enc3, enc4, c1, c2, c3)
            pseudo = input_norm_pseudo_t[:,0,:,:] + CFmap.squeeze(1) # Pseudo is (batch, 28, 28)

        fake_norm = pseudo.cpu().numpy() # (min_samples_per_class, 28, 28)


        # 3. Compute FID between real and fake images of class 'num'
        try:
            fid = calculate_fid(real_norm, fake_norm)
            fid_scores[num] = fid
            print(f"Class {num} FID (approx): {fid:.4f}")
        except ValueError as e:
            print(f"Could not compute FID for class {num}: {e}")
            # print shapes for debugging
            # print(f"Shapes: real_norm={real_norm.shape}, fake_norm={fake_norm.shape}")
            # print(f"np.cov(real_norm, rowvar=False).shape={np.cov(real_norm.reshape(real_norm.shape[0], -1), rowvar=False, bias=False).shape}")
            # print(f"np.cov(fake_norm, rowvar=False).shape={np.cov(fake_norm.reshape(fake_norm.shape[0], -1), rowvar=False, bias=False).shape}")

    # Optionally, print average FID
    if fid_scores:
        avg_fid = np.mean(list(fid_scores.values()))
        print(f"\nAverage FID (approx) across classes: {avg_fid:.4f}")

In [12]:
visualize_counterfactual(save_path, images_test, decoder, encoder, epoch=100)

In [13]:
estimate_fid(images_test, labels_test, decoder, encoder)

Class 0 FID (approx): 40.3272
Class 1 FID (approx): 47.6529
Class 2 FID (approx): 29.9583
Class 3 FID (approx): 27.9147
Class 4 FID (approx): 26.8909
Class 5 FID (approx): 18.6595
Class 6 FID (approx): 36.4761
Class 7 FID (approx): 36.3927
Class 8 FID (approx): 22.2486
Class 9 FID (approx): 26.4769

Average FID (approx) across classes: 31.2998
