In [None]:
import os
import subprocess

def git_repo_root():
    # Run the 'git rev-parse --show-toplevel' command to get the root directory of the Git repository
    try:
        root = subprocess.check_output(['git', 'rev-parse', '--show-toplevel'], universal_newlines=True).strip()
        return root
    except subprocess.CalledProcessError:
        # Handle the case where the current directory is not inside a Git repository
        return None

# Get the root directory of the Git repository
git_root = git_repo_root()

if git_root:
    # Change the working directory to the root of the Git repository
    os.chdir(git_root)
    print(f"Changed working directory to: {git_root}")
else:
    print("Not inside a Git repository.")

In [None]:
from diffusion import VPSDE
import torchvision
from matplotlib import pyplot as plt
import torchvision.transforms as transforms
from denoising_diffusion_pytorch import Unet
from diffusion import VPSDE
from torch.utils.data import DataLoader
from torch.optim import Adam
from training import train_score_network_mnist
import torch
from guided_diffusion import Net

In [None]:
device = 'cuda'
data = 'MNIST'
model = Unet(channels = 1, dim = 32).to(device)
model.load_state_dict(torch.load('./epoch49'))
sde = VPSDE(100, 0.1, 20, device = device)

''' hyper params'''
n_epochs =   50
batch_size =  32
lr=1e-4
optimizer = Adam(model.parameters(), lr=lr)

In [None]:
tfm = transforms.Compose([transforms.ToTensor(), transforms.Pad(2)])
data = torchvision.datasets.MNIST(f'./data/', transform=tfm, download = True)
data_loader = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
def plot(score_net):
  with torch.no_grad():
    samples = sde.backward_diffusion(score_net, data_shape = (5, 1, 32, 32)).detach().cpu().numpy()
  samples = samples.swapaxes(1,2)
  samples = samples.swapaxes(2,3)
  fig, axe = plt.subplots(5)
  for i in range(5):
      axe[i].imshow(samples[i])
  plt.show()

In [None]:
plot(model)

In [None]:
def half_cut(img):
  mask = torch.ones_like(img)
  mask[:, :, :16, :] = 0
  mask2 = torch.ones_like(img)
  mask2[:, :, 16:, :] = 0
  mask3 = torch.ones_like(img)
  mask3[:, :, :, 16:] = 0
  mask4 = torch.ones_like(img)
  mask4[:, :, :, :16] = 0

  b = img.size(0)
  total_mask = torch.cat([mask, mask4, mask2, mask3], axis = 0)
  idx = torch.randperm(4*b)
  total_mask = total_mask[idx, :][:b]

  img = img*total_mask
  return total_mask, img




for x, y in data_loader:
  mask, x = half_cut(x)
  plt.imshow(x[0][0])
  plt.show()
  plt.imshow(x[1][0])
  plt.show()
  plt.imshow(x[3][0])
  plt.show()
  plt.imshow(x[4][0])
  plt.show()
  break

In [None]:
def plot_with_infill(score_net, x_org, x, masks):

  plt.figure(figsize=(18,18))
  fig, axe = plt.subplots(10, 11)

  axe[0][0].set_title('Original', {'fontsize':8})
  axe[0][1].set_title('Masked', {'fontsize':8})

  for i in range(10):
    axe[i, 0].imshow(x_org[i][0].cpu(), cmap = 'gray', vmin=0, vmax=1)
    axe[i, 1].imshow(x[i][0].cpu(), cmap = 'gray', vmin=0, vmax=1)
    axe[i][0].axis('off')
    axe[i][1].axis('off')
    axe[i][2].set_visible(False)
    axe[i][2].set_visible(False)

  for i in range(8):
    with torch.no_grad():
      samples = sde.infilling_diffusion(score_net, x, mask, data_shape = (32, 1, 32, 32)).detach().cpu().numpy()
    for j in range(10):
      axe[j, 3+i].imshow(samples[j][0], cmap = 'gray', vmin=0, vmax=1)
      axe[j, 3+i].axis('off')
  plt.tight_layout(pad=0.1)
  plt.show()

In [None]:
for x, y in data_loader:
  x_org = torch.Tensor(x)
  mask, x = half_cut(x)
  x = x.to(device)
  mask = mask.to(device)
  plot_with_infill(model, x_org,  x, mask)
  break

In [None]:
torch.nn.functional.one_hot(torch.Tensor([5,5,5,5,5]).to(torch.int64).to(device))