In [None]:
import os
import subprocess

def git_repo_root():
    # Run the 'git rev-parse --show-toplevel' command to get the root directory of the Git repository
    try:
        root = subprocess.check_output(['git', 'rev-parse', '--show-toplevel'], universal_newlines=True).strip()
        return root
    except subprocess.CalledProcessError:
        # Handle the case where the current directory is not inside a Git repository
        return None

# Get the root directory of the Git repository
git_root = git_repo_root()

if git_root:
    # Change the working directory to the root of the Git repository
    os.chdir(git_root)
    print(f"Changed working directory to: {git_root}")
else:
    print("Not inside a Git repository.")

In [None]:
from diffusion import VPSDE
import torchvision
from matplotlib import pyplot as plt
import torchvision.transforms as transforms
from denoising_diffusion_pytorch import Unet
from diffusion import VPSDE
from torch.utils.data import DataLoader
from torch.optim import Adam
from training import train_score_network_mnist
import torch
from guided_diffusion import Net

In [None]:
device = 'cuda'
data = 'MNIST'
model = Unet(channels = 1, dim = 32).to(device)
model.load_state_dict(torch.load('./models/MNIST/epoch49'))
sde = VPSDE(100, 0.1, 20, device = device)

''' hyper params'''
n_epochs =   50
batch_size =  32
lr=1e-4
optimizer = Adam(model.parameters(), lr=lr)

In [None]:
def plot(score_net):
  with torch.no_grad():
    samples = sde.backward_diffusion(score_net, data_shape = (5, 1, 32, 32)).detach().cpu().numpy()
  samples = samples.swapaxes(1,2)
  samples = samples.swapaxes(2,3)
  fig, axe = plt.subplots(5)
  for i in range(5):
      axe[i].imshow(samples[i])
  plt.show()

In [None]:
plot(model)

In [None]:
classifier = Net().to(device)
classifier.load_state_dict(torch.load('./models/MNISTClassifier/epoch99'))

In [None]:
def get_numbers(score_net, classifier, batch_size, number):

  samples = sde.classifier_guided_backward_diffusion(score_net, classifier, data_shape = (batch_size, 1, 32, 32),
                                     classes = torch.Tensor([number]*batch_size).to(torch.int64).to(device)).detach().cpu().numpy()

  return samples

def plot(score_net, classifier):

  fig, axe = plt.subplots(3, 10)

  for i in range(10):
    axe[0][i].set_title(f'{i}')
    samples = get_numbers(score_net, classifier, 3, i)
    for j in range(3):
      axe[j, i].imshow(samples[j][0], cmap = 'gray', vmin=0, vmax=1)
      axe[j, i].axis('off')

  plt.tight_layout(pad=0.1)
  plt.show()

In [None]:
plot(model, classifier)

In [None]:
plot(model, classifier)

In [None]:
plot(model, classifier)

In [None]:
plot(model, classifier)

In [None]:
plot(model, classifier)

In [None]:
plot(model, classifier)