In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
from imports import *

In [None]:
T = 2

dataset_str = "MNIST" #@param ['LIDC', 'MNIST'] {'type':'string'}


eps = 1e-3
device = "cuda"

mode = "ours" #@param ['unconditional', 'ours'] {'type':'string'}
inv_prob = "deblurr"#@param ['ct', 'deblurr'] {'type':'string'}


batch_size =   64#@param {"type":"integer"}

## Parameters below are fixed for our experiment. Changing them will require training a new model for our approach.
sig_obs = 0.05

sig_blurr = 4

max_angle = 15

if dataset_str == "MNIST": 
    ## learning rate
    lr=1e-4 #@param {'type':'number'}
    from custom_configs import MNIST_ddpmpp_continuous as configs  
    from torchvision.datasets import MNIST
    dataset = MNIST('.', 
                train=True, 
                transform=torchvision.transforms.Compose(
                    [torchvision.transforms.Resize(32), torchvision.transforms.ToTensor()]),
                download=True)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    if mode == "unconditional":
        n_epochs =   30#@param {'type':'integer'}
    elif mode == "ours":
        if inv_prob == "ct":
            n_epochs = 30#@param {'type':'integer'}
        elif inv_prob == "deblurr":
            n_epochs = 40

In [None]:
# @title Load the score-based model
config = configs.get_config()

size = config.data.image_size

config.training.batch_size = batch_size
config.eval.batch_size = batch_size
size = config.data.image_size
config.optim.lr = lr

random_seed = 0 #@param {"type": "integer"}
torch.manual_seed(random_seed)

score_model = mutils.create_model(config)

optimizer = get_optimizer(config, score_model.parameters())
ema = ExponentialMovingAverage(score_model.parameters(),
                               decay=config.model.ema_rate)


ema.copy_to(score_model.parameters())
color_channels = config.data.num_channels

angles = int(size * max_angle / 180.)

if mode == "ours":
    #prepare forward map and other required parameters
    if inv_prob == "deblurr":
        from utils_ours import build_blurring_operator
        A = build_blurring_operator(size, sig_blurr).to(device)
        Gam = sig_obs * torch.eye(size**2).to(device)
        filename = mode + "_" + "T=" + str(T)+ "_" + dataset_str + inv_prob + str(sig_blurr)
    elif inv_prob == "ct":
        from utils_ours import build_CT_operator
        try:
            A = torch.load("CT_forward" + str(size) + "angles" + str(angles) + "max_angle" + str(max_angle) +".pt")
        except:
            A = build_CT_operator(size, angles, max_angle) 
        A = torch.from_numpy(A)
        Gam = sig_obs * torch.eye(size*angles).to(device)
        filename = mode + "_" + "T=" + str(T)+ "_" + dataset_str + inv_prob + str(max_angle)
    loss_fn = functools.partial(our_loss_fn, A = A.float().to(device), Gam = Gam.cuda().to(device), T = T)
    
elif mode == "unconditional":
    loss_fn = functools.partial(uncond_loss_fn, T = T)     
    filename = mode + "_" + "T=" + str(T)+ "_" + dataset_str

In [None]:
Train = True
keepTraining = False


if keepTraining:
    ckpt = torch.load(os.path.join("checkpoints", "ckpt_trained_" + filename + ".pth"), map_location=device)
    score_model.load_state_dict(ckpt)
    f = open(os.path.join("checkpoints","training_prog" + filename + ".txt"), "r")
    content = f.read()
    content = content.replace(" ", "").replace(",", "")
    content = content.splitlines()
    content = [float(num) for num in content]
    losses_hist = content
    f.close()
else:
    losses_hist = []
    
start = time.time()
if Train:
    try:
        f = open(os.path.join("checkpoints","training_prog" + filename + ".txt"), "x")
    except:
        print("training_prog file exists already")
    
    #optimizer = Adam(score_model.parameters(), lr=lr)
    scheduler = ReduceLROnPlateau(optimizer, 'min')

    for epoch in range(n_epochs): #tqdm_epoch:
      avg_loss = 0.
      num_items = 0
      for x, _ in data_loader:
        #x = x.to(device)   
        loss, t = loss_fn(score_model, x)
        optimizer.zero_grad()
        loss.backward()    
        optimizer.step()
        ema.update(score_model.parameters())
        if mode == "unconditional":
            loss_actual = loss
        elif mode == "ours":
            loss_actual = loss  * torch.exp(t/2) / torch.sqrt(torch.exp(t)-1)
        avg_loss += loss_actual.item() * x.shape[0]
        num_items += x.shape[0]
          
      scheduler.step(avg_loss / num_items)
      losses_hist.append(avg_loss / num_items)
      
      print("train_error: ", avg_loss / num_items) 
      #write training error to logging file 
      f = open(os.path.join("checkpoints","training_prog" + filename + ".txt"), "w")
      f.write("".join(str(a) + ", \n" for a in losses_hist))
      f.close()
      # Update the checkpoint after each epoch of training.
      torch.save(score_model.state_dict(), os.path.join("checkpoints", "ckpt_trained_" + filename + ".pth"))
      print("time elapsed:", time.time() - start)