In [19]:
import os

def load_labels(label_file_path):
    image_labels = []
    with open(label_file_path, 'r') as f:
        for line in f:
            # 每行格式：图片文件名 + 标签
            line = line.strip()
            if line:
                image_file, label = line.split()
                image_labels.append((image_file, int(label)))  # 图片名和标签
    return image_labels

label_file_path = 'labelled_stoat.txt'
image_labels = load_labels(label_file_path)
# image_labels


In [20]:
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

class StoatDataset(Dataset):
    def __init__(self, image_labels, image_dir, transform=None):
        self.image_labels = image_labels
        self.image_dir = image_dir
        self.transform = transform
    
    def __len__(self):
        return len(self.image_labels)
    
    def __getitem__(self, idx):
        image_file, label = self.image_labels[idx]
        image_path = os.path.join(self.image_dir, image_file)
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# 图像目录
image_dir = '/raid/yil708/stoat_data/auxiliary_network_pics/labelled_auxiliary_network_pics/labelled_auxiliary_network_pics/'

# 图像转换，数据增强
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载自定义的 Dataset
stoat_dataset = StoatDataset(image_labels, image_dir, transform=transform)


In [21]:
from torch.utils.data import DataLoader, random_split

# 将数据集划分为80%训练集和20%验证集
train_size = int(0.8 * len(stoat_dataset))
val_size = len(stoat_dataset) - train_size
train_dataset, val_dataset = random_split(stoat_dataset, [train_size, val_size])

# 创建 DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [22]:
import torch
import torch.nn as nn
import torchvision.models as models

# 选择GPU设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载预训练的ResNet50模型
model = models.resnet50(pretrained=True)

# 修改最后的全连接层，输出4个分类
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 4)  # 4类：正面、背面、左侧面、右侧面

# 将模型转移到GPU上
model = model.to(device)

In [23]:
import torch.optim as optim

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [24]:
# 训练模型函数
def train_model(model, criterion, optimizer, train_loader, val_loader, num_epochs=10):
    for epoch in range(num_epochs):
        # 训练模式
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        train_accuracy = correct / total * 100
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/total:.4f}, Accuracy: {train_accuracy:.2f}%")
        
        # 验证模式
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        val_accuracy = val_correct / val_total * 100
        print(f"Validation Accuracy: {val_accuracy:.2f}%")

# 开始训练
train_model(model, criterion, optimizer, train_loader, val_loader, num_epochs=10)

Epoch 1/10, Loss: 0.0457, Accuracy: 31.91%


  return F.conv2d(input, weight, bias, self.stride,


Validation Accuracy: 32.99%
Epoch 2/10, Loss: 0.0422, Accuracy: 37.32%
Validation Accuracy: 36.91%
Epoch 3/10, Loss: 0.0413, Accuracy: 37.53%
Validation Accuracy: 38.14%
Epoch 4/10, Loss: 0.0399, Accuracy: 42.22%
Validation Accuracy: 37.94%
Epoch 5/10, Loss: 0.0385, Accuracy: 44.54%
Validation Accuracy: 38.97%
Epoch 6/10, Loss: 0.0371, Accuracy: 45.00%
Validation Accuracy: 47.63%
Epoch 7/10, Loss: 0.0345, Accuracy: 52.47%
Validation Accuracy: 47.01%
Epoch 8/10, Loss: 0.0330, Accuracy: 54.90%
Validation Accuracy: 50.72%
Epoch 9/10, Loss: 0.0302, Accuracy: 59.59%
Validation Accuracy: 58.76%
Epoch 10/10, Loss: 0.0275, Accuracy: 64.43%
Validation Accuracy: 59.38%
