# Diffusion개량
1. 커스텀 학습대신 커스텀 데이터셋을 이용한 디퓨전 학습


# 사전준비

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
import tensorflow as tf
import os
from PIL import Image
import numpy as np

pathTrain="C:\\Users\\82109\\Desktop\\train" #학습할 이미지 
pathTest="C:\\Users\\82109\\Desktop\\test" #벨리데이션 테스트 이미지
#pathGeneratorTest="C:\\Users\\82109\\Desktop\\generatorTest" #생성 테스트용 이미지
pathSave="C:\\Users\\82109\\Desktop\\checkPoint" #모델 저장할 위치
batchSize=1

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')

# 이미지 전처리 함수모음

In [None]:
#경로 반환
def GetFilePath(path,end=".gif"):
  gifFileList=os.listdir(path)
  gifPath=[]
  for name in gifFileList:
    if name.endswith(tuple(end)):
      gifPath.append(os.path.join(path,name))
  return gifPath

#못쓰는 데이터를 걸러줌
def PreprocessGif(path,frame=5):
  gif=Image.open(path)
  targetFrame=gif.n_frames
  targetSize=gif.size#이미지 크기가 너무크면 오버플로우 발생함
  gif.close()
  if targetSize[0]>targetSize[1]:
    targetSize=targetSize[0]
  else:
    targetSize=targetSize[1]

  if targetFrame<frame or targetSize>512:
    print(path,": ",targetFrame," ",targetSize,"사용불가능")
    os.remove(path=path)
  else:
    print(path,": ",targetFrame," ",targetSize," 사용가능")


#gif를 읽고 넘파이 배열로 노멀라이즈해줌
def LoadGif(path, paddingSize=32):
  gif=Image.open(path)
  flip=np.random.randint(0,7)
  extractFrame=np.random.randint(4,11)
  remainFrame=gif.n_frames-extractFrame
  start=0
  end=0
  if remainFrame<=1:
    start=1
    end=gif.n_frames
  else:
    start=np.random.randint(1,remainFrame+1)
    end=start+extractFrame
  images=[]
  for i in range(start,end):
    gif.seek(i)
    temp=gif.transpose(flip).convert("RGBA")
    temp=np.array(temp)
    height=(paddingSize-temp.shape[0]%paddingSize)%paddingSize
    width=(paddingSize-temp.shape[1]%paddingSize)%paddingSize
    temp=np.pad(temp, pad_width=((0,height),(0,width),(0,0)),mode="constant",constant_values=0)
    images.append(temp)
  gif.close()
  return np.array(images)/255.0


def LoadGifAll(path, paddingSize=32):
  gif=Image.open(path)
  flip=np.random.randint(0,7)

  images=[]
  for i in range(1,gif.n_frames):
    gif.seek(i)
    temp=gif.transpose(flip).convert("RGBA")
    temp=np.array(temp)
    height=(paddingSize-temp.shape[0]%paddingSize)%paddingSize
    width=(paddingSize-temp.shape[1]%paddingSize)%paddingSize
    temp=np.pad(temp, pad_width=((0,height),(0,width),(0,0)),mode="constant",constant_values=0)
    images.append(temp)
  gif.close()
  return np.array(images)/255.0


#인풋데이터와 아웃풋 데이터를 분리
def Divide(arr):
  evens=arr[0::2]
  odds=arr[1::2]
  if evens.shape[0] != odds.shape[0]:
    evens=evens[0:-1]
  return [evens,odds]
  
def diffusionSchedule(diffusionTime):
  startAng=np.arccos(0.99)
  endAng=np.arccos(0.1)
  diffusionAng=startAng+diffusionTime*(endAng-startAng) #DFT가 1에 가까울수록 노이즈(1에서 시작)
  sigRate=np.cos(diffusionAng) # DFT가 1에 가까울수록 0.01
  noiseRate=np.sin(diffusionAng) # DFT가 1에 가까울수록 0.99
  return sigRate,noiseRate
        
  
#데이터셋 제너레이터 생성
def DatasetGenerater(gifPath):
  #gif파일을 반환
  for i in gifPath:
      x,y= Divide(LoadGif(i))  
      noise=np.random.rand(x.shape[0],x.shape[1],x.shape[2],x.shape[3])
      step=np.ones((x.shape[0],x.shape[1],x.shape[2],1))
      sigRate, noiseRate=diffusionSchedule(np.random.rand())
      noisyImage=sigRate*y+noiseRate*noise
      step=step*sigRate
      yield np.concatenate([x,noisyImage,step],axis=-1),noise

def saveGif(path, images):
  imgs=[]
  for i in images:
    img=Image.fromarray((i*255).round().astype(np.int8), mode="RGBA")
    imgs.append(img)
  imgs[0].save(path, save_all=True, append_images=imgs[1:], disposal = 2,duration=150, loop=0)


# unet 모델 생성 함수모음

In [None]:
def seperableConv(filter, input):
    depthwise=tf.keras.layers.Conv3D(input.shape[-1],3,padding="same",groups=input.shape[-1])(input)
    pointwise=tf.keras.layers.Conv3D(filter,1,padding="same")(depthwise)
    return pointwise
    
def block(filter,input):
    #conv1=tf.keras.layers.Conv3D(filter,3,padding="same")(input)
    conv1=seperableConv(filter,input)
    layerNorm1=tf.keras.layers.LayerNormalization()(conv1)
    swishAct1=tf.keras.layers.Activation("swish")(layerNorm1)

    #conv2=tf.keras.layers.Conv3D(filter,3,padding="same")(swishAct1)
    conv2=seperableConv(filter,swishAct1)
    layerNorm2=tf.keras.layers.LayerNormalization()(conv2)
    swishAct2=tf.keras.layers.Activation("swish")(layerNorm2)
    
    return swishAct2

def unetModel(inputShape=(None, None, None, 9)):
    
    inputImage=tf.keras.Input(shape=inputShape)

    #인코딩
    e1=block(32,inputImage)
    e1Pooling=tf.keras.layers.MaxPooling3D(pool_size=(1, 2, 2))(e1)
    
    e2=block(64,e1Pooling)
    e2Pooling=tf.keras.layers.MaxPooling3D(pool_size=(1, 2, 2))(e2)
    
    e3=block(128,e2Pooling)
    e3Pooling=tf.keras.layers.MaxPooling3D(pool_size=(1, 2, 2))(e3)
    
    e4=block(256,e3Pooling)
    e4Pooling=tf.keras.layers.MaxPooling3D(pool_size=(1, 2, 2))(e4)
    
    #중간
    bottleNeck=block(512,e4Pooling)

    
    d4UpSampling=tf.keras.layers.UpSampling3D(size=(1,2,2))(bottleNeck)
    d4Transpose=seperableConv(256,d4UpSampling)
    d4Concatenate=tf.keras.layers.Concatenate()([d4Transpose,e4])
    d4=block(256,d4Concatenate)
    
    d3UpSampling=tf.keras.layers.UpSampling3D(size=(1,2,2))(d4)
    d3Transpose=seperableConv(128,d3UpSampling)
    d3Concatenate=tf.keras.layers.Concatenate()([d3Transpose,e3])
    d3=block(128,d3Concatenate)
    
    d2UpSampling=tf.keras.layers.UpSampling3D(size=(1,2,2))(d3)
    d2Transpose=seperableConv(64,d2UpSampling)
    d2Concatenate=tf.keras.layers.Concatenate()([d2Transpose,e2])
    d2=block(64,d2Concatenate)
    
    d1UpSampling=tf.keras.layers.UpSampling3D(size=(1,2,2))(d2)
    d1Transpose=seperableConv(32,d1UpSampling)
    d1Concatenate=tf.keras.layers.Concatenate()([d1Transpose,e1])
    d1=block(32,d1Concatenate)
    
    outputImage=seperableConv(4,d1)
    
    return tf.keras.Model(inputImage,outputImage)
unetModel().summary()

# 학습(학습할때만 사용)

In [None]:

print("")
print("트레인셋 전처리")
#사용불가능 파일 전처리
gifPath=GetFilePath(pathTrain) 
for i in gifPath:
  PreprocessGif(i)
gifPath=GetFilePath(pathTrain)

print("")
print("총 트레인셋 갯수",len(gifPath))

print("")
print("테스트셋 전처리")
gifPathTest=GetFilePath(pathTest) 
for i in gifPathTest:
  PreprocessGif(i)
gifPathTest=GetFilePath(pathTest)

print("")
print("총 테스트셋 갯수",len(gifPathTest))


trainDataset=tf.data.Dataset.from_generator(DatasetGenerater,
                               args=[gifPath], output_types=(tf.float32,tf.float32),
                               output_shapes = ((None, None,None,9),(None, None,None,4)))
#(inputImages, outputImages)
trainDataset=trainDataset.shuffle(5).batch(batchSize).prefetch(tf.data.experimental.AUTOTUNE)


testDataset=tf.data.Dataset.from_generator(DatasetGenerater,
                               args=[gifPathTest], output_types=(tf.float32,tf.float32),
                               output_shapes = ((None, None,None,9),(None, None,None,4)))
#(inputImages, outputImages)
testDataset=testDataset.shuffle(5).batch(batchSize).prefetch(tf.data.experimental.AUTOTUNE)



In [None]:
#모델생성
model=unetModel(inputShape=(None,None,None,9))
model.load_weights(os.path.join(pathSave,"1cp-0001.ckpt"))
#콜백생성
cpCallback = tf.keras.callbacks.ModelCheckpoint(
    filepath=os.path.join(pathSave,"cp-{epoch}.ckpt"), 
    verbose=1, 
    save_weights_only=True)
esCallback=tf.keras.callbacks.EarlyStopping(patience=5)
rlrCallback=tf.keras.callbacks.ReduceLROnPlateau(
    factor=0.1,
    patience=3,
    min_lr=1e-5)

#컴파일
model.compile(
    tf.keras.optimizers.experimental.AdamW(
        learning_rate=1e-2, weight_decay=1e-4
    ),
    loss=tf.keras.losses.mean_squared_error,
              metrics=['accuracy'])


In [None]:
model.fit(
    trainDataset,
    epochs=10,
    validation_data=testDataset,    
    callbacks=[
        cpCallback,
        esCallback,
        rlrCallback
    ]
)