In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import os
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import torchvision
from torchvision.utils import make_grid
import time

from csgm import ConditionalScoreModel2D, ConditionalScoreModel2Dy
from csgm.utils import CustomLRScheduler

from utils_ours import uncond_loss_fn, our_loss_fn, cond_loss_fn
from samplers import *

from torchvision.datasets import MNIST
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import functools
import math



In [3]:
T = 1

size = 175

dataset_str = "Celeb" #@param ['LIDC', 'MNIST'] {'type':'string'}


eps = 5e-4
device = "cuda"

mode = "ours" #@param ['unconditional', 'conditional', 'ours'] {'type':'string'}
inv_prob = "deblurr3c"#@param ['ct', 'deblurr'] {'type':'string'}


batch_size =   50
val_batch_size = 50

## Parameters below are fixed for our experiment. Changing them will require training a new model for our approach.

sig_blurr = 5

max_angle = 45
#angles = int(size * max_angle / 180.)
angles = size

    
## learning rate
lr= 0.002 #@param {'type':'number'}
lr_final = 0.0005




if dataset_str == "MNIST": 
    modes = 15
    hidden_dim = 32
    from torchvision.datasets import MNIST
    dataset = MNIST('.', 
                train=True, 
                transform=torchvision.transforms.Compose(
                    [#torchvision.transforms.Resize(32),
                     torchvision.transforms.ToTensor(), 
                     torchvision.transforms.Lambda(lambda x: x.permute(1, 2, 0))]),
                download=True)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    dataset = MNIST('.', 
                train=False, 
                transform=torchvision.transforms.Compose(
                    [#torchvision.transforms.Resize(32),
                     torchvision.transforms.ToTensor(), 
                     torchvision.transforms.Lambda(lambda x: x.permute(1, 2, 0))]),
                download=True)
    val_data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    sig_obs = 0.05
    n_epochs = 30
elif dataset_str == "Celeb": 
    modes = 50
    hidden_dim = 128
    dataset = torchvision.datasets.CelebA(root = "/fabian/work/Project CT Diffusion/SBD-task--/SBD-task-dependent-main/dataset/celeba/",
                                          split = "train", transform=torchvision.transforms.Compose(
                    [torchvision.transforms.Resize((size, size)),
                     torchvision.transforms.ToTensor(), 
                     torchvision.transforms.Lambda(lambda x: x.permute(1, 2, 0))]))

    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    dataset = torchvision.datasets.CelebA(root = "/fabian/work/Project CT Diffusion/SBD-task--/SBD-task-dependent-main/dataset/celeba/",
                                          split = "valid", transform=torchvision.transforms.Compose(
                    [torchvision.transforms.Resize((size, size)),
                     torchvision.transforms.ToTensor(), 
                     torchvision.transforms.Lambda(lambda x: x.permute(1, 2, 0))]))
    val_data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    sig_obs = 0.01
    n_epochs = 30
    
elif dataset_str == "LIDC":
    modes = 50
    hidden_dim = 128
    from utils_ours import LungDataset
    dataset = LungDataset(train = True,
                       transform=torchvision.transforms.Compose(
                           [torchvision.transforms.Resize(size, antialias=True), torchvision.transforms.Lambda(lambda x: x.permute(1, 2, 0))]))

    data_loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=0)

    dataset = LungDataset(train = False,
                       transform=torchvision.transforms.Compose(
                           [torchvision.transforms.Resize(size, antialias=True), torchvision.transforms.Lambda(lambda x: x.permute(1, 2, 0))]))

    val_data_loader = DataLoader(dataset, batch_size=val_batch_size, shuffle=True, num_workers=0)
    sig_obs = 0.05
    n_epochs = 30
elif dataset_str == "GP":
    modes = 15
    hidden_dim = 64
    from utils_ours import GP
    dataset = GP(train = True, size = size)
                       #transform=torchvision.transforms.Compose(
                       #    [torchvision.transforms.Lambda(lambda x: x.permute(1, 2, 0))]))

    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    dataset = GP(train = False, size = size)
                       #transform=torchvision.transforms.Compose(
                       #    [torchvision.transforms.Lambda(lambda x: x.permute(1, 2, 0))]))

    val_data_loader = DataLoader(dataset, batch_size=val_batch_size, shuffle=True, num_workers=4)
    sig_obs = 0.01
    n_epochs = 30

In [4]:
u = next(iter(data_loader))

In [8]:
# @title Load the score-based model

d3 = inv_prob == "deblurr3d"


if mode == "conditional":
    model = ConditionalScoreModel2Dy(
            modes=modes,
            hidden_dim=hidden_dim,
            nlayers=12,
            nt=T,
            d3 = d3
        ).to(device)
elif mode == "ours":
    model = ConditionalScoreModel2D(
            modes=modes,
            hidden_dim=hidden_dim,
            nlayers=12,
            nt=T,
            d3 = d3
        ).to(device)
else:
    model = ConditionalScoreModel2D(
            modes=modes,
            hidden_dim=hidden_dim,
            nlayers=12,
            nt=T,
            d3 = d3
        ).to(device)
color_channels = 3


if mode == "ours" or mode == "conditional":
    #prepare forward map and other required parameters
    if inv_prob == "deblurr" or inv_prob == "deblurr3c":
        from utils_ours import build_blurring_operator
        A = build_blurring_operator(size, sig_blurr).to(device)
        filename = "test" + mode + "_" + "T=" + str(T)+ "_" + dataset_str + inv_prob + str(sig_blurr) + str(size) + str(hidden_dim)
    elif inv_prob == "ct":
        from utils_ours import build_CT_operator
        try:
            A = torch.load("CT_forward" + str(size) + "angles" + str(angles) + "max_angle" + str(max_angle) +".pt")
        except:
            A = build_CT_operator(size, angles, max_angle) 
        A = torch.from_numpy(A)
        filename = mode + "_" + "T=" + str(T)+ "_" + dataset_str + inv_prob + str(max_angle) + str(size) + str(hidden_dim)
    elif inv_prob == "seismic":
        A = torch.load("seismic_forward.pt")
        filename = mode + "_" + "T=" + str(T)+ "_" + dataset_str + inv_prob


if mode == "ours":
    loss_fn = functools.partial(our_loss_fn, A = A.float().to(device), sig_obs = sig_obs, T = T)
elif mode == "conditional":
    loss_fn = functools.partial(cond_loss_fn, A = A.float().to(device), sig_obs = sig_obs, T = T, d3 = d3)
elif mode == "unconditional":
    loss_fn = functools.partial(uncond_loss_fn, T = T)    
    filename = mode + "_" + "T=" + str(T)+ "_" + dataset_str + str(size) + str(hidden_dim)


  A = torch.load("Blurr_forward" + str(size) + "sig:" + str(sig_blurr) +".pt")


In [None]:
%matplotlib inline

keepTraining = False
if keepTraining:
    ckpt = torch.load(os.path.join("checkpoints", "ckpt_trained_" + filename + ".pth"), map_location=device)
    score_model.load_state_dict(ckpt)
    f = open(os.path.join("checkpoints","training_prog" + filename + ".txt"), "r")
    content = f.read()
    content = content.replace(" ", "").replace(",", "")
    content = content.splitlines()
    content = [float(num) for num in content]
    losses_hist = content
    f.close()
else:
    losses_hist = []


start = time.time()
if True:
    try:
        f = open(os.path.join("cpts","training_prog" + filename + ".txt"), "x")
    except:
        print("training_prog file exists already")
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    # Setup the learning rate scheduler.
    scheduler = CustomLRScheduler(optimizer, lr, lr_final,
                                  n_epochs)


    for epoch in range(n_epochs):
      avg_loss = 0.
      num_items = 0
      val_loss = 0
      #with torch.no_grad():
      #    for x in val_data_loader:
      #        if dataset_str == "MNIST" or dataset_str == "seismic" or dataset_str == "CIFAR"or dataset_str == "Celeb":
      #           x = x[0]
      #        if dataset_str == "seismic":
      #            x = x[:, 0, :, :, 0]
      #            x = x[:, :, :, None]
      #        x = x.to(device)  
      #        loss = loss_fn(model, x)
      #        #if mode == "unconditional":
      #        loss_actual = loss
      #        #elif mode == "ours":
      #        #      loss_actual = loss  * torch.exp(t/2) / torch.sqrt(torch.exp(t)-1)
      #        val_loss += loss_actual.item() 
      #other_val_loss = (val_loss / len(val_data_loader))
      model.eval()
      with torch.no_grad():
          for x in val_data_loader:
              if dataset_str == "MNIST" or dataset_str == "seismic" or dataset_str == "CIFAR"or dataset_str == "Celeb":
                 x = x[0]
              if dataset_str == "seismic":
                  x = x[:, 0, :, :, 0]
                  x = x[:, :, :, None]
              x = x.to(device)  
              loss = loss_fn(model, x)
              #if mode == "unconditional":
              loss_actual = loss
              #elif mode == "ours":
              #      loss_actual = loss  * torch.exp(t/2) / torch.sqrt(torch.exp(t)-1)
              avg_loss += loss_actual.item() 
      losses_hist.append(avg_loss / len(val_data_loader))
      model.train()
      scheduler.step()
      with tqdm(data_loader, unit=' itr', colour='#B5F2A9', dynamic_ncols=True) as pb:
          for x in pb:
            if dataset_str == "MNIST" or dataset_str == "seismic" or dataset_str == "CIFAR"or dataset_str == "Celeb":
                 x = x[0]
            if dataset_str == "seismic":
                  x = x[:, 0, :, :, 0]
                  x = x[:, :, :, None]
            x = x.to(device)
            loss = loss_fn(model, x)   
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
              
            optimizer.step()
            optimizer.zero_grad()
            pb.set_postfix({
                        'epoch': epoch,
                        'train obj': "{:.2f}".format(loss),
                        'val obj': "{:.2f}".format(losses_hist[-1])
                    })



      #write training error to logging file 
      f = open(os.path.join("cpts","training_prog" + filename + ".txt"), "w")
      f.write("".join(str(a) + ", \n" for a in losses_hist))
      f.close()
      # Update the checkpoint after each epoch of training.
      torch.save(model.state_dict(), os.path.join("cpts", "ckpt_trained_" + filename + ".pth"))

 65%|[38;2;181;242;169m███▏ [0m| 2105/3256 [58:57<32:14,  1.68s/ itr, epoch=0, train obj=705.32, val obj=27327.77][0m

# 