# 垃圾分类

当前地址，使用了google colab。

In [None]:
!ls

drive  sample_data


In [None]:
import os 
print(os.getcwd())


/content


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
import numpy as np

**请修改数据集的地址，保证代码正常运行。**

In [None]:
# 数据集地址
# /content/drive/My Drive/train_data_v2
# data_path = os.getcwd() + '/train_data_v2/'
data_path = '/content/drive/My Drive' + '/train_data_v2/'

## 自定义数据集

In [None]:

# 定义读取文件的格式
def default_loader(path):
    return Image.open(path).convert('RGB')


# 创建自己的类：MyDataset, 继承 Dataset类
class MyDataset(Dataset):
    def __init__(self, txt, data_path=None, transform=None, target_transform=None, loader=default_loader):
        super(MyDataset, self).__init__() # 对继承父类的属性初始化
        # 在__init__()方法中得到图像的路径，然后将图像路径组成一个数组
        file_path = data_path + txt
        file = open(file_path, 'r')
        imgs = []
        for line in file:
            line = line.split()
            # print(line[0].rstrip(','))  # img
            # print(line[1].rstrip('\n'))  # label
            imgs.append((line[0].rstrip(','), line[1].rstrip('\n')))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        self.data_path = data_path

    def __getitem__(self, index):
        # 按照索引读取每个元素的具体内容
        imgName, label = self.imgs[index]
        imgPath = self.data_path + imgName
        img = self.loader(imgPath)
        if self.transform is not None:
            img = self.transform(img)  # 数据标签转换为Tensor
            label = torch.from_numpy(np.array(int(label)))
        return img, label

    def __len__(self):
        # 数据集的图片数量
        return len(self.imgs)


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.models import resnet50
from torch.utils.data import Dataset, DataLoader


## 模型训练及测试

In [None]:


# 预处理的设置
# 图片转化为resnet50规定的图片大小
# 归一化是减去均值，除以方差
# 把 numpy array 转化为 tensor 的格式
my_tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.545, 0.506, 0.459], [0.207, 0.212, 0.220])])

# 数据集加载方式设置
train_data = MyDataset(txt='train.txt', data_path=data_path, transform=my_tf)
test_data = MyDataset(txt='test.txt', data_path=data_path, transform=my_tf)

# 调用DataLoader和数据集
train_loader = DataLoader(dataset=train_data, batch_size=16, shuffle=True, num_workers=2)
test_loader = DataLoader(dataset=test_data, batch_size=1, shuffle=False, num_workers=2)

# 使用resnet50
my_resnet50 = resnet50(pretrained=True)

# 固定网络框架全连接层之前的参数
for param in my_resnet50.parameters():
    param.requires_grad = False

# 将resnet50最后一层输出的类别数，改为垃圾分类数据集的类别数（40）
in_f = my_resnet50.fc.in_features
my_resnet50.fc = nn.Linear(in_f, 40)

# 超参数设置
learn_rate = 0.001
num_epoches = 20
# 多分类损失函数，使用默认值
criterion = nn.CrossEntropyLoss()
# 梯度下降，求解模型最后一层参数
optimizer = optim.SGD(my_resnet50.fc.parameters(), lr=learn_rate, momentum=0.9)
# 判断使用CPU还是GPU
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# 训练阶段
my_resnet50.to(device)
my_resnet50.train()
for epoch in range(num_epoches):
    print(f"epoch: {epoch+1}")
    for idx, (img, label) in enumerate(train_loader):
        images = img.to(device)
        labels = label.to(device)
        output = my_resnet50(images)
        loss = criterion(output, labels)
        loss.backward()  # 损失反向传播
        optimizer.step()  # 更新梯度
        optimizer.zero_grad()  # 梯度清零
        if idx % 100 == 0:
            print(f"current loss = {loss.item()}")


# 测试阶段
my_resnet50.to(device)
my_resnet50.eval()  # 把训练好的模型的参数冻结
total, correct = 0, 0
for img, label in test_loader:
    images = img.to(device)
    labels = label.to(device)
    #print("label: ",labels)
    output = my_resnet50(images)
    #print("output:", output.data.size)
    _, idx = torch.max(output.data, 1) # 输出最大值的位置
    #print("idx: ", idx)
    total += labels.size(0) # 全部图片
    correct += (idx == labels).sum() # 正确的图片
    #print("correct_num: %f",correct)
print("correct_num: ", correct)
print("total_image_num: ", total)
print(f"accuracy:{100.*correct/total}")

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/checkpoints/resnet50-19c8e357.pth


HBox(children=(FloatProgress(value=0.0, max=102502400.0), HTML(value='')))


epoch: 1
current loss = 3.638876438140869
current loss = 2.8584587574005127
current loss = 2.2275004386901855
current loss = 1.4816701412200928
current loss = 2.0192432403564453
current loss = 1.52982497215271
current loss = 1.020609736442566
epoch: 2
current loss = 1.6001406908035278
current loss = 1.0698637962341309
current loss = 1.090205192565918
current loss = 0.8029191493988037
current loss = 0.7371137142181396
current loss = 0.7341311573982239
current loss = 0.8181235194206238
epoch: 3
current loss = 0.7901668548583984
current loss = 0.8786979913711548
current loss = 0.9400674700737
current loss = 0.6751846075057983
current loss = 0.4387664496898651
current loss = 0.2824462652206421
current loss = 0.6610844731330872
epoch: 4
current loss = 0.6385390758514404
current loss = 0.847917914390564
current loss = 0.9723794460296631
current loss = 0.8351149559020996
current loss = 0.6543305516242981
current loss = 0.46012964844703674
current loss = 0.5872468948364258
epoch: 5
current lo