In [1]:
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import torchvision.transforms as transforms

# 1. 加载数据集文件
def load_dataset_files(data_path, begin, end):
    """获取 1~10 文件夹下以 M_ 开头的视频文件和对应的标注文件"""
    matched_files = []
    for i in range(begin, end + 1):
        folder_path = os.path.join(data_path, str(i))
        if not os.path.isdir(folder_path):
            continue
        for file in os.listdir(folder_path):
            if file.startswith("M_") and file.lower().endswith('.mp4'):
                video_path = os.path.join(folder_path, file)
                base_name = file.replace("M_", "").replace(".MP4", "")
                annotation_file_name = base_name + ".txt"
                annotation_file_path = os.path.join(folder_path, annotation_file_name)
                if os.path.exists(annotation_file_path):
                    matched_files.append((video_path, annotation_file_path))
                else:
                    print(f"{video_path} 视频未找到相对应的文本文件")
    print(f"找到 {len(matched_files)} 对（视频-文本）文件对")
    return matched_files

# 2. 加载标注文件
def load_annotation_file(annotation_file_path):
    """读取标注文件"""
    annotations = []
    try:
        with open(annotation_file_path, 'r') as file:
            lines = file.readlines()
            if lines[0].startswith("Frame"):
                lines = lines[1:]
            for line in lines:
                parts = line.strip().split()
                if len(parts) == 2:
                    try:
                        frame = int(parts[0])
                        phase = parts[1]
                        annotations.append({'frame': frame, 'phase': phase})
                    except ValueError:
                        print(f"⚠️ 无法解析行: {line}")
    except FileNotFoundError:
        print(f"⚠️ 找不到文件: {annotation_file_path}")
    print(f"从 {annotation_file_path} 加载了 {len(annotations)} 条标注")
    return annotations

# 3. 获取视频的帧率和总帧数
def check_video_fps_and_frames(video_path):
    """获取视频的帧率和总帧数"""
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"无法打开视频文件: {video_path}")
        return None, None
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.release()
    print(f"视频 {video_path} 帧率: {fps}, 总帧数: {total_frames}")
    return fps, total_frames

# 4. 自定义数据集类
class VideoFrameDataset(Dataset):
    def __init__(self, matched_files, transform=None):
        self.matched_files = matched_files
        self.transform = transform
        self.frames = []
        self.labels = []
        
        # 定义类别到整数的映射
        self.class_to_idx = {
            'Preparation': 0,
            'Estimation': 1,
            'Marking': 2,
            'Injection':3,
            'Incision':4,
            'ESD':5,
            'Vessel-treatment':6,
            'Clip':7
            # 添加其他类别
        }
        
        # 加载视频帧和标注
        for video_path, annotation_path in matched_files:
            annotations = load_annotation_file(annotation_path)
            cap = cv2.VideoCapture(video_path)
            fps, total_frames = check_video_fps_and_frames(video_path)
            selected_annotations = [ann for ann in annotations if ann['frame'] % 50 == 0]
            for annotation in selected_annotations:
                frame_idx = annotation['frame']
                phase = annotation['phase']
                
                # 跳转到指定帧
                cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
                ret, frame = cap.read()
                if ret:
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    frame = Image.fromarray(frame)
                    if self.transform:
                        frame = self.transform(frame)
                    self.frames.append(frame)
                    # 将 phase 转换为整数标签
                    self.labels.append(self.class_to_idx.get(phase, -1))  # 如果 phase 不在映射中，返回 -1
                    print(f"✅ 成功读取帧: {frame_idx} from {video_path}") 
                    
            cap.release()
            print(f"🎯 从 {video_path} 提取了 {len(frame)} 帧")
    
    def __len__(self):
        return len(self.frames)
    
    def __getitem__(self, idx):
        return self.frames[idx], self.labels[idx]

# 5. 定义数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 6. 加载数据集
data_path = r'D:\\ESD\\test'
begin = 1
end = 2
matched_files = load_dataset_files(data_path, begin, end)

# 创建数据集
dataset = VideoFrameDataset(matched_files, transform=transform)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# 7. 定义CNN模型
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=8):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 56 * 56, 512)
        self.fc2 = nn.Linear(512, num_classes)
    
    def forward(self, x):
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 64 * 56 * 56)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleCNN(num_classes=8)

# 8. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 9. 训练模型
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

# 10. 验证模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in train_loader:  # 这里可以用验证集替换
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy on training set: {100 * correct / total:.2f}%')

找到 4 对（视频-文本）文件对
从 D:\\ESD\\test\1\20230822145107_U2291907_1_001_0001-01.txt 加载了 35250 条标注
视频 D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4 帧率: 50.0, 总帧数: 40060
✅ 成功读取帧: 0 from D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4
✅ 成功读取帧: 50 from D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4
✅ 成功读取帧: 100 from D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4
✅ 成功读取帧: 150 from D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4
✅ 成功读取帧: 200 from D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4
✅ 成功读取帧: 250 from D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4
✅ 成功读取帧: 300 from D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4
✅ 成功读取帧: 350 from D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4
✅ 成功读取帧: 400 from D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4
✅ 成功读取帧: 450 from D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4
✅ 成功读取帧: 500 from D:\\ESD\\test\1\M_202308221451

NameError: name 'extracted_frames' is not defined

In [1]:
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from torchvision import models
from torch.optim.lr_scheduler import StepLR

# 1. 加载数据集文件
def load_dataset_files(data_path, begin, end):
    """获取 1~10 文件夹下以 M_ 开头的视频文件和对应的标注文件"""
    matched_files = []
    for i in range(begin, end + 1):
        folder_path = os.path.join(data_path, str(i))
        if not os.path.isdir(folder_path):
            continue
        for file in os.listdir(folder_path):
            if file.startswith("M_") and file.lower().endswith('.mp4'):
                video_path = os.path.join(folder_path, file)
                base_name = file.replace("M_", "").replace(".MP4", "")
                annotation_file_name = base_name + ".txt"
                annotation_file_path = os.path.join(folder_path, annotation_file_name)
                if os.path.exists(annotation_file_path):
                    matched_files.append((video_path, annotation_file_path))
                else:
                    print(f"{video_path} 视频未找到相对应的文本文件")
    print(f"找到 {len(matched_files)} 对（视频-文本）文件对")
    return matched_files

# 2. 加载标注文件
def load_annotation_file(annotation_file_path):
    """读取标注文件"""
    annotations = []
    try:
        with open(annotation_file_path, 'r') as file:
            lines = file.readlines()
            if lines[0].startswith("Frame"):
                lines = lines[1:]
            for line in lines:
                parts = line.strip().split()
                if len(parts) == 2:
                    try:
                        frame = int(parts[0])
                        phase = parts[1]
                        annotations.append({'frame': frame, 'phase': phase})
                    except ValueError:
                        print(f"⚠️ 无法解析行: {line}")
    except FileNotFoundError:
        print(f"⚠️ 找不到文件: {annotation_file_path}")
    print(f"从 {annotation_file_path} 加载了 {len(annotations)} 条标注")
    return annotations

# 3. 获取视频的帧率和总帧数
def check_video_fps_and_frames(video_path):
    """获取视频的帧率和总帧数"""
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"无法打开视频文件: {video_path}")
        return None, None
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.release()
    print(f"视频 {video_path} 帧率: {fps}, 总帧数: {total_frames}")
    return fps, total_frames

# 4. 自定义数据集类
class VideoFrameDataset(Dataset):
    def __init__(self, matched_files, transform=None):
        self.matched_files = matched_files
        self.transform = transform
        self.frames = []
        self.labels = []
        
        # 定义类别到整数的映射
        self.class_to_idx = {
            'Preparation': 0,
            'Estimation': 1,
            'Marking': 2,
            'Injection': 3,
            'Incision': 4,
            'ESD': 5,
            'Vessel-treatment': 6,
            'Clip': 7
        }
        
        # 加载视频帧和标注
        for video_path, annotation_path in matched_files:
            annotations = load_annotation_file(annotation_path)
            cap = cv2.VideoCapture(video_path)
            fps, total_frames = check_video_fps_and_frames(video_path)
            selected_annotations = [ann for ann in annotations if ann['frame'] % 50 == 0]
            for annotation in selected_annotations:
                frame_idx = annotation['frame']
                phase = annotation['phase']
                
                # 跳转到指定帧
                cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
                ret, frame = cap.read()
                if ret:
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    frame = Image.fromarray(frame)
                    if self.transform:
                        frame = self.transform(frame)
                    label = self.class_to_idx.get(phase, -1)
                    if label != -1:  # 只加载有效类别的样本
                        self.frames.append(frame)
                        self.labels.append(label)
                        print(f"✅ 成功读取帧: {frame_idx} from {video_path}")
                    
            cap.release()
            print(f"🎯 从 {video_path} 提取了 {len(self.frames)} 帧")
    
    def __len__(self):
        return len(self.frames)
    
    def __getitem__(self, idx):
        return self.frames[idx], self.labels[idx]

# 5. 定义数据预处理和数据增强
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),  # 随机水平翻转
    transforms.RandomRotation(5),     # 轻微旋转（5度）
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 6. 加载数据集
data_path = r'D:\\ESD\\test'
begin = 1
end = 2
matched_files = load_dataset_files(data_path, begin, end)

# 创建数据集
dataset = VideoFrameDataset(matched_files, transform=transform)

# 划分训练集和验证集
train_size = int(0.8 * len(dataset))  # 80% 训练集
val_size = len(dataset) - train_size  # 20% 验证集
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# 创建 DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# 7. 使用预训练的 ResNet 模型
class ResNetModel(nn.Module):
    def __init__(self, num_classes=8):
        super(ResNetModel, self).__init__()
        self.model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)  # 使用 ResNet50
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
    
    def forward(self, x):
        return self.model(x)

model = ResNetModel(num_classes=8)

# 8. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)  # 使用 AdamW 和权重衰减
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)  # 学习率调度器

# 9. 训练模型
num_epochs = 20
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

best_val_accuracy = 0.0
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    # 打印当前 epoch 和学习率
    current_lr = optimizer.param_groups[0]['lr']
    print(f"\nEpoch [{epoch+1}/{num_epochs}], Learning Rate: {current_lr:.6f}")
    
    # 训练阶段
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        # 每 10 个 batch 打印一次训练状态
        if (batch_idx + 1) % 10 == 0:
            batch_loss = running_loss / (batch_idx + 1)
            batch_accuracy = 100 * correct / total
            print(f"  Batch [{batch_idx + 1}/{len(train_loader)}], "
                  f"Train Loss: {batch_loss:.4f}, "
                  f"Train Accuracy: {batch_accuracy:.2f}%")
    
    # 计算训练集的平均损失和准确率
    train_loss = running_loss / len(train_loader)
    train_accuracy = 100 * correct / total
    
    # 验证阶段
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
    
    # 计算验证集的平均损失和准确率
    val_loss /= len(val_loader)
    val_accuracy = 100 * val_correct / val_total
    
    # 打印训练和验证结果
    print(f"  Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%")
    print(f"  Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%")
    
    # 保存最佳模型
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), 'best_model.pth')
        print(f"🎉 保存最佳模型，验证集准确率: {val_accuracy:.2f}%")
    
    # 更新学习率
    scheduler.step()

找到 4 对（视频-文本）文件对
从 D:\\ESD\\test\1\20230822145107_U2291907_1_001_0001-01.txt 加载了 35250 条标注
视频 D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4 帧率: 50.0, 总帧数: 40060
✅ 成功读取帧: 0 from D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4
✅ 成功读取帧: 50 from D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4
✅ 成功读取帧: 100 from D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4
✅ 成功读取帧: 150 from D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4
✅ 成功读取帧: 200 from D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4
✅ 成功读取帧: 250 from D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4
✅ 成功读取帧: 300 from D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4
✅ 成功读取帧: 350 from D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4
✅ 成功读取帧: 400 from D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4
✅ 成功读取帧: 450 from D:\\ESD\\test\1\M_20230822145107_U2291907_1_001_0001-01.MP4
✅ 成功读取帧: 500 from D:\\ESD\\test\1\M_202308221451