In [1]:
# 序列数据、标签的加载
import os
from PIL import Image
import torch
from torchvision import transforms

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

torch.manual_seed(4)

# 若需更换face和eye，需改图片size
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

data_dir = 'D:/15341/Desktop/datasets_cite/DeepfakeTIMIT/face'

sequences = []
labels = []

for label, class_name in enumerate(['real', 'fake']):
    label_dir = os.path.join(data_dir, class_name)

    for sequence_idx in range(320):
        if sequence_idx == 25 or sequence_idx == 88 or sequence_idx == 167 or sequence_idx == 195:
            continue

        sequence = []
        for image_idx in range(10):
            image_path = os.path.join(label_dir, f'{sequence_idx}_{image_idx}.png')
            image = Image.open(image_path).convert("RGB")
            if transform:
                image = transform(image)
            sequence.append(image)

        sequences.append(torch.stack(sequence))
        labels.append(label)

sequences = torch.stack(sequences)
labels = torch.tensor(labels)
print(sequences.shape)
print(labels.shape)

torch.Size([632, 10, 3, 224, 224])
torch.Size([632])


In [2]:
# DeepfakeTIMIT sequence
from torch.utils.data import DataLoader, SubsetRandomSampler, TensorDataset
import numpy as np

dataset_size = sequences.size(0)
indices = list(range(dataset_size))
np.random.seed(4) # change
np.random.shuffle(indices)
print(indices[:10])

sequences = sequences[indices]
labels = labels[indices]

split1 = int(np.floor(0.8 * dataset_size))
split2 = int(np.floor(0.9 * dataset_size))
train_indices, val_indices, test_indices = indices[:split1], indices[split1:split2], indices[split2:]

# split = int(np.floor(0.8 * dataset_size))
# test_size = dataset_size - split
# # size = np.ceil(0.5 * test_size).astype(int)
# size = int(0.5 * test_size)
# m = 3
# n = 4
# indices_left, test_indices, indices_right = indices[:(test_size * m)], indices[(test_size * m):(test_size * n - size)], indices[(test_size * n):]
# indices_left.extend(indices_right)
# train_indices = indices_left

# train_indices, test_indices = indices[:split], indices[split:]
# test_indices, train_indices = indices[:test_size], indices[test_size:]

dataset = TensorDataset(sequences, labels)

train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
test_sampler = SubsetRandomSampler(test_indices)

seq_train_loader = DataLoader(dataset, batch_size=1, sampler=train_sampler)
seq_val_loader = DataLoader(dataset, batch_size=1, sampler=val_sampler)
seq_test_loader = DataLoader(dataset, batch_size=1, sampler=test_sampler)

print("Training set length:", len(train_indices))
print("Validation set length:", len(val_indices))
print("Test set length:", len(test_indices))

[550, 248, 459, 433, 389, 490, 626, 596, 175, 320]
Training set length: 505
Validation set length: 63
Test set length: 64


In [3]:
# 检查标签是否正常
all_labels = []
for data in seq_train_loader:
    inputs, label = data
    all_labels.extend(label.tolist())
all_labels = torch.tensor(all_labels)
print("Total number of labels:", len(all_labels))
num_positive_samples = (all_labels == 1).sum().item()
num_negative_samples = (all_labels == 0).sum().item()
print(f"Number of positive samples: {num_positive_samples}")
print(f"Number of negative samples: {num_negative_samples}")

Total number of labels: 505
Number of positive samples: 247
Number of negative samples: 258


In [4]:
# DeepfakeTIMIT image
extend_labels = []
images = []

for i in range(sequences.size(0)):
    for j in range(10):
        images.append(sequences[i][j])
        extend_labels.append(labels[i])

images = torch.stack(images)
img_labels = torch.tensor(extend_labels)

img_dataset_size = images.size(0)
img_indices = list(range(img_dataset_size))

img_split1 = 10 * split1
img_split2 = 10 * split2
img_train_indices, img_val_indices, img_test_indices = img_indices[:img_split1], img_indices[img_split1:img_split2], img_indices[img_split2:]

# img_split = 10 * split
# img_test_size = img_dataset_size - img_split
# size = 10 * size

# indices_left, img_test_indices, indices_right = img_indices[:(img_test_size * m)], img_indices[(img_test_size * m):(img_test_size * n - size)], img_indices[(img_test_size * n):]
# indices_left.extend(indices_right)
# img_train_indices = indices_left

# img_train_indices, img_test_indices = img_indices[:img_split], img_indices[img_split:]
# img_test_indices, img_train_indices = img_indices[:img_test_size], img_indices[img_test_size:]

img_dataset = TensorDataset(images, img_labels)

img_train_sampler = SubsetRandomSampler(img_train_indices)
img_val_sampler = SubsetRandomSampler(img_val_indices)
img_test_sampler = SubsetRandomSampler(img_test_indices)

img_train_loader = DataLoader(img_dataset, batch_size=32, sampler=img_train_sampler)
img_val_loader = DataLoader(img_dataset, batch_size=32, sampler=img_val_sampler)
img_test_loader = DataLoader(img_dataset, batch_size=32, sampler=img_test_sampler)

print("Training set length:", len(img_train_indices))
print("Validation set length:", len(img_val_indices))
print("Test set length:", len(img_test_indices))

Training set length: 5050
Validation set length: 630
Test set length: 640


In [5]:
from models import CustomVGG

img_model = CustomVGG().to(device)

In [6]:
import torch.nn as nn
import torch.optim as optim


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(img_model.parameters(), lr=0.001)
# optimizer = optim.SGD(img_model.parameters(), lr=0.01, momentum=0.9)
# torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # 设置 clipnorm
# torch.nn.utils.clip_grad_value_(model.parameters(), 0.5)  # 设置 clipvalue

best_val_loss = float('inf')
patience = 10
counter = 0

# 训练
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in img_train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        img_model.train()

        optimizer.zero_grad()

        # 前向传播
        outputs = img_model(inputs)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    loss = running_loss / len(img_train_loader)
    accuracy = 100 * correct / total
    # print('Epoch %d, loss: %.3f, accuracy: %.2f %%' % (epoch + 1, loss, accuracy))
    
    # 模型评估
    img_model.eval()  # 将模型设置为评估模式，关闭 Dropout
    val_correct = 0
    val_total = 0
    val_running_loss = 0.0
    all_labels = []
    all_predictions = []
    with torch.no_grad():  # 在验证阶段不需要计算梯度
        for val_data in img_val_loader:
            val_inputs, val_labels = val_data
            val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)

            val_outputs = img_model(val_inputs)
            val_loss = criterion(val_outputs, val_labels)
            val_running_loss += val_loss.item()

            _, val_predicted = torch.max(val_outputs.data, 1)
            val_total += val_labels.size(0)
            val_correct += (val_predicted == val_labels).sum().item()

    # 计算验证集上的准确率
    val_accuracy = 100 * val_correct / val_total
    avg_val_loss = val_running_loss / len(img_val_loader)

    print('Epoch %d, loss: %.3f, accuracy: %.2f %%. Validation, loss: %.3f, accuracy: %.2f %%' %
            (epoch + 1, loss, accuracy, avg_val_loss, val_accuracy))
    
    if epoch == 0:
        best_val_accuracy = val_accuracy
    # 保存最佳模型
    if val_accuracy >= best_val_accuracy:
        best_val_accuracy = val_accuracy
        best_model_weights = img_model.state_dict()

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping")
            break

Epoch 1, loss: 0.107, accuracy: 96.36 %. Validation, loss: 0.028, accuracy: 98.89 %
Epoch 2, loss: 0.014, accuracy: 99.76 %. Validation, loss: 0.025, accuracy: 99.21 %
Epoch 3, loss: 0.010, accuracy: 99.80 %. Validation, loss: 0.030, accuracy: 98.73 %
Epoch 4, loss: 0.012, accuracy: 99.62 %. Validation, loss: 0.022, accuracy: 99.21 %
Epoch 5, loss: 0.010, accuracy: 99.78 %. Validation, loss: 0.021, accuracy: 99.37 %
Epoch 6, loss: 0.007, accuracy: 99.82 %. Validation, loss: 0.012, accuracy: 99.52 %
Epoch 7, loss: 0.009, accuracy: 99.74 %. Validation, loss: 0.011, accuracy: 99.84 %
Epoch 8, loss: 0.011, accuracy: 99.68 %. Validation, loss: 0.006, accuracy: 99.84 %
Epoch 9, loss: 0.006, accuracy: 99.88 %. Validation, loss: 0.012, accuracy: 99.68 %
Epoch 10, loss: 0.008, accuracy: 99.80 %. Validation, loss: 0.051, accuracy: 98.57 %


In [7]:
torch.save(img_model, 'D:/15341/Desktop/model/DeepfakeTIMIT/LQ/cnn_face_model(3).pth')

In [8]:
from sklearn.metrics import roc_auc_score

# img_model = torch.load('D:/15341/Desktop/model/DeepfakeTIMIT/LQ/cnn_face_model(2).pth')
correct = 0
total = 0
all_labels = []
all_predictions = []

with torch.no_grad():
    img_model.eval()
    for inputs, label in img_test_loader:
        inputs, label = inputs.to(device), label.to(device)
        outputs = img_model(inputs)
        
        _, predicted = torch.max(outputs, 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()

        predictions = torch.sigmoid(outputs[:, 1])
        all_labels.extend(label.cpu().numpy())
        all_predictions.extend(predictions.cpu().detach().numpy())

test_accuracy = 100 * correct / total
auc = roc_auc_score(all_labels, all_predictions)
print("Test Accuracy: %.2f %%, AUC: %.4f" % (test_accuracy, auc))

Test Accuracy: 99.69 %, AUC: 1.0000


In [9]:
# 序列对图片的提升（序列整体平均）
# img_model = torch.load('D:/15341/Desktop/model/DeepfakeTIMIT/LQ/cnn_face_model(2).pth')
correct = 0

with torch.no_grad():
    img_model.eval()
    for data in seq_val_loader:
        input, label = data
        input, label = input.to(device), label.to(device)
        img_output = []
        output = []

        for t in range(10):
            img_output.append(img_model(input[:, t, :, :, :]))
        img_output = torch.stack(img_output)
        img_output = torch.squeeze(img_output, dim=1)
        mean_img_output = torch.mean(img_output, dim=0, keepdim=True)

        for t in range(10):
            output = img_model(input[:, t, :, :, :])
            output += mean_img_output

            _, predicted = torch.max(output, 1)
            correct += (predicted == label).sum().item()

total = 10 * len(seq_val_loader)
print(total)
test_accuracy = 100 * correct / total
print("Test Accuracy: %.2f %%" % test_accuracy)

630
Test Accuracy: 100.00 %


In [10]:
# img_model = torch.load('./model/DeepfakeTIMIT/HQ/cnn_eye_model_5.pth')
output = []
with torch.no_grad():
    img_model.eval()
    for i, data in enumerate(seq_test_loader):
        input, label = data
        input, label = input.to(device), label.to(device)
        if i == 0:
            output = img_model(input[:, 0, :, :, :])
            print(output.shape)

torch.Size([1, 2])


In [11]:
# 定义两个张量
tensor1 = torch.tensor([[1, 2]])
tensor2 = torch.tensor([[3, 4]])

# 使用torch.cat沿着维度1（列）拼接张量
combined_tensor = torch.cat((tensor1, tensor2), dim=1)

print(combined_tensor)  # 输出应为：tensor([[1, 2, 3, 4]])

tensor([[1, 2, 3, 4]])


In [12]:
# SVM：序列对图片的提升（伪）
# img_model = torch.load('./model/DeepfakeTIMIT/HQ/cnn_eye_model_5.pth')
img_output = []
img_label = []

with torch.no_grad():
    img_model.eval()
    for data in seq_train_loader:
        input, label = data
        input, label = input.to(device), label.to(device)
        single_output = []
        output = []

        for t in range(10):
            single_output.append(img_model(input[:, t, :, :, :]))
        single_output = torch.stack(single_output)
        single_output = torch.squeeze(single_output, dim=1)
        mean_img_output = torch.mean(single_output, dim=0, keepdim=True)

        for t in range(10):
            output = img_model(input[:, t, :, :, :])
            output += mean_img_output
            
            # output = torch.cat((output, mean_img_output), dim=1)
            output = torch.sigmoid(output)
            img_output.append(output)
            img_label.append(label)
print(len(img_output))

5050


In [13]:
from sklearn import svm
from sklearn.metrics import accuracy_score

img_output_tensor = torch.stack(img_output)
img_label_tensor = torch.cat(img_label)

img_output_flattened = img_output_tensor.view(img_output_tensor.size(0), -1).cpu().numpy()
img_label_flattened = img_label_tensor.cpu().numpy()

svm_classifier = svm.SVC(kernel='linear')

svm_classifier.fit(img_output_flattened, img_label_flattened)

y_pred = svm_classifier.predict(img_output_flattened)

accuracy = accuracy_score(img_label_flattened, y_pred)
print(f"Accuracy of SVM classifier on train dataset: {accuracy:.4f}")

Accuracy of SVM classifier on train dataset: 0.9972


In [14]:
import joblib
# 保存训练好的 SVM 模型到文件
model_filename = 'D:/15341/Desktop/model/DeepfakeTIMIT/LQ/svm/cnn_face_new(1).pkl'
joblib.dump(svm_classifier, model_filename)

['D:/15341/Desktop/model/DeepfakeTIMIT/LQ/svm/cnn_face_new(1).pkl']

In [15]:
# SVM：序列对图片的提升（伪）
# img_model = torch.load('./model/DeepfakeTIMIT/HQ/cnn_eye_model_5.pth')
img_output = []
img_label = []

with torch.no_grad():
    img_model.eval()
    for data in seq_test_loader:
        input, label = data
        input, label = input.to(device), label.to(device)
        single_output = []
        output = []

        for t in range(10):
            single_output.append(img_model(input[:, t, :, :, :]))
        single_output = torch.stack(single_output)
        single_output = torch.squeeze(single_output, dim=1)
        mean_img_output = torch.mean(single_output, dim=0, keepdim=True)

        for t in range(10):
            output = img_model(input[:, t, :, :, :])
            output += mean_img_output
            
            # output = torch.cat((output, mean_img_output), dim=1)
            output = torch.sigmoid(output)
            img_output.append(output)
            img_label.append(label)
print(len(img_output))

640
