In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from PIL import Image
from torchvision.transforms.functional import rotate

In [2]:
# 定义旋转变换
class RotateTransform:
    def __init__(self, angles):
        self.angles = angles

    def __call__(self, x):
        angle = self.angles[torch.randint(0, len(self.angles), size=(1,)).item()]
        rotated_image = rotate(x, angle)
        return rotated_image, angle

# 定义自监督学习模型
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(64 * 28 * 28, 128)
        self.fc2 = nn.Linear(128, 4)  # 输出4个角度类别：0°, 90°, 180°, 270°

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 自定义数据集类
class RotatedCIFAR10(datasets.CIFAR10):
    def __getitem__(self, index):
        img, _ = super().__getitem__(index)
        rotated_img, angle = RotateTransform(angles)(img)
        return rotated_img, angles.index(angle)

# 数据预处理
angles = [0, 90, 180, 270]
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
])

# 加载数据集并应用旋转
train_dataset = RotatedCIFAR10(root='./data', train=True, download=True, transform=transform)

# 创建模型和优化器
model = CNNModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# 数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# 模型训练
for epoch in range(10):
    running_loss = 0.0
    for img, label in train_loader:
        img = img.float()  # 确保 img 是浮点型张量
        label = label.long()  # 确保 label 是长整型张量
        
        optimizer.zero_grad()
        
        # 前向传播
        output = model(img)
        
        # 计算损失
        loss = criterion(output, label)
        
        # 反向传播
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

print('Finished Training')


Files already downloaded and verified


KeyboardInterrupt: 

In [21]:
import torch

# 假设你的模型实例为 model
model = CNNModel()

# 保存模型的状态字典
torch.save(model.state_dict(), 'cnn_self_supervised_pretrained.pth')

# 如果你需要保存整个模型（包括模型结构和参数），可以使用：
# torch.save(model, 'cnn_model_full.pth')


In [None]:
class ClassificationModel(nn.Module):
    def __init__(self, pretrained_model):
        super(ClassificationModel, self).__init__()
        self.feature_extractor = pretrained_model
        self.fc = nn.Linear(128, 10)  # 10个类别对应CIFAR-10数据集

    def forward(self, x):
        x = self.feature_extractor.conv1(x)
        x = torch.relu(x)
        x = torch.max_pool2d(x, 2)
        x = self.feature_extractor.conv2(x)
        x = torch.relu(x)
        x = torch.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = torch.relu(self.feature_extractor.fc1(x))
        x = self.fc(x)
        return x

# 加载预训练模型权重
pretrained_model = CNNModel()
pretrained_model.load_state_dict(torch.load('cnn_self_supervised_pretrained.pth'))

# 冻结预训练的层，防止其在微调时更新
for param in pretrained_model.parameters():
    param.requires_grad = False

# 创建分类模型，并加载预训练的特征提取部分
classification_model = ClassificationModel(pretrained_model)

# 使用分类任务的数据集
train_dataset_classification = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader_classification = torch.utils.data.DataLoader(train_dataset_classification, batch_size=64, shuffle=True)

# 定义优化器和损失函数
optimizer = optim.Adam(classification_model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# 分类任务微调模型
for epoch in range(10):
    running_loss = 0.0
    for img, label in train_loader_classification:
        img = img.float()
        label = label.long()

        optimizer.zero_grad()

        # 前向传播
        output = classification_model(img)

        # 计算损失
        loss = criterion(output, label)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Classification Loss: {running_loss/len(train_loader_classification)}")

# 保存微调后的分类模型
torch.save(classification_model.state_dict(), 'cnn_classification_model.pth')
