In [48]:
import torch

from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM

import math
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import ignite.metrics as ig

import torchvision
from torchvision import datasets, transforms

from matplotlib import pyplot as plt
from IPython.display import clear_output

import os
from PIL import Image

import time

In [2]:
teach_im_count = 160
batch_size = 8
begin_offset = 0

In [3]:
folder_path = 'Record_1'

# Получите список файлов в папке
image_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]

# Инициализируйте список для хранения массивов изображений
image_arrays = []
# Загрузите каждое изображение, преобразуйте его в массив и добавьте в список
for i, image_file in enumerate(image_files):
    image_path = os.path.join(folder_path, image_file)
    if i >= begin_offset:
        with Image.open(image_path) as img:
            image_array = np.array(img)
            image_arrays.append(image_array)
    if i == teach_im_count + begin_offset - 1:
        break
# Преобразуйте список массивов в один numpy массив
record_array = np.array(image_arrays)

print(record_array.shape)

class OneImDataset(Dataset):
    def __init__(self, x_arr):
        self.transform1 = transforms.Compose([
            transforms.ToTensor()
        ])
        
        self.x = x_arr
    def __len__(self):
        return len(self.x)
    def __getitem__(self, idx):
        x = self.transform1(self.x[idx])
        return x
        
dataset = OneImDataset(record_array)
dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = True)

record_array = None

(160, 420, 420, 3)


In [16]:
def show_images_layer(images, size):
    f, axes = plt.subplots(1, size, figsize = (30, 10))
    for i, ax in enumerate(axes.ravel()):
        image = images[i].detach().numpy()
        image = image.transpose((1, 2, 0))
        ax.imshow(image)
    plt.show()

In [5]:
def t_to_im(tensor):
    return(tensor.cpu().detach().numpy().transpose(1, 2, 0))

In [6]:
def show_images_real(orig, record, real, size):

  f, axes = plt.subplots(3, size, figsize = (30, 16))
  for i, ax in enumerate(axes.ravel()):
    if i < size:
        image = orig[i].detach().numpy()
        image = image.transpose((1, 2, 0))
        ax.imshow(image)
    elif i < 2 * size:
        image = record[i - size].detach().numpy()
        image = image.transpose((1, 2, 0))
        ax.imshow(image)
    else:
        image = real[i - 2 * size].detach().numpy()
        image = image.transpose((1, 2, 0))
        ax.imshow(image)
  plt.show()

In [26]:
def show_images_all(images, col_count, size):
  l = size // col_count
  f, axes = plt.subplots(l, col_count, figsize = (30, 9 * l))
  for i, ax in enumerate(axes.ravel()):
    ax.imshow(t_to_im(images[i]))
  plt.show()

In [8]:
def add_noise_linear1(image, t, T):
    beta_0 = 0.0001
    beta_T = 0.02
    betas = np.linspace(beta_0, beta_T, T)
    alpha = np.prod(1 - betas[0:t])
    alpha = alpha
    noise = torch.randn_like(image)
    noisy_image = alpha ** 0.5 * image + (1 - alpha) ** 0.5 * noise
    return noisy_image

In [38]:
def add_noise_trig(image, t, T):
    beta_0 = 0.02
    beta_T = 1
    start_angle = np.arccos(beta_T)
    end_angle = np.arccos(beta_0)
    angle = start_angle +  t / T * (end_angle - start_angle)
    noise = torch.randn_like(image)
    noisy_image = np.cos(angle) * image + np.sin(angle) * noise
    return noisy_image

In [18]:
x = None
for batch in dataloader:
    x = batch

In [19]:
image = x[0]

In [13]:
image.shape

torch.Size([3, 420, 420])

In [20]:
im_array = []
T = 100
for i in range(T):
    im_array.append(add_noise_linear1(image, i, T))

In [None]:
show_images_all(im_array, 5, 100)

In [39]:
im_array = []
T = 100
for i in range(T):
    im_array.append(add_noise_trig(image, i, T))

In [None]:
show_images_all(im_array, 5, 100)

In [41]:
def sinusoidal_embedding(timesteps, dim):
    half_dim = dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    emb = timesteps[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
    return emb

In [44]:
t = torch.rand(8)

In [49]:
emb = sinusoidal_embedding(t, 320)

In [45]:
t

tensor([0.2124, 0.4290, 0.6410, 0.9734, 0.6906, 0.4237, 0.1744, 0.4501])

In [51]:
emb.shape

torch.Size([8, 320])

In [53]:
emb[:,:,None,None].shape

torch.Size([8, 320, 1, 1])

In [54]:
t = torch.randn((8, 3, 420, 420))

In [56]:
t + emb[:,:,None,None]

RuntimeError: The size of tensor a (3) must match the size of tensor b (320) at non-singleton dimension 1

In [67]:
class Swish(nn.Module):
    def forward(self, x):
      return x * torch.sigmoid(x)

In [68]:
class TimeEmbedding(nn.Module):
     def __init__(self, n_channels: int):
         super().__init__()
         self.n_channels = n_channels
         self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
         self.act = Swish()
         self.lin2 = nn.Linear(self.n_channels, self.n_channels)
     def forward(self, t: torch.Tensor):
          half_dim = self.n_channels // 8
          emb = math.log(10_000) / (half_dim - 1)
          emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
          emb = t[:, None] * emb[None, :]
          emb = torch.cat((emb.sin(), emb.cos()), dim=1)

          emb = self.act(self.lin1(emb))
          emb = self.lin2(emb)
          return emb

In [95]:
def compress(in_channels1, in_channels2, in_channels3, out_channels, ks1, ks2, ks3, size):
    block = nn.Sequential(
        nn.Conv2d(in_channels = in_channels1, out_channels = in_channels2, kernel_size = ks1, padding = 0),
        nn.ReLU(),
        nn.Conv2d(in_channels = in_channels2, out_channels = in_channels2, kernel_size = ks2, padding = 1),
        nn.ReLU(),
        #nn.BatchNorm2d(in_channels2),
        nn.Conv2d(in_channels = in_channels2, out_channels = in_channels3, kernel_size = 5, padding = 2),
        nn.ReLU(),
        #nn.BatchNorm2d(in_channels3),
        nn.Conv2d(in_channels = in_channels2, out_channels = in_channels3, kernel_size = 3, padding = 1),
        nn.ReLU(),
        nn.Conv2d(in_channels = in_channels3, out_channels = out_channels, kernel_size = 3, padding = 1),
        nn.ReLU(),
        nn.Conv2d(in_channels = in_channels3, out_channels = out_channels, kernel_size = 3, padding = 1),
        nn.ReLU(),
        #nn.BatchNorm2d(out_channels),
        nn.MaxPool2d(2),
        nn.ReLU()
        #nn.BatchNorm2d(out_channels)
    )

    return block
    
def upsample_without_batchnorm(in_channels1, in_channels2, in_channels3, out_channels, ks1, ks2, ks3, scale, size):
    block = nn.Sequential(
        nn.Upsample(size = size),
        nn.Conv2d(in_channels = in_channels1, out_channels = in_channels2, kernel_size = ks1, padding = 0),
        nn.ReLU(),
        #nn.BatchNorm2d(in_channels2),
        nn.Conv2d(in_channels = in_channels2, out_channels = in_channels2, kernel_size = ks2, padding = 1),
        nn.ReLU(),
        nn.Conv2d(in_channels = in_channels2, out_channels = in_channels2, kernel_size = 5, padding = 2),
        nn.ReLU(),
        nn.Conv2d(in_channels = in_channels2, out_channels = in_channels2, kernel_size = 3, padding = 1),
        nn.ReLU(),
        nn.Conv2d(in_channels = in_channels2, out_channels = in_channels2, kernel_size = 3, padding = 1),
        nn.ReLU(),
        nn.Conv2d(in_channels = in_channels2, out_channels = in_channels2, kernel_size = 3, padding = 1),
        nn.ReLU(),
    )

    return block

ch_size = [3, 32, 64, 128, 256, 512]

class Model8(nn.Module):
    def __init__(self):
        super().__init__()
        self.time_emb1 = TimeEmbedding(3)
        self.c1 = compress(ch_size[0], ch_size[1], ch_size[1], ch_size[1], 1, 3, 3, 420) #420 -> 210
        self.time_emb2 = TimeEmbedding(32)
        self.c2 = compress(ch_size[1], ch_size[2], ch_size[2], ch_size[2], 1, 5, 3, 208) #210 -> 104
        self.time_emb3 = TimeEmbedding(64)
        self.c3 = compress(ch_size[2], ch_size[3], ch_size[3], ch_size[3], 1, 3, 3, 208) #104 -> 52
        self.time_emb4 = TimeEmbedding(128)
        self.c4 = compress(ch_size[3], ch_size[4], ch_size[4], ch_size[4], 1, 3, 2, 208) #52 -> 26
        self.time_emb5 = TimeEmbedding(256)
        self.c5 = compress(ch_size[4], ch_size[5], ch_size[5], ch_size[5], 1, 3, 2, 208) #26 -> 13

        
        self.d1 = upsample_without_batchnorm(ch_size[5], ch_size[5], ch_size[5], ch_size[4], 1, 3, 3, 2.5, 26)
        self.d2 = upsample_without_batchnorm(3 * ch_size[4], ch_size[4], ch_size[4], ch_size[3], 1, 3, 3, 2.5, 52) #23 -> 53 -> 49 2.31
        self.d3 = upsample_without_batchnorm(3 * ch_size[3], ch_size[3], ch_size[3], ch_size[2], 1, 3, 3, 2.18, 104) #49 -> 106 -> 102
        self.d4 = upsample_without_batchnorm(3 * ch_size[2], ch_size[2], ch_size[2], ch_size[1], 1, 3, 3, 2.08, 210) #102 -> 212 -> 208
        self.d5 = upsample_without_batchnorm(3 * ch_size[1], ch_size[1], ch_size[1], ch_size[0], 1, 3, 3, 2.04, 420) #208 -> 424 -> 420
    
    def forward(self, x, t):
        x += self.time_emb1(F.relu(t))[:, :, None, None]
        x = self.c1(x)
        x1 = x
        x += self.time_emb2(F.relu(t))[:, :, None, None]
        x = self.c2(x)
        x2 = x

        x += self.time_emb3(F.relu(t))[:, :, None, None]
        x = self.c3(x)
        x3 = x

        x += self.time_emb4(F.relu(t))[:, :, None, None]
        x = self.c4(x)
        x4 = x
        print(f'x1: {x1.shape}')
        print(f'x2: {x2.shape}')
        print(f'x3: {x3.shape}')
        print(f'x4: {x4.shape}')
        print(x.shape)
        
        x += self.time_emb5(F.relu(t))[:, :, None, None]
        x = self.c5(x)
        print(x.shape)
        x = self.d1(x)
        x = torch.cat((x, x4), 1)
        print(x.shape)
        x = self.d2(x)
        x = torch.cat((x, x3), 1)
        #print(x.shape)
        x  = self.d3(x)
        x = torch.cat((x, x2), 1)

        x = self.d4(x)
        x = torch.cat((x, x1), 1)

        x = self.d5(x)
        return x

In [None]:
def conv(in_channels ,out_channels):
    block = nn.Sequential(
        nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 1),
        nn.ReLU(),
        nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = 3, padding = 1),
        nn.ReLU(),
        nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = 3, padding = 1),
        nn.ReLU()
    )
    return block

def time_emb(size, channels):
    block = nn.Sequential(
        nn.Linear(1, size ** 2 * channels),
        nn.ReLU(),
        nn.Linear(size ** 2 * channels, size ** 2 * channels),
        nn.Tanh()
    )
    return block



class DifUNet(nn.Module):
    def __init__(self):
        super().__init__()
        ch = [3, 32, 64, 128, 256, 512, 1024]
        self.mp = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor = 2)
        
        self.c1 = conv(ch[0], ch[1])
        self.c2 = conv(ch[1], ch[2])
        self.c3 = conv(ch[2], ch[3])
        self.c4 = conv(ch[3], ch[4])
        self.c5 = conv(ch[4], ch[5])
        self.c6 = conv(ch[5], ch[6])

        self.emb = time_emb(10, 3)

        self.cc1 = conv(3 + ch[6], ch[5])
        self.cc2 = conv(2 * ch[5], ch[4])
        self.cc3 = conv(2 * ch[4], ch[3])
        self.cc4 = conv(2 * ch[3], ch[2])
        self.cc5 = conv(2 * ch[2], ch[1])
        self.cc6 = conv(2 * ch[1], ch[0])
        
    def forward(self, x, t):
        x = self.c1(x)
        x5 = x
        x = self.mp(x)

        x = self.c2(x)
        x4 = x
        x = self.mp(x)

        x = self.c3(x)
        x3 = x
        x = self.mp(x)

        x = self.c4(x)
        x2 = x
        x = self.mp(x)

        x = self.c5(x)
        x1 = x
        x = self.mp(x)

        x = self.c6(x)
        t = self.emb(t.to(torch.float32))
        t = t.view(-1, 3, 10, 10)
        x = torch.cat((t, x), dim = 1)
        
        x = self.cc1(x)
        x = self.up(x)

        x = torch.cat((x1, x), dim = 1)
        x = self.cc2(x)
        x = self.up(x)

        x = torch.cat((x2, x), dim = 1)
        x = self.cc3(x)
        x = self.up(x)

        x = torch.cat((x3, x), dim = 1)
        x = self.cc4(x)
        x = self.up(x)

        x = torch.cat((x4, x), dim = 1)
        x = self.cc5(x)
        x = self.up(x)

        x = torch.cat((x5, x), dim = 1)
        x = self.cc6(x)

        return x

In [58]:
tensor = torch.rand((8,3,420, 420))

In [72]:
time = torch.tensor([16]*8)

In [83]:
md = Model8()

In [None]:
md = Model8()
md(tensor, time)