In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import numpy as np
import os
import matplotlib.pyplot as plt
from PIL import Image
import argparse
import logging
import copy
import pickle
from datetime import datetime

class OneHotCVAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim, class_size=10):
        super(OneHotCVAE, self).__init__()
        self.fc1 = nn.Linear(x_dim + class_size, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        self.fc4 = nn.Linear(z_dim + class_size, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)

    def encoder(self, x, c):
        inputs = torch.cat([x,c], dim=1)
        h = F.relu(self.fc1(inputs))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h)

    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def decoder(self, z, c):
        inputs = torch.cat([z,c], dim=1)
        h = F.relu(self.fc4(inputs))
        h = F.relu(self.fc5(h))
        return torch.sigmoid(self.fc6(h))

    def forward(self, x, c):
        mu, log_var = self.encoder(x.view(-1, 784), c)
        z = self.sampling(mu, log_var)
        return self.decoder(z, c), mu, log_var

class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.bn1 = nn.BatchNorm2d(10)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.bn2 = nn.BatchNorm2d(20)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.bn1(self.conv1(x)), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.bn2(self.conv2(x))), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

class Config:
    def __init__(self, forget=[0]):
        self.dataset = "MNIST"
        self.x_dim = 784
        self.h_dim1 = 512
        self.h_dim2 = 256
        self.z_dim = 8
        if isinstance(forget, int):
            forget = [forget]
        self.digit_to_forget = forget
        self.lmbda = 100
        self.gamma = 1
        self.n_forget_iters = 10000
        self.batch_size = 256
        self.lr = 1e-4
        self.n_samples_per_class = 10
        self.load_existing_models = True
        self.base_model_dir = "./saved_models"
        forget_label = "_".join(map(str, forget)) if isinstance(forget, (list, tuple)) else str(forget)
        self.exp_root_dir = f"./results/forgetting_test_{forget_label}"
        self.log_dir = os.path.join(self.exp_root_dir, 'logs')
        self.ckpt_dir = os.path.join(self.exp_root_dir, 'ckpts')
        self.sample_dir = os.path.join(self.exp_root_dir, 'samples')
        os.makedirs(self.log_dir, exist_ok=True)
        os.makedirs(self.ckpt_dir, exist_ok=True)
        os.makedirs(self.sample_dir, exist_ok=True)
        os.makedirs(self.base_model_dir, exist_ok=True)

def save_model(model, filename, config=None):
    save_path = os.path.join(config.base_model_dir, filename)
    torch.save({
        'model_state_dict': model.state_dict(),
        'config': config
    }, save_path)
    print(f"Model saved to: {save_path}")

def load_model(model_class, filename, config, device):
    load_path = os.path.join(config.base_model_dir, filename)
    if os.path.exists(load_path):
        print(f"Loading model from: {load_path}")
        try:
            checkpoint = torch.load(load_path, map_location=device, weights_only=True)
        except:
            checkpoint = torch.load(load_path, map_location=device, weights_only=False)
        model = model_class(config.x_dim, config.h_dim1, config.h_dim2, config.z_dim).to(device)
        model.load_state_dict(checkpoint['model_state_dict'])
        return model
    else:
        print(f"No saved model found at: {load_path}")
        return None

def save_fisher_matrix(fisher_dict, config):
    fisher_path = os.path.join(config.base_model_dir, 'fisher_matrix.pkl')
    with open(fisher_path, 'wb') as f:
        pickle.dump(fisher_dict, f)
    print(f"Fisher matrix saved to: {fisher_path}")

def load_fisher_matrix(config):
    fisher_path = os.path.join(config.base_model_dir, 'fisher_matrix.pkl')
    if os.path.exists(fisher_path):
        print(f"Loading Fisher matrix from: {fisher_path}")
        with open(fisher_path, 'rb') as f:
            fisher_dict = pickle.load(f)
        return fisher_dict
    else:
        print(f"No saved Fisher matrix found at: {fisher_path}")
        return None

def save_classifier(classifier, config):
    classifier_path = os.path.join(config.base_model_dir, 'classifier.pt')
    torch.save(classifier.state_dict(), classifier_path)
    print(f"Classifier saved to: {classifier_path}")

def load_classifier(config, device):
    classifier_path = os.path.join(config.base_model_dir, 'classifier.pt')
    if os.path.exists(classifier_path):
        print(f"Loading classifier from: {classifier_path}")
        classifier = Classifier().to(device)
        classifier.load_state_dict(torch.load(classifier_path, map_location=device, weights_only=True))
        return classifier
    else:
        print(f"No saved classifier found at: {classifier_path}")
        return None

def train_cvae(config, device):
    print("Training original CVAE")
    train_dataset = datasets.MNIST('./dataset', train=True, download=True,
                                   transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    vae = OneHotCVAE(config.x_dim, config.h_dim1, config.h_dim2, config.z_dim).to(device)
    optimizer = optim.Adam(vae.parameters(), lr=config.lr)
    vae.train()
    for epoch in range(50):
        train_loss = 0
        for batch_idx, (data, labels) in enumerate(train_loader):
            data = data.to(device)
            c = F.one_hot(labels, 10).to(device)
            optimizer.zero_grad()
            recon_batch, mu, log_var = vae(data, c)
            loss = loss_function(recon_batch, data, mu, log_var)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {train_loss/len(train_loader):.6f}')
    save_model(vae, 'original_vae.pt', config)
    return vae

def calculate_fim(vae, config, device):
    print("Calculating FIM")
    fisher_dict = {}
    for name, param in vae.named_parameters():
        fisher_dict[name] = torch.zeros_like(param.data)
    n_fim_samples = 1000
    for _ in range(n_fim_samples):
        with torch.no_grad():
            z = torch.randn(1, config.z_dim).to(device)
            c = torch.randint(0, 10, (1,)).to(device)
            c = F.one_hot(c, 10)
            sample = vae.decoder(z, c)
        vae.zero_grad()
        recon_batch, mu, log_var = vae(sample, c)
        loss = loss_function(recon_batch, sample, mu, log_var)
        loss.backward()
        for name, param in vae.named_parameters():
            fisher_dict[name] += (param.grad.data ** 2) / n_fim_samples
    save_fisher_matrix(fisher_dict, config)
    return fisher_dict

def train_forgetting(original_vae, fisher_dict, config, device):
    print(f"Training to forget {config.digit_to_forget}")
    vae = copy.deepcopy(original_vae)
    vae.train()
    params_mle_dict = {}
    for name, param in original_vae.named_parameters():
        params_mle_dict[name] = param.data.clone()
    optimizer = optim.Adam(vae.parameters(), lr=config.lr)
    label_choices = [d for d in range(10) if d not in config.digit_to_forget]
    vae_clone = copy.deepcopy(vae)
    vae_clone.eval()
    losses = []
    for step in range(config.n_forget_iters):
        c_remember = torch.from_numpy(np.random.choice(label_choices, size=config.batch_size)).to(device)
        c_remember = F.one_hot(c_remember, 10)
        z_remember = torch.randn((config.batch_size, config.z_dim)).to(device)
        forget_labels = np.random.choice(config.digit_to_forget, size=config.batch_size)
        c_forget = torch.from_numpy(forget_labels).to(device)
        c_forget = F.one_hot(c_forget, 10)
        out_forget = torch.rand((config.batch_size, 1, 28, 28)).to(device)
        with torch.no_grad():
            out_remember = vae_clone.decoder(z_remember, c_remember).view(-1, 1, 28, 28)
        optimizer.zero_grad()
        recon_batch, mu, log_var = vae(out_forget, c_forget)
        loss = loss_function(recon_batch, out_forget, mu, log_var)
        recon_batch, mu, log_var = vae(out_remember, c_remember)
        loss += config.gamma * loss_function(recon_batch, out_remember, mu, log_var)
        for n, p in vae.named_parameters():
            _loss = fisher_dict[n].to(device) * (p - params_mle_dict[n].to(device)) ** 2
            loss += config.lmbda * _loss.sum()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        if (step + 1) % 1000 == 0:
            print(f'Forgetting Step: {step+1}, Loss: {loss.item():.6f}')
    forgotten_model_name = f'forgotten_vae_digits_{"_".join(map(str, config.digit_to_forget))}.pt'
    save_model(vae, forgotten_model_name, config)
    return vae, losses

def generate_samples(vae, config, device, prefix="original"):
    print(f"Generating {prefix} samples")
    vae.eval()
    with torch.no_grad():
        for digit in range(10):
            z = torch.randn((config.n_samples_per_class, config.z_dim)).to(device)
            c = (torch.ones(config.n_samples_per_class, dtype=int) * digit).to(device)
            c = F.one_hot(c, 10)
            samples = vae.decoder(z, c).view(-1, 1, 28, 28)
            digit_dir = os.path.join(config.sample_dir, prefix, f'digit_{digit}')
            os.makedirs(digit_dir, exist_ok=True)
            for i, sample in enumerate(samples):
                save_image(sample, os.path.join(digit_dir, f'sample_{i}.png'))
            grid = make_grid(samples, nrow=5)
            save_image(grid, os.path.join(config.sample_dir, f'{prefix}_digit_{digit}_grid.png'))


def train_classifier(config, device):
    print("Training classifier")
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./dataset', train=True, download=True,
                      transform=transforms.ToTensor()),
        batch_size=64, shuffle=True)

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./dataset', train=False, download=True,
                      transform=transforms.ToTensor()),
        batch_size=1000, shuffle=True)
    classifier = Classifier().to(device)
    optimizer = optim.Adam(classifier.parameters(), lr=1e-4)
    classifier.train()
    for epoch in range(5):
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = classifier(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
    save_classifier(classifier, config)
    return classifier

def evaluate_with_classifier(classifier, sample_dir, device):
    print("Evaluating samples with classifier")
    results = {}
    for digit in range(10):
        original_samples = []
        forgotten_samples = []
        orig_dir = os.path.join(sample_dir, 'original', f'digit_{digit}')
        forg_dir = os.path.join(sample_dir, 'forgotten', f'digit_{digit}')
        if os.path.exists(orig_dir):
            for img_file in os.listdir(orig_dir)[:10]:
                img = Image.open(os.path.join(orig_dir, img_file)).convert('L')
                img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(device)
                original_samples.append(img_tensor)
        if os.path.exists(forg_dir):
            for img_file in os.listdir(forg_dir)[:10]:
                img = Image.open(os.path.join(forg_dir, img_file)).convert('L')
                img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(device)
                forgotten_samples.append(img_tensor)

        if original_samples:
            orig_tensor = torch.cat(original_samples)
            with torch.no_grad():
                orig_log_probs = classifier(orig_tensor)
                orig_probs = torch.exp(orig_log_probs)
                orig_entropy = -torch.sum(orig_probs * orig_log_probs, dim=1).mean().item()
                orig_target_prob = orig_probs[:, digit].mean().item()
            results[f'original_digit_{digit}'] = {
                'entropy': orig_entropy,
                'target_prob': orig_target_prob
            }

        if forgotten_samples:
            forg_tensor = torch.cat(forgotten_samples)
            with torch.no_grad():
                forg_log_probs = classifier(forg_tensor)
                forg_probs = torch.exp(forg_log_probs)
                forg_entropy = -torch.sum(forg_probs * forg_log_probs, dim=1).mean().item()
                forg_target_prob = forg_probs[:, digit].mean().item()
            results[f'forgotten_digit_{digit}'] = {
                'entropy': forg_entropy,
                'target_prob': forg_target_prob
            }
    return results

def create_comparison_plots(results, config):
    if hasattr(config, "digits_to_forget"):
        forgotten_digits = config.digits_to_forget
    else:
        forgotten_digits = [config.digit_to_forget]

    original_entropies = []
    forgotten_entropies = []
    digits = []
    for digit in range(10):
        orig_key = f'original_digit_{digit}'
        forg_key = f'forgotten_digit_{digit}'
        if orig_key in results and forg_key in results:
            original_entropies.append(results[orig_key]['entropy'])
            forgotten_entropies.append(results[forg_key]['entropy'])
            digits.append(digit)

    plt.figure(figsize=(12, 8))

    plt.subplot(2, 2, 1)
    x = np.arange(len(digits))
    width = 0.35
    plt.bar(x - width/2, original_entropies, width, label='Original', alpha=0.7)
    plt.bar(x + width/2, forgotten_entropies, width, label='Forgotten', alpha=0.7)
    for fd in forgotten_digits:
        if fd in digits:
            plt.axvline(x=digits.index(fd), color='red', linestyle='--')
    plt.xlabel('Digit')
    plt.ylabel('Entropy')
    plt.title('Entropy Comparison')
    plt.xticks(x, digits)
    plt.legend([f'Forgotten Digits: {forgotten_digits}'])

    plt.subplot(2, 2, 2)
    original_probs = [results[f'original_digit_{d}']['target_prob'] for d in digits]
    forgotten_probs = [results[f'forgotten_digit_{d}']['target_prob'] for d in digits]
    plt.bar(x - width/2, original_probs, width, label='Original', alpha=0.7)
    plt.bar(x + width/2, forgotten_probs, width, label='Forgotten', alpha=0.7)
    for fd in forgotten_digits:
        if fd in digits:
            plt.axvline(x=digits.index(fd), color='red', linestyle='--')
    plt.xlabel('Digit')
    plt.ylabel('Target Probability')
    plt.title('Target Probability Comparison')
    plt.xticks(x, digits)
    plt.legend([f'Forgotten Digits: {forgotten_digits}'])

    plt.subplot(2, 2, 3)
    entropy_diff = [f - o for o, f in zip(original_entropies, forgotten_entropies)]
    colors = ['red' if d in forgotten_digits else 'blue' for d in digits]
    plt.bar(x, entropy_diff, color=colors, alpha=0.7)
    plt.xlabel('Digit')
    plt.ylabel('Entropy Difference (Forgotten - Original)')
    plt.title('Entropy Difference')
    plt.xticks(x, digits)

    plt.subplot(2, 2, 4)
    prob_diff = [o - f for o, f in zip(original_probs, forgotten_probs)]
    colors = ['red' if d in forgotten_digits else 'blue' for d in digits]
    plt.bar(x, prob_diff, color=colors, alpha=0.7)
    plt.xlabel('Digit')
    plt.ylabel('Probability Difference (Original - Forgotten)')
    plt.title('Probability Difference')
    plt.xticks(x, digits)

    plt.tight_layout()
    plt.savefig(os.path.join(config.exp_root_dir, 'comparison_analysis.png'), dpi=300, bbox_inches='tight')
    plt.close()

def display_sample_grids(config):
    fig, axes = plt.subplots(2, 10, figsize=(20, 4))
    for digit in range(10):
        orig_grid_path = os.path.join(config.sample_dir, f'original_digit_{digit}_grid.png')
        if os.path.exists(orig_grid_path):
            orig_img = Image.open(orig_grid_path)
            axes[0, digit].imshow(orig_img)
            axes[0, digit].set_title(f'Original {digit}')
            axes[0, digit].axis('off')

        forg_grid_path = os.path.join(config.sample_dir, f'forgotten_digit_{digit}_grid.png')
        if os.path.exists(forg_grid_path):
            forg_img = Image.open(forg_grid_path)
            axes[1, digit].imshow(forg_img)
            axes[1, digit].set_title(f'Forgotten {digit}')
            axes[1, digit].axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(config.exp_root_dir, 'sample_comparison_grid.png'), dpi=300, bbox_inches='tight')
    plt.close()


def main(i=[0]):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    config = Config(i)
    print(f"Configuration:")
    print(f"Digit to forget: {config.digit_to_forget}")
    print(f"Lambda (EWC): {config.lmbda}")
    print(f"Gamma: {config.gamma}")
    print(f"Forgetting iterations: {config.n_forget_iters}")
    print(f"Load existing models: {config.load_existing_models}")
    print(f"Output directory: {config.exp_root_dir}")

    if config.load_existing_models:
        original_vae = load_model(OneHotCVAE, 'original_vae.pt', config, device)
    else:
        original_vae = None

    if original_vae is None:
        original_vae = train_cvae(config, device)
    else:
        print("Using pre-trained original VAE")

    if config.load_existing_models:
        fisher_dict = load_fisher_matrix(config)
    else:
        fisher_dict = None

    if fisher_dict is None:
        fisher_dict = calculate_fim(original_vae, config, device)
    else:
        print("Using pre-calculated Fisher matrix")

    generate_samples(original_vae, config, device, "original")

    forgotten_model_name = f'forgotten_vae_digit_{config.digit_to_forget}.pt'
    if config.load_existing_models:
        forgotten_vae = load_model(OneHotCVAE, forgotten_model_name, config, device)
    else:
        forgotten_vae = None

    if forgotten_vae is None:
        forgotten_vae, losses = train_forgetting(original_vae, fisher_dict, config, device)
    else:
        print(f"Using pre-trained forgetting model for digit {config.digit_to_forget}")
        losses = []

    generate_samples(forgotten_vae, config, device, "forgotten")

    if config.load_existing_models:
        classifier = load_classifier(config, device)
    else:
        classifier = None

    if classifier is None:
        classifier = train_classifier(config, device)
    else:
        print("Using pre-trained classifier")

    results = evaluate_with_classifier(classifier, config.sample_dir, device)
    create_comparison_plots(results, config)
    display_sample_grids(config)
    print("\nRESULTS SUMMARY")
    forgotten_digit = config.digit_to_forget
    orig_key = f'original_digit_{forgotten_digit}'
    forg_key = f'forgotten_digit_{forgotten_digit}'

    if orig_key in results and forg_key in results:
        print(f"\nForgotten Digit {forgotten_digit}:")
        print(f"  Original - Entropy: {results[orig_key]['entropy']:.4f}, Target Prob: {results[orig_key]['target_prob']:.4f}")
        print(f"  Forgotten - Entropy: {results[forg_key]['entropy']:.4f}, Target Prob: {results[forg_key]['target_prob']:.4f}")
        print(f"  Entropy Change: {results[forg_key]['entropy'] - results[orig_key]['entropy']:.4f}")
        print(f"  Probability Change: {results[orig_key]['target_prob'] - results[forg_key]['target_prob']:.4f}")

    print(f"\nAll results and visualizations saved to: {config.exp_root_dir}")
    print(f"All models saved to: {config.base_model_dir}")

if __name__ == "__main__":
    for i in range (10):
        main(i)     # forget each digit
    main([3, 7])    # example of multiple forgeting

Using device: cuda
Configuration:
Digit to forget: [0]
Lambda (EWC): 100
Gamma: 1
Forgetting iterations: 10000
Load existing models: True
Output directory: ./results/forgetting_test_0
No saved model found at: ./saved_models/original_vae.pt
Training original CVAE


100%|██████████| 9.91M/9.91M [00:02<00:00, 4.55MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 120kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.14MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.47MB/s]


Epoch 0, Loss: 70412.649967
Epoch 10, Loss: 32287.901400
Epoch 20, Loss: 29604.869199
Epoch 30, Loss: 28381.453437
Epoch 40, Loss: 27651.993605
Model saved to: ./saved_models/original_vae.pt
No saved Fisher matrix found at: ./saved_models/fisher_matrix.pkl
Calculating FIM
Fisher matrix saved to: ./saved_models/fisher_matrix.pkl
Generating original samples
No saved model found at: ./saved_models/forgotten_vae_digit_[0].pt
Training to forget [0]
Forgetting Step: 1000, Loss: 175461.125000
Forgetting Step: 2000, Loss: 173720.015625
Forgetting Step: 3000, Loss: 172912.875000
Forgetting Step: 4000, Loss: 172548.984375
Forgetting Step: 5000, Loss: 171329.453125
Forgetting Step: 6000, Loss: 171668.812500
Forgetting Step: 7000, Loss: 171866.671875
Forgetting Step: 8000, Loss: 171814.656250
Forgetting Step: 9000, Loss: 171682.062500
Forgetting Step: 10000, Loss: 169973.281250
Model saved to: ./saved_models/forgotten_vae_digits_0.pt
Generating forgotten samples
No saved classifier found at: ./sav