Skip to content

Commit

Permalink
docs: change chinese comment to English
Browse files Browse the repository at this point in the history
  • Loading branch information
Kiteretsu77 committed Jan 2, 2024
1 parent e51790f commit e2eeced
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 69 deletions.
11 changes: 5 additions & 6 deletions Real_CuGAN/cunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def forward(self, x):
# x3 = F.leaky_relu(x3, 0.1, inplace=True)
# z = self.conv_bottom(x3)

#大概逻辑就是一个变小+conv+deconv,另外一个正常conv,然后两个加载一起并且deconv到大概两倍的样子
return z

def forward_a(self, x):
Expand Down Expand Up @@ -163,7 +162,7 @@ def __init__(self, in_channels, out_channels, deconv):
nn.init.constant_(m.bias, 0)

def forward(self, x, alpha = 1):
#整体就是unet结构
# UNet Structure
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x1 = F.pad(x1, (-16, -16, -16, -16))
Expand All @@ -185,28 +184,28 @@ def forward(self, x, alpha = 1):
z = self.conv_bottom(x5)
return z

def forward_a(self, x):#conv234结尾有se
def forward_a(self, x):
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x1 = F.pad(x1, (-16, -16, -16, -16))
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2.conv(x2)
return x1,x2

def forward_b(self, x2): # conv234结尾有se
def forward_b(self, x2):
x3 = self.conv2_down(x2)
x2 = F.pad(x2, (-4, -4, -4, -4))
x3 = F.leaky_relu(x3, 0.1, inplace=True)
x3 = self.conv3.conv(x3)
return x2,x3

def forward_c(self, x2,x3): # conv234结尾有se
def forward_c(self, x2,x3):
x3 = self.conv3_up(x3)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
x4 = self.conv4.conv(x2 + x3)
return x4

def forward_d(self, x1,x4): # conv234结尾有se
def forward_d(self, x1,x4):
x4 = self.conv4_up(x4)
x4 = F.leaky_relu(x4, 0.1, inplace=True)
x5 = self.conv5(x1 + x4)
Expand Down
4 changes: 2 additions & 2 deletions Real_CuGAN/upcunet_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ def __init__(self, unet_full_weight_path, adjust):


def forward(self, input):
x = F.pad(input, (18, 18, 18, 18), 'reflect') # pad最后一个倒数第二个dim各上下18个(总计36个)
x = F.pad(input, (18, 18, 18, 18), 'reflect') # (18, 18, 18, 18) is emperical padding hard-code

######################## Neural Network Inference #############################
unet_full_output = self.unet_model_full(x)
#############################################################################


# 目前默认是pro mode (pro跟weight有关)
# pro mode in default process
return ((unet_full_output - 0.15) * (255/0.7)).round().clamp_(0, 255).byte()


Expand Down
3 changes: 0 additions & 3 deletions process/crop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def crop4partition_SR(img, position = 3):

# Preparation
scale = configuration.scale
# TODO: 这里还没有完全考虑rescale后的办法,这个需要再想想 !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

height, width, _ = img.shape
crop_base = height // 3 # 160 for 480p
side_extra_padding = configuration.pixel_padding // 3
Expand Down Expand Up @@ -114,7 +112,6 @@ def combine_partitions_SR(crop1, crop2, crop3):


scale = configuration.scale
# TODO: 这里还没有完全考虑rescale后的办法,这个需要再想想 !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
cropped_padding = configuration.pixel_padding
h, w, c = crop1.shape

Expand Down
57 changes: 26 additions & 31 deletions process/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(self, configuration, model_full_name, model_partition_name, process
print("This process id is ", self.process_id)


self.max_cache_loop = (self.nt + self.full_model_num) * 200 # The is for bug report purpose (正常来说,取值要小于total frame num,不然就查不出来了)
self.max_cache_loop = (self.nt + self.full_model_num) * 200 # The is for bug report purpose (should be less thantotal frame num)
#################################################################################################################################

################################### Load model ##################################################################################
Expand Down Expand Up @@ -254,7 +254,7 @@ def __call__(self, input_path, output_path):

video_decode_loop_start = ttime()
######################################### video decode loop #######################################################
for frame_idx, frame in enumerate(objVideoreader.iter_frames(fps=self.decode_fps)): # 删掉了target fps
for frame_idx, frame in enumerate(objVideoreader.iter_frames(fps=self.decode_fps)):

# Rescale the video frame at the beginning if we want a different output resolution
if self.rescale_factor != 1:
Expand All @@ -264,14 +264,12 @@ def __call__(self, input_path, output_path):
if frame_idx % 50 == 0 or int(self.total_frame_number) == frame_idx:
# 以后print这边用config统一管理
print("Total frame:%s\t video decoded frames:%s"%(int(self.total_frame_number), frame_idx))
sleep(self.decode_sleep) # 否则解帧会一直抢主进程的CPU到100%,不给其他线程CPU空间进行图像预处理和后处理
# 目前nt=1的情况来说,不写也无所谓
sleep(self.decode_sleep) # This is very needed to avoid conflict in the process that is too overwhelming


# Check if NN process too slow to catch up the frames decoded
decode_processed_diff = frame_idx - self.now_idx
if decode_processed_diff >= 650: # This is an empirical value
#TODO: 这个插值也要假如config中
self.frame_write()
if decode_processed_diff >= 1000:
# Have to do this else it's possible to raise bugs
Expand All @@ -290,7 +288,6 @@ def __call__(self, input_path, output_path):
queue_put_idx = [3]

# Init the reference_frame and reference_idx
# TODO: reference_frame && reference_idx 写到一个func中
for i in range(3):
cropX = eval("crop%s"%i)
self.reference_frame[i] = cropX[:, :, 0] # We only store Single Red Channel to compare to accelerate
Expand All @@ -300,7 +297,7 @@ def __call__(self, input_path, output_path):
self.time2switchFULL -= 1 # Update the counter
self.momentum_used_times += 1 # Update the number of frames we process with momentum

# 根据full_model_num和nt 进行调整
# Abjust based on full_model_num and nt
if self.full_model_num > 0:
queue_put_idx = [3]
else: # In this case, we can only use partition queue
Expand Down Expand Up @@ -336,7 +333,6 @@ def __call__(self, input_path, output_path):
self.momentum_reference.append(True)
if sum(self.momentum_reference) == self.momentum_reference_size:
# If we have momentum_reference_size amount of frames that have big MSE difference between consequent frames, we activate MOMENTUM mechanism
# 考虑到reference_frame的重置,我们time2switchFULL影响到的frame是 momentum_skip_crop_frame_num + 1
# print("We need to use momentum at frame {}".format(frame_idx))
self.time2switchFULL = self.momentum_skip_crop_frame_num # Set how many frames we will skip
self.reference_frame = [None, None, None] # Reset reference
Expand Down Expand Up @@ -375,7 +371,7 @@ def __call__(self, input_path, output_path):
# print("We skip frame_idx {} and partition_idx {}!".format(frame_idx, partition_idx))


# Put partition/full frames into the queue 这里只是管理queue的,其他的比如idx2res这些都在上面处理完了(样子设计就是为了更好的程序设计)
# Put partition/full frames into the queue
if 3 in queue_put_idx:
# Full frame put into the queue
assert(len(queue_put_idx) == 1) # We cannot have partition idx here
Expand All @@ -399,8 +395,7 @@ def __call__(self, input_path, output_path):


################################################ 后面残留的计算 ##################################################
frame_idx += 1 # 调整成frames总数量
#等待所有的处理完,最后读取一遍全部的图片
frame_idx += 1 # Fit the total number of frames
while True:
self.frame_write()

Expand All @@ -415,13 +410,14 @@ def __call__(self, input_path, output_path):
assert(self.inp_q.qsize() == 0)
assert(self.res_q.qsize() == 0)
break

# TODO: 这个目前发现不运行就会出问题, 要不要用decode sleep统一处理, 这个bug是不是res_q满载了的原因
sleep(self.decode_sleep) # 原本0.01
# TODO: optimize this code if needed
sleep(self.decode_sleep)

print("Final image index is ", self.now_idx)

for _ in range(self.nt): # 全部结果拿到后,关掉模型线程
# Closs all queues
for _ in range(self.nt):
self.inp_q.put(None)
for _ in range(self.full_model_num):
self.inp_q_full.put(None)
Expand All @@ -431,9 +427,9 @@ def __call__(self, input_path, output_path):


video_decode_loop_end = ttime()
################################################################################################################
################################################################################################################################################

##################################### 分析汇总 ##################################################################
############################################################ Analysis ##########################################################################
# Calculation
full_time_spent = video_decode_loop_end - video_decode_loop_start
total_exe_fps = self.total_frame_number / full_time_spent
Expand All @@ -446,8 +442,7 @@ def __call__(self, input_path, output_path):
print("Done! Total time cost:", full_time_spent)
else:
print("Done! Total time cost: %d min %d s" %(full_time_spent//60, full_time_spent%60))
# print("The total duration is ", total_duration)
# print("The scaling of processing_time/total_video_duration is {} %".format((full_time_spent/total_duration) * 100))


# Details report
print("The following is the detailed report:")
Expand All @@ -456,10 +451,10 @@ def __call__(self, input_path, output_path):
self.parition_processed_num, 100 * self.parition_processed_num / (self.total_frame_number * 3)))
print("\t Total full_frame_cal_num is %d which is %.2f %%" %(self.full_frame_cal_num, 100*full_frame_portion))
print("\t Total momentum used num is %d which is %.2f %%" %(self.momentum_used_times, 100*self.momentum_used_times//self.total_frame_number))
################################################################################################################
#############################################################################################################################################


##################################### Generate Final Report ####################################################
##################################### Generate Final Report #################################################################################
report = {}
report["input_path"] = input_path
report["full_time_spent"] = full_time_spent
Expand All @@ -477,22 +472,22 @@ def frame_write(self):
''' Extract parition/full frame from res_q and write to ffmpeg writer (moviepy)
'''

#Step1:写入暂存器,因为多进程多线程的结果是不均匀出来的
while True: # 取出处理好的所有结果
#Step1:Write to tmp places because multi-process and multi-threading is unblanced
while True: # Get all the processed results
if self.res_q.empty():
break
iidx, position, res = self.res_q.get()
self.idx2res[iidx][position] = res


#Step2: 把暂存器的内容写到writer中
#Step2: put the tmp result to video writer
while True: # 按照idx排序写帧
if not self.res_q.empty():
iidx, position, res = self.res_q.get()
self.idx2res[iidx][position] = res

#这里一定保证是sequential的,所以repeat frame前面的reference完全有加载
if self.loop_counter == self.max_cache_loop: #####这个系数也要config管理!#####
# For sanity check purpose
if self.loop_counter == self.max_cache_loop:
self.writer.close()
print("Ends at frame ", self.now_idx)
print("\t Continuously not found, end the program and store what's stored")
Expand All @@ -505,11 +500,11 @@ def frame_write(self):
self.loop_counter = 0


########################################## 下面确保了frame的所有部分都是在的 ############################################
######################################## The following is safe to execute with all frames needed #######################################
if self.now_idx % 50 == 0:
print("Process {} had written frames: {}".format(self.process_id, self.now_idx))

# 3种类型的crop处理
# 3 types of cropping
if 3 not in self.idx2res[self.now_idx]:
# Partition Frame cases
crops = []
Expand All @@ -526,13 +521,13 @@ def frame_write(self):
# This one is NN inferenced result
crops.append(self.idx2res[self.now_idx][idx])

combined_frame = combine_partitions_SR(*crops) # adjust是完全固定的值
combined_frame = combine_partitions_SR(*crops) # adjust is a fixed value
else:
# Full frame cases
combined_frame = self.idx2res[self.now_idx][3]

# Write the frame
# cv2.imwrite(str(self.now_idx)+".png", cv2.cvtColor(combined_frame, cv2.COLOR_BGR2RGB)) # For Debug purpose (这个只能process=1的时候,不然会相互write without protection)
# cv2.imwrite(str(self.now_idx)+".png", cv2.cvtColor(combined_frame, cv2.COLOR_BGR2RGB)) # For Debug purpose (Please set the process_num = 1)
self.writer.write_frame(combined_frame)


Expand All @@ -545,7 +540,7 @@ def frame_write(self):


def queue_put(self, frame_idx, position, frame, full):
''' Put into the queue (集中管理)
''' Put into the queue
Args:
frame_idx (int): Global frame index
position (int): Position of frame 0|1|2|3
Expand Down
1 change: 0 additions & 1 deletion process/mass_production.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def mass_process(input_folder_dir, output_dir_parent):


# Process the video
# TODO: 利用log的report看看要不要减少partition的thread数量,毕竟相同视频类型都是相似的
parallel_process(input_dir, output_name, parallel_num=configuration.process_num)


Expand Down
3 changes: 1 addition & 2 deletions process/single_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def split_video(input_file, parallel_num):
print("duration is ", duration)
divide_time = math.ceil(duration // parallel_num) + 1

# TODO: 直接拆分audio出来,这样子就不会出现中途有卡壳的情况

# Split audio
audio_split_cmd = "ffmpeg -i " + input_file + " -map 0:a -c copy tmp/output_audio.m4a"
os.system(audio_split_cmd)
Expand Down Expand Up @@ -252,7 +252,6 @@ def single_process(model_full_name, model_partition_name, params, process_id):
config_preprocess(params, configuration)


# TODO: 我觉得这里应该直接读取video height和width然后直接选择模型,不然每次自己手动很麻烦
video_upscaler = VideoUpScaler(configuration, model_full_name, model_partition_name, process_id)
print("="*100)
print("Current Processing file is ", configuration.inp_path)
Expand Down
27 changes: 3 additions & 24 deletions tensorrt_weight_generator/weight_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def cunet_pre_process(self, array):
# Don't forget this np2tensor
tensor = np2tensor(array, pro=True)
# tensor = ToTensor()(array).unsqueeze(0).cuda()
input = F.pad(tensor, (18, 18, 18, 18), 'reflect').cuda() # pad最后一个倒数第二个dim各上下18个(总计36个)
input = F.pad(tensor, (18, 18, 18, 18), 'reflect').cuda() # (18, 18, 18, 18) is hard-code padding

return input

Expand All @@ -64,26 +64,6 @@ def rrdb_preprocess(self, array):

return tensor

def after_process(self, x):

####Q: 是不是add以后这边变成cpu更加节约时间

# h0 = 480
# w0 = 640
# ph = ((h0 - 1) // 2 + 1) * 2 # 应该是用来确认奇数偶数的
# pw = ((w0 - 1) // 2 + 1) * 2
# if w0 != pw or h0 != ph:
# x = x[:, :, :h0 * 2, :w0 * 2] #调整成偶数的size

if self.h%2 != 0 or self.w%2 != 0:
print("ensure that width and height to be even number")
os._exit(0)

########目前默认是pro mode
temp = ((x - 0.15) * (255/0.7)).round().clamp_(0, 255).byte()
# print("After after-process, the shape is ", temp.shape)
return temp


def model_weight_transform(self, input):

Expand Down Expand Up @@ -151,7 +131,7 @@ def model_weight_transform(self, input):
print("Finish generating the tensorRT weight and save at {}".format(save_path))


# 测试一下output
# Transform to trt model
output = model_trt_model(input)
print("TensorRT Sample input shape is ", input.shape)
print("TensoRT Sample output shape is ", output.shape)
Expand All @@ -160,7 +140,6 @@ def model_weight_transform(self, input):


def weight_generate(self):
# 如果要从头开始weight生成的话,dont_calculate_transform为false;只是image大量测试,就用true就行
self.dont_calculate_transform = False

self.model_weight_transform(self.sample_input)
Expand Down Expand Up @@ -210,7 +189,7 @@ def crop_image(sample_img_dir, target_h, target_w):
print("Such height and/or width is not supported, please use a larger sample input")
os._exit(0)

croped_img = img[:target_h, :target_w,:] # 第一个是height,第二个是width
croped_img = img[:target_h, :target_w,:]
print("Size after crop is ", croped_img.shape)
cv2.imwrite(configuration.full_croppped_img_dir, croped_img)

Expand Down

0 comments on commit e2eeced

Please sign in to comment.