In [53]:
import torch
import torchvision
from sdot.modules import SemiDiscreteOptimalTransport, PatchExtractor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

blog that explain diffusion model : https://medium.com/mlearning-ai/enerating-images-with-ddpms-a-pytorch-implementation-cef5a2ba8cb1

other post : https://lilianweng.github.io/posts/2021-07-11-diffusion-models/

In [54]:
def diffusion_step(input, beta):
    noise = torch.randn_like(input)
    diffused = torch.sqrt(1-beta)*input + torch.sqrt(beta)*noise
    return diffused

def beta_scheduler(beta_min, beta_max, num_step):
    return torch.linspace(beta_min, beta_max, num_step)



In [55]:
from PIL import Image
image_to_tensor = torchvision.transforms.ToTensor()
tex_image = image_to_tensor(Image.open('texture.png').convert('RGB')).unsqueeze(0)

In [56]:
class ConvBlock(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_0 = torch.nn.Conv2d(3, 16, 3, padding=1)
        self.conv_1 = torch.nn.Conv2d(16, 16, 3, padding=1)
        self.conv_2 = torch.nn.Conv2d(16, 3, 3, padding=1)
        self.activation = torch.nn.ReLU()

    def forward(self, input):
        out = self.activation(self.conv_0(input))
        out = self.activation(self.conv_1(out))
        out = self.activation(self.conv_2(out))
        return out

class DiffusionModel(torch.nn.Module):
    def __init__(self, num_step):
        super().__init__()
        betas = beta_scheduler(0.01,0.1,num_step)        
        self.model = torch.nn.ModuleList([ConvBlock() for i in range(num_step)])
        self.register_buffer('betas', betas)

    def forward(self, input, time_step):
        return diffusion_step(input, self.betas[time_step])

    def backward(self, input, time_step):
        prediction = self.model[time_step](input)
        beta = self.betas[time_step]
        return prediction# (input - torch.sqrt(beta)*prediction)/torch.sqrt(1-beta)

patch_size = 5
num_step = 10

patch_extractor = PatchExtractor(patch_size)
diffuser = DiffusionModel(num_step)

In [52]:
diff_image = tex_image

for s in range(num_step):
    step = torch.tensor(s, device = device)
    diff_image = diffuser(diff_image, step)
    torchvision.utils.save_image(diff_image, f'tmp/diffused_{step}.png')
    rec_image = diffuser.backward(diff_image, step)
    torchvision.utils.save_image(rec_image, f'tmp/reconstructed_{step}.png')


RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

In [57]:
diffuser = diffuser.to(device)
tex_image = tex_image.to(device)
patch_extractor = patch_extractor.to(device)

target_patches = patch_extractor(tex_image)

model_optimizer = torch.optim.Adam(diffuser.parameters(), lr=0.0001)

ot_discriminator = SemiDiscreteOptimalTransport(target_patches).to(device)
discriminator_optimizer = torch.optim.Adam(ot_discriminator.parameters(), lr=0.0001)

max_iter = 1000
monitoring_step = 10

iteration = 0
while iteration < max_iter + 1:


    time_step = torch.randint(0,num_step, (1,)).to(device)

    diff_image = tex_image
    for step in range(time_step+1):
        prev_image = diff_image
        diff_image = diffuser(diff_image, step)
    
    for i in range(10):
        discriminator_optimizer.zero_grad()
        rec_image = diffuser.backward(diff_image, time_step).detach()
        rec_patches = patch_extractor(rec_image)
        loss = -torch.mean(ot_discriminator(rec_patches))
        loss.backward()
        discriminator_optimizer.step()

    # reconstructed image 
    rec_image = diffuser.backward(diff_image, time_step)
    rec_patches = patch_extractor(rec_image)
    model_optimizer.zero_grad()
    loss = torch.mean(ot_discriminator(rec_patches))
    loss.backward()
    model_optimizer.step()

    iteration+=1
    print(f'iteration {iteration}, loss {loss.item()}')

    if iteration%monitoring_step==0:
        print(f'saving image at iteration {iteration}, loss {loss.item()}')
        torchvision.utils.save_image(rec_image, f'tmp/{iteration}_reconstructed.png')
        torchvision.utils.save_image(prev_image, f'tmp/{iteration}_target.png')



iteration 1, loss 0.1322564333677292
iteration 2, loss 0.15140537917613983
iteration 3, loss 0.04009786248207092
iteration 4, loss 0.10245775431394577
iteration 5, loss 0.14218223094940186
iteration 6, loss 0.043048031628131866
iteration 7, loss 0.04485150799155235
iteration 8, loss 0.1379941701889038
iteration 9, loss 0.09324098378419876
iteration 10, loss 0.048705849796533585
saving image at iteration 10, loss 0.048705849796533585
iteration 11, loss 0.11837611347436905
iteration 12, loss 0.09047222137451172
iteration 13, loss 0.15767505764961243
iteration 14, loss 0.08958911895751953
iteration 15, loss 0.13307610154151917
iteration 16, loss 0.055992480367422104
iteration 17, loss 0.05780224874615669
iteration 18, loss 0.059573911130428314
iteration 19, loss 0.11596029251813889
iteration 20, loss 0.13313181698322296
saving image at iteration 20, loss 0.13313181698322296
iteration 21, loss 0.06177684664726257
iteration 22, loss 0.06727024912834167
iteration 23, loss 0.08294971287250519

In [60]:
diff_image = tex_image
for s in range(num_step):
    step = torch.tensor(s, device = device)
    diff_image = diffuser(diff_image, step)

# print(diff_image)
for step in range(num_step):
    rec_image = diffuser.backward(diff_image, step) + 0.1*torch.randn_like(diff_image)
    diff_image = rec_image
    torchvision.utils.save_image(rec_image, f'tmp/reconstructed_{step}.png')

In [18]:
tex_image.shape

torch.Size([3, 128, 128])

In [43]:
torch.mean(tex_image)

tensor(0.1783, device='cuda:0')

In [62]:
uvxyz = torch.rand(1,1,1,5)
print(uvxyz)
uv = uvxyz[...,0:2]
ux = uvxyz[...,(0,2)]

print(uv)
print(ux)


tensor([[[[0.2270, 0.0695, 0.4281, 0.9451, 0.8638]]]])
tensor([[[[0.2270, 0.0695]]]])
tensor([[[[0.2270, 0.4281]]]])


In [67]:
test = []
test+= torch.zeros(1,2,2)

In [68]:
test

[tensor([[0., 0.],
         [0., 0.]])]