In [1]:
import ResNeXt

import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset
import torch.nn.functional as F
import cv2
from PIL import Image
import os
from torchvision import transforms
from tqdm import tqdm

In [2]:
# 引入标签
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")    # 定义训练用
model = ResNeXt.ResNeXt50_32x4d().to(device)                               # 导入模型
loss_fuc = nn.CrossEntropyLoss()                                            # 设置损失函数 
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.8)                  # 设置梯度下降函数
transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

In [3]:
img_labels = pd.read_csv(r'C:\Users\24468\Desktop\python练习\cassava-leaf-disease-classification\train.csv')

In [4]:
img_labels.shape

(21397, 2)

In [5]:
img_dir  = r'C:\Users\24468\Desktop\python练习\cassava-leaf-disease-classification\train_images'
annotations_path = r'C:\Users\24468\Desktop\python练习\cassava-leaf-disease-classification\train.csv'

class MyDataset(Dataset): # 定义类
    def __init__(self, annotations_file, img_dir, img_transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.img_transform = img_transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = Image.open(img_path)
        image = self.img_transform(image)
        label = self.img_labels.iloc[idx, 1]
        return image, label
    
dataset = MyDataset(annotations_path, img_dir, img_transform=transform, target_transform=None)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=64)

In [6]:
# 开始训练
Epoch = 3
for epoch in range(Epoch):
    sum_loss = 0
    for i, data in tqdm(enumerate(train_loader)):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fuc(outputs, labels)
        loss.backward()
        optimizer.step()
        sum_loss += loss.item()
        if i % 100 == 99:
            print('[Epoch:%d, batch:%d] train loss: %.03f' % (epoch + 1, i + 1, sum_loss / 100))
            sum_loss = 0.0

100it [29:55, 16.09s/it]

[Epoch:1, batch:100] train loss: 1.471


200it [56:32, 20.62s/it]

[Epoch:1, batch:200] train loss: 1.206


300it [1:31:48, 21.43s/it]

[Epoch:1, batch:300] train loss: 1.145


335it [1:44:03, 18.64s/it]
100it [35:23, 21.23s/it]

[Epoch:2, batch:100] train loss: 1.044


200it [1:10:57, 21.18s/it]

[Epoch:2, batch:200] train loss: 1.037


300it [1:42:25, 12.86s/it]

[Epoch:2, batch:300] train loss: 1.014


335it [1:49:38, 19.64s/it]
100it [21:13, 12.73s/it]

[Epoch:3, batch:100] train loss: 0.967


200it [43:29, 13.95s/it]

[Epoch:3, batch:200] train loss: 0.943


300it [1:07:05, 13.11s/it]

[Epoch:3, batch:300] train loss: 0.902


335it [1:14:55, 13.42s/it]


In [7]:
torch.save(model, r"C:\Users\24468\Desktop\python练习\cassava-leaf-disease-classification\ResNext_model.pt")