# Train Model

In [None]:
import json
from utils import TrainWrapper
config = "/home/anubis/memdir/diploma/diff_unlearn/model/config.json"
save = "/home/anubis/memdir/diploma/diff_unlearn/model/model_fromscratch"

with open(config, "r") as fd:
    config = json.load(fd)

trainer = TrainWrapper(model_save=save, unlearn_label=0, **config)
trainer.train()

# Unlearn Model

In [None]:
import os
import json
from utils import UnlearnWrapper

conf_path = "/home/anubis/memdir/diploma/diff_unlearn/model_unlearn4"

for config in os.listdir(conf_path):
    config = os.path.join(conf_path, config)
    with open(config, "r") as fd:
        conf_dict = json.load(fd)
    
    unlearner = UnlearnWrapper(**conf_dict)
    unlearner()

# Sample

In [None]:
from diffusers import DDPMScheduler
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
import torchvision
import torch

from unet import MNIST_Unet
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')

DEVICE = "cuda"


In [None]:
model = MNIST_Unet()
model.load_state_dict(torch.load("model_unlearn/checkpoints/500_10.pt"))

In [None]:
x = torch.randn(80, 1, 28, 28).to(DEVICE)
y = torch.tensor([[i]*8 for i in range(0, 10)]).flatten().to(DEVICE)
model.to(DEVICE)
# Sampling loop
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
    with torch.no_grad():
        residual = model(x, t, y)  # Again, note that we pass in our labels y

    x = noise_scheduler.step(residual, t, x).prev_sample

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.axis("off")
ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=8)[0], cmap='Greys')
plt.suptitle("Сгенерированные изображения измененной модели\nS=500, K=10")
plt.savefig("tests/sample_500_10.jpg")
plt.show()

# MIA Plots

In [None]:
from metrics import MIA
import matplotlib.pyplot as plt
import matplotlib
import os
from tqdm.auto import tqdm

In [None]:
ckpts = "model_unlearn/checkpoints"
batch_size = 50
matplotlib.use("Agg")
for pt in os.listdir(ckpts):
    for forget in [0,1]:
        print(f"MODEL: {pt}")
        models = ["/home/anubis/memdir/diploma/diff_unlearn/model/model15.pt"]
        model_path = os.path.join(ckpts, pt)
        losses = MIA(batch_size=batch_size, models=models+[model_path], forget=forget)()
        print(losses)
        model_names = list(losses.keys()) 
        plt.figure(figsize=(10,6))
        plt.plot(losses[model_names[0]], label=os.path.basename(model_names[0]))
        plt.plot(losses[model_names[1]], label=os.path.basename(model_names[1]))

        plt.legend()
        plt.xlabel('Step')
        plt.ylabel('Loss')
        s,k = pt.split('.')[0].split('_')
        plt.title(f'S = {s}, K = {k}')
        plt.savefig(f"visual/{pt.split('.')[0]}_{forget}.jpg")