In [None]:
import torch
import clip
import pandas as pd
from PIL import Image
import os
import numpy as np
from ultralytics import YOLO
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from datetime import datetime, timedelta
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# 1. 数据预处理
class MultiModalDataset(Dataset):
    def __init__(self, csv_path, img_dir, clip_preprocess, input_size=224, indices=None):
        self.df = pd.read_csv(csv_path)
        self.img_dir = img_dir
        self.preprocess = clip_preprocess
        self.input_size = input_size

        self.df['abs_time'] = pd.to_datetime('2023-09-03') + pd.to_timedelta(self.df['Time'])
        self.df['total_seconds'] = (self.df['abs_time'] - datetime(2023, 9, 3)).dt.total_seconds()

        self.image_files = sorted([f for f in os.listdir(img_dir) if f.endswith('.jpg')])
        self.image_times = []
        for f in self.image_files:
            time_str = f.split('.')[0]
            dt = datetime.strptime(time_str, "%Y-%m-%d-%H-%M-%S")
            self.image_times.append((dt - datetime(2023, 9, 3)).total_seconds())

        # 数据对齐
        self.pairs = []
        for img_idx, img_time in enumerate(self.image_times):
            csv_idx = np.abs(self.df['total_seconds'] - img_time).argmin()
            self.pairs.append((img_idx, csv_idx))

        # 特征标准化
        feature_cols = ['上升时间', '计数', '能量', '持续时间', '幅值', '平均频率', 'RMS', '峰值频率', '绝对能量']
        self.scaler = StandardScaler()
        self.features = self.scaler.fit_transform(self.df[feature_cols])

        if indices is None:
            self.indices = list(range(len(self.pairs)))
        else:
            self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        img_idx, csv_idx = self.pairs[self.indices[idx]]

        # 处理图像
        img_path = os.path.join(self.img_dir, self.image_files[img_idx])
        image = self.preprocess(Image.open(img_path))

        # CSV特征
        csv_feature = torch.FloatTensor(self.features[csv_idx])

        # 时间标签
        img_time = self.image_times[img_idx]
        crack_label = 1 if img_time > (18 * 3600) else 0

        return image, csv_feature, crack_label

# 自注意力机制
class SelfAttention(torch.nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.query = torch.nn.Linear(input_dim, input_dim)
        self.key = torch.nn.Linear(input_dim, input_dim)
        self.value = torch.nn.Linear(input_dim, input_dim)
        self.softmax = torch.nn.Softmax(dim=-1)

    def forward(self, x):
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1))
        attn_probs = self.softmax(attn_scores)
        output = torch.matmul(attn_probs, V)
        return output

# 交叉注意力机制
class CrossAttention(torch.nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.query = torch.nn.Linear(input_dim, input_dim)
        self.key = torch.nn.Linear(input_dim, input_dim)
        self.value = torch.nn.Linear(input_dim, input_dim)
        self.softmax = torch.nn.Softmax(dim=-1)

    def forward(self, x1, x2):
        Q = self.query(x1)
        K = self.key(x2)
        V = self.value(x2)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1))
        attn_probs = self.softmax(attn_scores)
        output = torch.matmul(attn_probs, V)
        return output

# 2. 多模态
class FusionModel(torch.nn.Module):
    def __init__(self, yolo_weights, clip_model='ViT-B/32'):
        super().__init__()
        # 加载预训练模型
        self.yolo = YOLO(yolo_weights).model
        self.clip_model, _ = clip.load(clip_model)

        # 冻结CLIP参数
        for param in self.clip_model.parameters():
            param.requires_grad = False

        # 特征处理
        self.csv_encoder = torch.nn.Sequential(
            torch.nn.Linear(9, 256),
            torch.nn.ReLU(),
            torch.nn.LayerNorm(256),
            torch.nn.Linear(256, 512)
        )

        # 自注意力机制
        self.img_self_attn = SelfAttention(512)
        self.csv_self_attn = SelfAttention(512)

        # 交叉注意力机制
        self.img_csv_cross_attn = CrossAttention(512)
        self.csv_img_cross_attn = CrossAttention(512)

        # 特征融合
        self.fusion = torch.nn.Sequential(
            torch.nn.Linear(512 * 4, 1024),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),
            torch.nn.Linear(1024, 512)
        )

        # 分类器
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(512, 256),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5), 
            torch.nn.Linear(256, 2)
        )

    def forward(self, x_img, x_csv):
        # 图像特征提取
        with torch.no_grad():
            img_features = self.clip_model.encode_image(x_img)
        img_features = img_features.float()

        # CSV特征编码
        csv_features = self.csv_encoder(x_csv)

        # 自注意力
        img_self_attn_output = self.img_self_attn(img_features)
        csv_self_attn_output = self.csv_self_attn(csv_features)

        # 交叉注意力
        img_csv_cross_attn_output = self.img_csv_cross_attn(img_features, csv_features)
        csv_img_cross_attn_output = self.csv_img_cross_attn(csv_features, img_features)

        # 特征融合
        fused = torch.cat([img_self_attn_output, csv_self_attn_output, img_csv_cross_attn_output, csv_img_cross_attn_output], dim=1)
        fused = self.fusion(fused)

        # 通过分类器
        return self.classifier(fused)

# 3. 训练
def train():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    clip_model, preprocess = clip.load('ViT-B/32')

    # 数据集
    csv_path = r'output_processed.csv'
    img_dir = r'picture'
    dataset = MultiModalDataset(csv_path, img_dir, preprocess)

    indices = list(range(len(dataset)))
    train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)

    train_dataset = MultiModalDataset(csv_path, img_dir, preprocess, indices=train_idx)
    val_dataset = MultiModalDataset(csv_path, img_dir, preprocess, indices=val_idx)

    dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=16)

    # 初始化模型
    model = FusionModel(r'best.pt').to(device)

    # 优化器和损失函数
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    criterion = torch.nn.CrossEntropyLoss()

    # 学习率调度器
    scheduler = CosineAnnealingLR(optimizer, T_max=50)

    # 早停机制参数
    patience = 10
    early_stopping_counter = 0
    best_val_loss = float('inf')
    best_val_acc = 0

    for epoch in range(500):
        print(f'Epoch {epoch + 1}/{500}')
        # 训练
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_labels = []
        train_progress = tqdm(dataloader, desc=f'Training Epoch {epoch + 1}', unit='batch')
        for images, csv_feats, labels in train_progress:
            images = images.to(device)
            csv_feats = csv_feats.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images, csv_feats)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())

            train_progress.set_postfix({'Loss': total_loss / (train_progress.n + 1), 'Accuracy': correct / total})

        train_loss = total_loss / len(dataloader)
        train_acc = correct / total
        train_conf_matrix = confusion_matrix(all_labels, all_preds)

        # 验证
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_labels = []
        val_progress = tqdm(val_loader, desc=f'Validation Epoch {epoch + 1}', unit='batch')
        with torch.no_grad():
            for images, csv_feats, labels in val_progress:
                images = images.to(device)
                csv_feats = csv_feats.to(device)
                labels = labels.to(device)

                outputs = model(images, csv_feats)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                all_preds.extend(predicted.cpu().tolist())
                all_labels.extend(labels.cpu().tolist())

                val_progress.set_postfix({'Loss': val_loss / (val_progress.n + 1), 'Accuracy': correct / total})

        val_loss = val_loss / len(val_loader)
        val_acc = correct / total
        val_conf_matrix = confusion_matrix(all_labels, all_preds)

        print(f'Train Loss: {train_loss:.8f}, Train Acc: {train_acc:.8f}, Val Loss: {val_loss:.8f}, Val Acc: {val_acc:.8f}')

        # 保存最佳模型
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_acc = val_acc
            torch.save(model.state_dict(), f'best_multimodal_2.pth')
            print('Saved best model')
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1

        # 早停机制
        if early_stopping_counter >= patience:
            print(f'Early stopping at epoch {epoch + 1}')
            break

        # 学习率调度
        scheduler.step()

        # 绘制混淆矩阵
        plt.figure(figsize=(12, 5))
        plt.subplot(1, 2, 1)
        sns.heatmap(train_conf_matrix, annot=True, fmt='d', cmap='Blues')
        plt.title('Training Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.subplot(1, 2, 2)
        sns.heatmap(val_conf_matrix, annot=True, fmt='d', cmap='Blues')
        plt.title('Validation Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.show()

    print(f'Best validation loss: {best_val_loss:.8f}, Best validation accuracy: {best_val_acc:.8f}')

# 4. 推理
def infer(model_path, img_path, csv_data):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    clip_model, preprocess = clip.load('ViT-B/32')

    # 加载模型
    model = FusionModel(r'best.pt').to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    # 处理输入
    image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
    csv_feature = torch.FloatTensor(csv_data).to(device)

    # 推理
    with torch.no_grad():
        output = model(image, csv_feature)

    # 后处理
    prob = torch.softmax(output, dim=1)[0]
    return {'crack_probability': prob[1].item()}

if __name__ == '__main__':
    train()