In [None]:
import time
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from torchsummary import summary
import torch.nn.functional as F
from torch.utils.data.dataloader import default_collate

# 设置随机种子保证可重复性
torch.manual_seed(42)

# 检查 GPU 可用性
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 对比学习数据增强
contrastive_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# 测试集数据预处理
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# 自定义 collate_fn 函数
def custom_collate(batch):
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    return images, default_collate(labels)

if __name__ == '__main__':
    # 加载训练集用于对比学习（不进行转换）
    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=None)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=128, shuffle=True, num_workers=0, collate_fn=custom_collate)

    # 加载测试集
    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False, num_workers=2)

    # 定义类别名称
    classes = ('plane', 'car', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck')

    # 获取当前运行文件所在的文件夹路径
    import os
    current_dir = os.getcwd()
    # 设置 TORCH_HOME 环境变量为当前文件夹
    os.environ['TORCH_HOME'] = current_dir

    # 加载预训练模型
    model = torchvision.models.resnet18(pretrained=True)

    # 修改全连接层以适应 CIFAR - 10 数据集
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 10)

    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()  # 移除原模型中的 maxpool

    model = model.to(device)

    # 打印模型的结构信息
    summary(model, (3, 32, 32))

    # 模型权重文件路径
    model_weights_path = r'E:\shuyang\代码文件\教学课件\面试班\25复试班\进阶项目\CNN\model_weights.pth'

    # 加载之前保存的模型权重
    try:
        model.load_state_dict(torch.load(model_weights_path))
        print(f'Model weights loaded from {model_weights_path}')
    except FileNotFoundError:
        print(f'Weight file {model_weights_path} not found.')
        exit(1)

    # 对比学习超参数
    temperature = 0.1
    contrastive_epochs = 10
    learning_rate = 0.001

    # 优化器
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # InfoNCE 损失函数
    def info_nce_loss(features, temperature=0.1):
        batch_size = features.shape[0] // 2
        labels = torch.cat([torch.arange(batch_size) for i in range(2)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        labels = labels.to(device)

        similarity_matrix = F.cosine_similarity(features.unsqueeze(1), features.unsqueeze(0), dim=2)
        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
        labels = labels[~mask].view(labels.shape[0], -1)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)

        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

        logits = torch.cat([positives, negatives], dim=1)
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)
        logits = logits / temperature
        return F.cross_entropy(logits, labels)

    # 对比学习训练
    for epoch in range(contrastive_epochs):
        running_loss = 0.0
        model.train()
        start_time = time.time()
        for i, (images, _) in enumerate(tqdm(trainloader, desc=f'Epoch {epoch + 1}/{contrastive_epochs}')):
            # 对每个样本进行两次不同的数据增强
            images1 = torch.stack([contrastive_transform(img) for img in images]).to(device)
            images2 = torch.stack([contrastive_transform(img) for img in images]).to(device)

            # 前向传播
            features1 = model(images1)
            features2 = model(images2)
            features = torch.cat([features1, features2], dim=0)

            # 计算对比损失
            loss = info_nce_loss(features, temperature)

            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        end_time = time.time()
        epoch_loss = running_loss / len(trainloader)
        print(f'Epoch {epoch + 1}/{contrastive_epochs}, Contrastive Loss: {epoch_loss:.4f}, Time taken: {end_time - start_time:.2f} seconds')

    # 保存对比学习后的模型权重
    contrastive_model_weights_path = r'E:\shuyang\代码文件\教学课件\面试班\25复试班\进阶项目\CNN\contrastive_model_weights.pth'
    # torch.save(model.state_dict(), contrastive_model_weights_path)
    # print(f'Contrastive model weights saved to {contrastive_model_weights_path}')

    # 评估对比学习后的模型
    model.eval()
    val_running_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in tqdm(testloader, desc="Inference on test set"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = nn.CrossEntropyLoss()(outputs, labels)

            val_running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

    epoch_val_loss = val_running_loss / len(testloader)
    epoch_val_accuracy = 100 * val_correct / val_total

    # 打印推理结果
    print(f'Test Loss after contrastive learning: {epoch_val_loss:.4f}, Test Acc: {epoch_val_accuracy:.2f}%')

Using device: cuda:0
Files already downloaded and verified
Files already downloaded and verified
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,728
       BatchNorm2d-2           [-1, 64, 32, 32]             128
              ReLU-3           [-1, 64, 32, 32]               0
          Identity-4           [-1, 64, 32, 32]               0
            Conv2d-5           [-1, 64, 32, 32]          36,864
       BatchNorm2d-6           [-1, 64, 32, 32]             128
              ReLU-7           [-1, 64, 32, 32]               0
            Conv2d-8           [-1, 64, 32, 32]          36,864
       BatchNorm2d-9           [-1, 64, 32, 32]             128
             ReLU-10           [-1, 64, 32, 32]               0
       BasicBlock-11           [-1, 64, 32, 32]               0
           Conv2d-12           [-1, 64, 32, 32]          36,864
      

Epoch 1/10: 100%|██████████| 391/391 [01:12<00:00,  5.38it/s]


Epoch 1/10, Contrastive Loss: 3.2421, Time taken: 72.71 seconds


Epoch 2/10: 100%|██████████| 391/391 [01:13<00:00,  5.35it/s]


Epoch 2/10, Contrastive Loss: 2.6218, Time taken: 73.11 seconds


Epoch 3/10: 100%|██████████| 391/391 [01:12<00:00,  5.37it/s]


Epoch 3/10, Contrastive Loss: 2.4693, Time taken: 72.79 seconds


Epoch 4/10: 100%|██████████| 391/391 [01:12<00:00,  5.39it/s]


Epoch 4/10, Contrastive Loss: 2.3679, Time taken: 72.50 seconds


Epoch 5/10: 100%|██████████| 391/391 [01:12<00:00,  5.38it/s]


Epoch 5/10, Contrastive Loss: 2.3050, Time taken: 72.65 seconds


Epoch 6/10: 100%|██████████| 391/391 [01:12<00:00,  5.39it/s]


Epoch 6/10, Contrastive Loss: 2.2593, Time taken: 72.51 seconds


Epoch 7/10: 100%|██████████| 391/391 [01:12<00:00,  5.37it/s]


Epoch 7/10, Contrastive Loss: 2.2250, Time taken: 72.84 seconds


Epoch 8/10: 100%|██████████| 391/391 [01:12<00:00,  5.42it/s]


Epoch 8/10, Contrastive Loss: 2.1736, Time taken: 72.13 seconds


Epoch 9/10: 100%|██████████| 391/391 [01:11<00:00,  5.43it/s]


Epoch 9/10, Contrastive Loss: 2.1488, Time taken: 71.95 seconds


Epoch 10/10: 100%|██████████| 391/391 [01:11<00:00,  5.44it/s]


Epoch 10/10, Contrastive Loss: 2.1425, Time taken: 71.83 seconds
Contrastive model weights saved to E:\shuyang\代码文件\教学课件\面试班\25复试班\进阶项目\CNN\contrastive_model_weights.pth


Inference on test set: 100%|██████████| 100/100 [00:02<00:00, 42.74it/s]

Test Loss after contrastive learning: 3.2534, Test Acc: 24.05%



