In [1]:
import cv2
import torch
import numpy as np
import torch.nn as nn
from torchvision import transforms
from PIL import Image

In [3]:
# 读取视频
video_path = './vid/noise1.mp4'
output_path = 'output.mp4'

nf = 8
nh = 4
nf = 8
nh = 32
class deep58(nn.Module):
    def __init__(self):
        super(deep58,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3,nf,kernel_size=8,stride=1,padding=0),
            nn.BatchNorm2d(nf),
            nn.ReLU()
            )
        self.conv2 = nn.Sequential(
            nn.Conv2d(nf,nf*2,kernel_size=5,stride=2,padding=0),
            nn.BatchNorm2d(nf*2),
            nn.ReLU()
            )
        self.conv3 = nn.Sequential(
            nn.Conv2d(nf*2,nh,kernel_size=5,stride=2,padding=0),
            nn.BatchNorm2d(nh),
            nn.ReLU()
            )
        self.convTrans1 = nn.Sequential(
            nn.ConvTranspose2d(nh,nf*2,kernel_size=5,stride=2,padding=0),
            nn.BatchNorm2d(nf*2),
            nn.ReLU()
            )
        self.convTrans2 = nn.Sequential(
            nn.ConvTranspose2d(nf*4,nf,kernel_size=5,stride=2,padding=0),
            nn.BatchNorm2d(nf),
            nn.ReLU(),
            )
        self.convTrans3 = nn.Sequential(
            nn.ConvTranspose2d(nf+3,3,kernel_size=5,stride=1,padding=0),
            nn.BatchNorm2d(3),
            nn.Tanh()
            )
        self.convTrans4 = nn.Sequential(
            nn.ConvTranspose2d(3,3,kernel_size=4,stride=1,padding=0),
            nn.BatchNorm2d(3),
            nn.Tanh()
            )
    def forward(self,x):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        x4 = self.convTrans1(x3)
        x5 = self.convTrans2(torch.cat((x2,x4),dim=1))
        x6 = self.convTrans3(torch.cat((x1[:,0:3],x5),dim=1))
        x7 = self.convTrans4(x6)
        return x7

model = torch.load('deVTXNoise4G.pth').cuda()

# 转换为PyTorch张量
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=([0.5,0.5,0.5]), std=([0.5,0.5,0.5]))
    ])

# 打开视频
cap = cv2.VideoCapture(video_path)

# 获取视频基本信息
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

# 创建输出视频
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    # 处理每一帧
    frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    frame_tensor = transform(frame_pil)
    
    frame_tensor = frame_tensor.unsqueeze(0)

    # 使用PyTorch模型处理帧
    with torch.no_grad():
        output = (model(frame_tensor.cuda())*0.5)+0.5

    # 将处理后的帧转换回NumPy数组
    output_np = output.cpu().squeeze().numpy().transpose((1, 2, 0))
    output_np = (output_np * 255).astype(np.uint8)

    # 输出处理后的视频帧
    out.write(cv2.cvtColor(output_np, cv2.COLOR_RGB2BGR))

cap.release()
out.release()