## 03-01 Create Module

In [20]:
import os
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image

torch.manual_seed(0)
torch.__version__

'1.3.0'

### 1.构建Dataset和DataLoader

In [16]:
# 定义文件路径
rmb_data_path = Path('/media/bnu/file/datasets/pytorch-tutorials/rmb_data')
train_path = rmb_data_path / 'split_data' / 'train'

class RMBDataset(Dataset):
    
    def __init__(self, image_path, transform=None):
        self.label_dict = {'1': 0, '100': 1}
        self.transform = transform

        self.image_data = []
        for root, dirs, files in os.walk(image_path):
            for sub_dir in dirs:
                # 获取目录下所有图片列表
                image_file_list = os.listdir(image_path / sub_dir)
                image_file_list = list(filter(lambda x: x.endswith('.jpg'), image_file_list))

                # 保存每个图片的路径和标签
                for i in range(len(image_file_list)):
                    image_name = image_file_list[i]
                    file_path = image_path / sub_dir / image_name
                    label = self.label_dict[sub_dir]
                    self.image_data.append((file_path, label))

    def __getitem__(self, index):
        file_path, label = self.image_data[index]
        image = Image.open(file_path).convert('RGB')  # 数据范围0-255
        
        # 对图片进行transform
        if self.transform is not None:
            image = self.transform(image)
        
        return image, label

    def __len__(self):
        return len(self.image_data)

In [18]:
# 标准化三通道的均值和标准差
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std)
])

# 构建Dataset实例
train_dataset = RMBDataset(image_path=train_path, transform=train_transform)

# 构建DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)

for i, (inputs, labels) in enumerate(train_loader):
    print('Batch Shape: ', inputs.shape, 'Labels Shape:', labels.shape)
    break

Batch Shape:  torch.Size([16, 3, 32, 32]) Labels Shape: torch.Size([16])


### 2.网络模型构建

In [21]:
class LeNet(nn.Module):
    
    def __init__(self, num_classes):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 64)
        self.fc3 = nn.Linear(64, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

In [24]:
model = LeNet(num_classes=2)

for i, (inputs, labels) in enumerate(train_loader):
    print('Batch Shape: ', inputs.shape, 'Labels Shape:', labels.shape)
    outputs = model(inputs)
    print('Outputs Shape:', outputs.shape)
    break 

Batch Shape:  torch.Size([16, 3, 32, 32]) Labels Shape: torch.Size([16])
Outputs Shape: torch.Size([16, 2])
