In [1]:
import os
import torch
import torchvision
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
# from torch.utils.tensorboard import summary
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np

from model import UNet
from model_v2 import UNet_v2
from diffusion_discrete import DiscreteDiffusion, generate_betas

In [4]:
for i in reversed(range(len([1, 2, 2, 2]))):
    print(128 * [1, 2, 2, 2][i])
    print(i)

256
3
256
2
256
1
128
0


In [10]:
for i in range(1000):
    if i % 100 == 0:
        print(i)

0
100
200
300
400
500
600
700
800
900


In [2]:
gpu = torch.cuda.is_available()
device = torch.device("cuda:0" if gpu else "cpu")
print("device:", device)

device: cpu


In [3]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def imshow(img, w_shape=False, gray=False):
    if w_shape:
        pimg = img.permute(1, 2, 0)
        npimg = pimg.numpy()
    else:
        npimg = img.numpy()
    
    if gray:
        plt.imshow(npimg, cmap='gray')
        plt.show()
    else:
        plt.imshow(npimg)
        plt.show

In [4]:
# seed for reproducability
torch.manual_seed(50)

# training parameters
num_epochs = 10
batch_size = 128
lr = 2e-4
# model = UNet(image_channels=1, model_output='logistic_pars').to(device)
model = UNet_v2(image_channels=3, model_output='logits').to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

In [5]:
print("Number of model parameters: ", count_parameters(model))
print(model)

Number of model parameters:  165649152
UNet_v2(
  (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (down): ModuleList(
    (0-1): 2 x ResnetBlock(
      (norm1): Normalize(
        (norm): GroupNorm(32, 128, eps=1e-05, affine=True)
      )
      (act1): Swish()
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm2): Normalize(
        (norm): GroupNorm(32, 128, eps=1e-05, affine=True)
      )
      (act2): Swish()
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (time_emb): Linear(in_features=512, out_features=128, bias=True)
      (time_act): Swish()
      (dropout): Dropout(p=0.1, inplace=False)
      (shortcut): Identity()
    )
    (2): Downsample(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
    (3): ResnetBlock(
      (norm1): Normalize(
        (norm): GroupNorm(32, 128, eps=1e-05, affine=True)
      )
      (act1): Swish()
  

In [17]:
# change image size from 28 to 32 so that it is power of 2
# img_size = 32

# transform = transforms.Compose([
#     transforms.Resize(img_size),
#     transforms.PILToTensor()
# ])

# trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

# testset = datasets.MNIST(root='./data', train=False, transform=transform)
# test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)


train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.PILToTensor()
])

test_transform = transforms.Compose([
    transforms.PILToTensor()
])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

testset = datasets.CIFAR10(root='./data', train=False, transform=test_transform)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)

In [18]:
betas = generate_betas(type='linear', start=1.e-4, stop=0.02, num_steps=1000).to(device)

In [19]:
diffusion = DiscreteDiffusion(betas=betas, transition_mat_type='gaussian',
                              num_bits=8, transition_bands=None, model_prediction='x_start',
                              model_output='logistic_pars', loss_type='hybrid',
                              hybrid_coeff=0.001, device=device)

  self.betas = betas = torch.tensor(betas, dtype=torch.float64)


In [20]:
train_losses = []
test_losses = []

for e in range(1, num_epochs+1):
    model.train()
    train_loss = 0
    train_loss_vals = []
    train_prior_bpd = 0
    for batch_idx, (x, _) in tqdm(enumerate(train_loader), total=len(train_loader)):
        x = x.to(device, dtype=torch.int32)
        # framework expects data shape (B, H, W, C)
        x = x.permute(0, 2, 3, 1)

        optimizer.zero_grad()
        loss = diffusion.training_losses(model, x_start=x, rng=25).mean()
        # prior_bpd = diffusion.prior_bpd(x).mean()

        loss.backward()
        train_loss += loss.item()
        train_loss_vals.append(loss.item())
        optimizer.step()

    train_loss /= batch_idx

    # evaluation
    model.eval()
    test_loss = 0
    test_loss_vals = []
    test_prior_bpd = 0
    test_total_bpd = 0
    with torch.no_grad():
        for batch_idx, (x_test, _) in enumerate(test_loader):
            x_test = x_test.to(device, dtype=torch.int32)
            x_test = x_test.permute(0, 2, 3, 1)

            l = diffusion.training_losses(model, x_start=x_test, rng=25).mean()
            test_loss += l.item()
            test_loss_vals.append(l.item())

            # loss_dict = diffusion.calc_bpd_loop(model, x_start=x_test, rng=25)
            # total_bpd = torch.mean(loss_dict['total'], dim=0)
            # prior_bpd = torch.mean(loss_dict['prior'], axis=0)
    
    test_loss /= batch_idx

    train_losses.append(train_loss_vals)
    test_losses.append(test_loss_vals)

    samples = diffusion.p_sample_loop(model_fn=model, shape=(1, 32, 32, 3), rng=25)
    imshow(samples[0].detach().cpu())

    print("\tEpoch,", e, "complete!", "\tTrain Loss: ", train_loss,
          "\tTest Loss: ", test_loss)
        

  0%|          | 0/469 [00:00<?, ?it/s]

torch.Size([128, 1, 32, 32])
tensor(255, dtype=torch.uint8)
tensor(0, dtype=torch.uint8)
torch.Size([128, 32, 32, 1])


  0%|          | 0/469 [00:04<?, ?it/s]


KeyboardInterrupt: 

In [None]:
plt.plot(np.arange(1, num_epochs+1), [np.mean(ls) for ls in train_losses], lw=2.5, label='train')
plt.plot(np.arange(1, num_epochs+1), [np.mean(ls) for ls in test_losses], lw=2.5, label='test')
plt.xlabel('epoch')
plt.ylabel('loss')
# plt.yscale('log')
plt.title('Average training/test loss per epoch')
plt.legend()
plt.show()

In [20]:
total_bpd_list = []
prior_bpd_list = []

for batch_idx, (x, _) in tqdm(enumerate(test_loader), total=len(test_loader)):
    x = x.to(device, dtype=torch.int32)
    # framework expects data shape (B, H, W, C)
    x = x.permute(0, 2, 3, 1)

    loss_dict = diffusion.calc_bpd_loop(model, x_start=x)

    print(loss_dict['total'].shape)
    print(loss_dict['prior'].shape)

    total_bpd = torch.mean(loss_dict['total'], dim=0)
    prior_bpd = torch.mean(loss_dict['prior'], dim=0)
    total_bpd_list.append(total_bpd.detach().cpu())
    prior_bpd_list.append(prior_bpd.detach().cpu())

avg_total_bpd = torch.mean(torch.tensor(total_bpd_list), dim=0)
avg_prior_bpd = torch.mean(torch.tensor(prior_bpd_list), dim=0)

  0%|          | 0/79 [00:00<?, ?it/s]

: 

In [4]:
for t in list(range(1000))[::-1]:
    print(t)

999
998
997
996
995
994
993
992
991
990
989
988
987
986
985
984
983
982
981
980
979
978
977
976
975
974
973
972
971
970
969
968
967
966
965
964
963
962
961
960
959
958
957
956
955
954
953
952
951
950
949
948
947
946
945
944
943
942
941
940
939
938
937
936
935
934
933
932
931
930
929
928
927
926
925
924
923
922
921
920
919
918
917
916
915
914
913
912
911
910
909
908
907
906
905
904
903
902
901
900
899
898
897
896
895
894
893
892
891
890
889
888
887
886
885
884
883
882
881
880
879
878
877
876
875
874
873
872
871
870
869
868
867
866
865
864
863
862
861
860
859
858
857
856
855
854
853
852
851
850
849
848
847
846
845
844
843
842
841
840
839
838
837
836
835
834
833
832
831
830
829
828
827
826
825
824
823
822
821
820
819
818
817
816
815
814
813
812
811
810
809
808
807
806
805
804
803
802
801
800
799
798
797
796
795
794
793
792
791
790
789
788
787
786
785
784
783
782
781
780
779
778
777
776
775
774
773
772
771
770
769
768
767
766
765
764
763
762
761
760
759
758
757
756
755
754
753
752
751
750
