<a href="https://colab.research.google.com/github/Song20011219/song/blob/main/ASD_CNNS3D.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
from datetime import datetime
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
from dataset import ASD_Isolated
from train import train_epoch
from validation import val_epoch
from Conv3D import CNN3D  # 导入Conv3D模型

# 设置路径
data_path = "/content/drive/MyDrive/output_frames"
model_path = "/content/drive/MyDrive/cnn3d_models"
log_path = "cnn3d_log_{:%Y-%m-%d_%H-%M-%S}.log".format(datetime.now())
sum_path = "cnn3d_runs_{:%Y-%m-%d_%H-%M-%S}".format(datetime.now())

# 记录到文件和Tensorboard
logging.basicConfig(level=logging.INFO, format='%(message)s', handlers=[logging.FileHandler(log_path), logging.StreamHandler()])
logger = logging.getLogger('CNN3D')
logger.info('记录到文件...')
writer = SummaryWriter(sum_path)

# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 超参数
num_classes = 2
epochs = 10
batch_size = 8
learning_rate = 1e-4
log_interval = 20
sample_size = 128
sample_duration = 30  # 30帧的视频样本

# 加载数据
transform = transforms.Compose([
    transforms.Resize([sample_size, sample_size]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

dataset = ASD_Isolated(data_path=data_path, transform=transform)
total_samples = len(dataset)
train_size = int(0.8 * total_samples)
val_size = total_samples - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# 创建Conv3D模型
model = Conv3D(sample_size=sample_size, sample_duration=sample_duration, num_classes=num_classes).to(device)

# 创建损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 开始训练
logger.info("开始训练".center(60, '#'))
for epoch in range(epochs):
    # 训练模型
    train_epoch(model, criterion, optimizer, train_loader, device, epoch, logger, log_interval, writer)

    # 验证模型
    val_epoch(model, criterion, val_loader, device, epoch, logger, writer)

    # 保存模型
    torch.save(model.state_dict(), os.path.join(model_path, "cnn3d_epoch{:03d}.pth".format(epoch+1)))
    logger.info("第 {} 轮模型已保存".format(epoch+1).center(60, '#'))

logger.info("训练完成".center(60, '#'))
