In [20]:
import os
import time
import numpy as np
import pandas as pd
import torch
from PIL import Image
from torch import nn
from torch.utils.data import Dataset
from torchvision import transforms, models
from tqdm import tqdm

In [21]:
labels_dataframe = pd.read_csv('./dataset/classify_leaves/train.csv')
# labels_dataframe.head(5)
labels_test=pd.read_csv('./dataset/classify_leaves/test.csv')
labels_test

Unnamed: 0,image
0,images/18353.jpg
1,images/18354.jpg
2,images/18355.jpg
3,images/18356.jpg
4,images/18357.jpg
...,...
8795,images/27148.jpg
8796,images/27149.jpg
8797,images/27150.jpg
8798,images/27151.jpg


In [22]:
np.asarray(labels_dataframe.iloc[2:10, 0])

array(['images/2.jpg', 'images/3.jpg', 'images/4.jpg', 'images/5.jpg',
       'images/6.jpg', 'images/7.jpg', 'images/8.jpg', 'images/9.jpg'],
      dtype=object)

In [23]:
labels_dataframe.describe()

Unnamed: 0,image,label
count,18353,18353
unique,18353,176
top,images/129.jpg,maclura_pomifera
freq,1,353


In [24]:
leaves_labels = sorted(labels_dataframe['label'].unique())

In [25]:
cls_to_num = dict(zip(leaves_labels, range(len(leaves_labels))))

In [26]:
num_to_cls = dict(zip(range(len(leaves_labels)), leaves_labels))

In [27]:
len(labels_dataframe)

18353

In [28]:
class LeavesData(Dataset):
    def __init__(self, csv_path, img_path, mode='train', valid_ration=0.2, resize_h=256, resize_w=256):
        """
        Args:
            csv_path (string): csv 文件路径
            img_path (string): 图像文件所在路径
            mode (string): 训练模式还是测试模式
            valid_ratio (float): 验证集比例
        """
        self.resize_h, self.resize_w = resize_h, resize_w
        self.img_path = img_path
        self.mode = mode
        self.data_info = pd.read_csv(csv_path, header=None)
        self.data_len = len(self.data_info) - 1
        self.train_data_len = int(self.data_len * (1 - valid_ration))
        # print(self.data_len)
        if mode == 'train':
            self.train_img = np.asarray(self.data_info.iloc[1:self.train_data_len, 0])
            self.train_label = np.asarray(self.data_info.iloc[1:self.train_data_len, 1])
            self.img_arr = self.train_img
            self.label_arr = self.train_label
        elif mode == 'valid':
            self.valid_img = np.asarray(self.data_info.iloc[self.train_data_len:, 0])
            self.valid_label = np.asarray(self.data_info.iloc[self.train_data_len:, 1])
            self.img_arr = self.valid_img
            self.label_arr = self.valid_label
        else:
            self.test_img = np.asarray(self.data_info.iloc[1:, 0])
            self.img_arr = self.test_img
        self.real_len = len(self.img_arr)
        print(f'Finished reading the {mode} set of Leaves Dataset ({self.real_len} samples found)')

    def __getitem__(self, index):
        img_name = self.img_arr[index]
        img_as_img = Image.open(os.path.join(self.img_path, img_name))
        # 如果需要将RGB三通道的图片转换成灰度图片可参考下面两行
        #         if img_as_img.mode != 'L':
        #             img_as_img = img_as_img.convert('L')
        if self.mode == 'train':
            transform = transforms.Compose([
                # transforms.RandomResizedCrop((224,224),scale=(0.8,1.0),ratio=(0.8,1.2)),#随机裁剪
                # transforms.transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),#随机调整图片的亮度、对比度、饱和度、色调
                # transforms.Resize((224,224)),#缩放图片
                transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
                transforms.RandomVerticalFlip(p=0.5),  # 随机垂直翻转
                transforms.ToTensor(), ]  # 将图片转换成Tensor
            )
        else:
            transform = transforms.Compose([
                transforms.Resize((224, 224)),  # 缩放图片
                transforms.ToTensor(), ]  # 将图片转换成Tensor
            )
        img_as_tensor = transform(img_as_img)
        if self.mode == 'test':
            return img_as_tensor
        else:
            label = self.label_arr[index]
            num_label = cls_to_num[label]
            return img_as_tensor, num_label

    def __len__(self):
        return self.real_len

In [29]:
train_path = 'dataset/classify_leaves/train.csv'
test_path = 'dataset/classify_leaves/test.csv'
img_path = 'dataset/classify_leaves/'
train_dataset = LeavesData(train_path, img_path, mode='train')
val_dataset = LeavesData(train_path, img_path, mode='valid')
test_dataset = LeavesData(test_path, img_path, mode='test')
print(train_dataset)
print(val_dataset)
print(test_dataset)

Finished reading the train set of Leaves Dataset (14681 samples found)
Finished reading the valid set of Leaves Dataset (3672 samples found)
Finished reading the test set of Leaves Dataset (8800 samples found)
<__main__.LeavesData object at 0x7fb4b2546340>
<__main__.LeavesData object at 0x7fb4b25463a0>
<__main__.LeavesData object at 0x7fb4b2546190>


In [30]:
train_iter = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=False)
val_iter = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=False)
test_iter = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=False)

In [31]:
def set_parameter_requires_grad(model, feature_extracting):
    # if feature_extracting:
    #     model=model
    #     for i,param in enumerate(model.children()):
    #         if i==8:
    #             break
    #         param.requires_grad = False
    if feature_extracting:
        model = model
        for param in model.parameters():
            param.requires_grad = False


# resnet34
def res_model(num_classes, feature_extract=False, use_pretrained=True):
    model_ft = models.resnet34(pretrained=use_pretrained)
    set_parameter_requires_grad(model_ft, feature_extract)
    num_ftrs = model_ft.fc.in_features
    #     model_ft.fc = nn.Sequential(
    #         nn.Linear(num_ftrs, 512),
    #         nn.ReLU(inplace=True),
    #         nn.Dropout(.3),
    #         nn.Linear(512, len(num_to_class))
    #     )
    model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, num_classes))
    return model_ft


# resnext50模型
def resnext_model(num_classes, feature_extract=False, use_pretrained=True):
    model_ft = models.resnext50_32x4d(pretrained=use_pretrained)
    set_parameter_requires_grad(model_ft, feature_extract)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, num_classes))
    return model_ft

In [32]:
learning_rate = 3e-4
weight_decay = 1e-3
num_epoch = 50
model_path = 'dataset/classify_leaves/pre_resnext_model.ckpt'

In [14]:
device = 'cuda'
# model = res_model(176)
model = resnext_model(176)
model = model.to(device)
model.device = device
# For the classification task, we use cross-entropy as the measurement of performance.
criterion = nn.CrossEntropyLoss()
# Initialize optimizer, you may fine-tune some hyperparameters such as learning rate on your own.
# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = learning_rate, weight_decay=weight_decay)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
# 余弦退火，last_epoch表示上一次训练的最后一个epoch,-1表示当前轮从0开始
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0, last_epoch=-1)
# The number of training epochs.
n_epochs = num_epoch

In [None]:
best_acc = 0.0
for epoch in range(n_epochs):
    time.sleep(0.5)
    model.train()
    train_loss = []
    train_acc = []
    i = 0
    # Iterate the training set by batches.
    for batch in tqdm(train_iter):
        time.sleep(0.05)
        imgs, labels = batch
        imgs = imgs.to(device)
        labels = labels.to(device)
        logits = model(imgs)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        if (i % 500 == 0):
            print("learning_rate:", scheduler.get_last_lr()[0])
        i += 1
        # print((logits.argmax(dim=1) == labels))
        acc = (logits.argmax(dim=1) == labels).float().mean()
        train_loss.append(loss.item())
        train_acc.append(acc)
    # The average loss and accuracy of the training set is the average of the recorded values.
    train_loss = sum(train_loss) / len(train_loss)
    train_acc = sum(train_acc) / len(train_acc)
    model.eval()
    # These are used to record information in validation.
    valid_loss = []
    valid_accs = []
    for batch in tqdm(val_iter):
        imgs, labels = batch
        # We don't need gradient in validation.
        # Using torch.no_grad() accelerates the forward process.
        with torch.no_grad():
            logits = model(imgs.to(device))

        # We can still compute the loss (but not the gradient).
        loss = criterion(logits, labels.to(device))

        # Compute the accuracy for current batch.
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # Record the loss and accuracy.
        valid_loss.append(loss.item())
        valid_accs.append(acc)
        #     # The average loss and accuracy for entire validation set is the average of the recorded values.
    valid_loss = sum(valid_loss) / len(valid_loss)
    valid_acc = sum(valid_accs) / len(valid_accs)
    # Print the information.
    print(f"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")
    if valid_acc > best_acc:
        best_acc = valid_acc
        torch.save(model.state_dict(), model_path)
        print('saving model with acc {:.3f}'.format(best_acc))

  0%|          | 2/918 [00:05<35:58,  2.36s/it]  

learning_rate: 0.00029265847744427303


 55%|█████▍    | 503/918 [00:52<00:38, 10.83it/s]

learning_rate: 0.0002926584774442748


100%|██████████| 918/918 [01:30<00:00, 10.11it/s]
100%|██████████| 230/230 [00:08<00:00, 28.11it/s]


[ Valid | 001/050 ] loss = 1.12001, acc = 0.66087
saving model with acc 0.661


  0%|          | 3/918 [00:01<08:07,  1.88it/s]

learning_rate: 0.00029265847744432296


 55%|█████▍    | 503/918 [00:47<00:37, 11.16it/s]

learning_rate: 0.0002926584774443747


100%|██████████| 918/918 [01:25<00:00, 10.73it/s]
100%|██████████| 230/230 [00:09<00:00, 24.87it/s]


[ Valid | 002/050 ] loss = 0.57960, acc = 0.82201
saving model with acc 0.822


  0%|          | 2/918 [00:02<17:14,  1.13s/it]

learning_rate: 0.00023816778784394873


 55%|█████▍    | 502/918 [00:48<00:37, 11.02it/s]

learning_rate: 0.000238167787843908


100%|██████████| 918/918 [01:26<00:00, 10.64it/s]
100%|██████████| 230/230 [00:08<00:00, 26.62it/s]


[ Valid | 003/050 ] loss = 0.54143, acc = 0.83913
saving model with acc 0.839


  0%|          | 3/918 [00:02<09:07,  1.67it/s]

learning_rate: 0.00015000000000013435


 55%|█████▍    | 503/918 [00:48<00:37, 10.98it/s]

learning_rate: 0.0001500000000000938


100%|██████████| 918/918 [01:26<00:00, 10.66it/s]
100%|██████████| 230/230 [00:08<00:00, 27.50it/s]


[ Valid | 004/050 ] loss = 0.82339, acc = 0.77527


  0%|          | 3/918 [00:02<10:05,  1.51it/s]

learning_rate: 6.183221215612329e-05


 55%|█████▍    | 503/918 [00:48<00:37, 11.16it/s]

learning_rate: 6.183221215614604e-05


100%|██████████| 918/918 [01:26<00:00, 10.63it/s]
100%|██████████| 230/230 [00:08<00:00, 26.36it/s]


[ Valid | 005/050 ] loss = 0.55107, acc = 0.84293
saving model with acc 0.843


  0%|          | 2/918 [00:01<12:35,  1.21it/s]

learning_rate: 7.34152255572697e-06


 55%|█████▍    | 502/918 [00:47<00:37, 10.99it/s]

learning_rate: 7.34152255572697e-06


100%|██████████| 918/918 [01:25<00:00, 10.69it/s]
100%|██████████| 230/230 [00:08<00:00, 25.93it/s]


[ Valid | 006/050 ] loss = 0.48470, acc = 0.86766
saving model with acc 0.868


  0%|          | 3/918 [00:02<09:23,  1.62it/s]

learning_rate: 7.341522555728597e-06


 55%|█████▍    | 503/918 [00:48<00:37, 11.16it/s]

learning_rate: 7.341522555736413e-06


100%|██████████| 918/918 [01:26<00:00, 10.60it/s]
100%|██████████| 230/230 [00:08<00:00, 28.19it/s]


[ Valid | 007/050 ] loss = 0.44337, acc = 0.87418
saving model with acc 0.874


  0%|          | 3/918 [00:01<07:27,  2.05it/s]

learning_rate: 6.183221215620828e-05


 55%|█████▍    | 502/918 [00:49<00:38, 10.94it/s]

learning_rate: 6.183221215607377e-05


100%|██████████| 918/918 [01:28<00:00, 10.41it/s]
100%|██████████| 230/230 [00:08<00:00, 27.62it/s]


[ Valid | 008/050 ] loss = 1.69550, acc = 0.60842


  0%|          | 3/918 [00:01<07:08,  2.14it/s]

learning_rate: 0.00015000000000010337


 55%|█████▍    | 503/918 [00:47<00:38, 10.88it/s]

learning_rate: 0.00015000000000037688


100%|██████████| 918/918 [01:25<00:00, 10.73it/s]
100%|██████████| 230/230 [00:08<00:00, 27.07it/s]


[ Valid | 009/050 ] loss = 0.58069, acc = 0.83234


  0%|          | 3/918 [00:01<07:28,  2.04it/s]

learning_rate: 0.00023816778784413724


 55%|█████▍    | 503/918 [00:48<00:37, 10.96it/s]

learning_rate: 0.0002381677878439048


100%|██████████| 918/918 [01:26<00:00, 10.60it/s]
100%|██████████| 230/230 [00:08<00:00, 28.18it/s]


[ Valid | 010/050 ] loss = 0.34194, acc = 0.90870
saving model with acc 0.909


  0%|          | 3/918 [00:01<07:59,  1.91it/s]

learning_rate: 0.00029265847744386233


 55%|█████▍    | 503/918 [00:47<00:37, 11.08it/s]

learning_rate: 0.0002926584774444529


100%|██████████| 918/918 [01:25<00:00, 10.76it/s]
100%|██████████| 230/230 [00:08<00:00, 25.92it/s]


[ Valid | 011/050 ] loss = 0.27153, acc = 0.91902
saving model with acc 0.919


  0%|          | 3/918 [00:01<06:56,  2.19it/s]

learning_rate: 0.0002926584774448399


 55%|█████▍    | 503/918 [00:46<00:37, 11.16it/s]

learning_rate: 0.0002926584774445823


100%|██████████| 918/918 [01:25<00:00, 10.74it/s]
100%|██████████| 230/230 [00:08<00:00, 27.57it/s]


[ Valid | 012/050 ] loss = 0.37814, acc = 0.88668


  0%|          | 3/918 [00:02<08:15,  1.84it/s]

learning_rate: 0.0002381677878444142


 55%|█████▍    | 503/918 [00:48<00:38, 10.84it/s]

learning_rate: 0.00023816778784358774


100%|██████████| 918/918 [01:26<00:00, 10.59it/s]
100%|██████████| 230/230 [00:08<00:00, 25.62it/s]


[ Valid | 013/050 ] loss = 0.39615, acc = 0.89457


  0%|          | 3/918 [00:01<07:43,  1.97it/s]

learning_rate: 0.00015000000000037686


 55%|█████▍    | 502/918 [00:48<00:39, 10.48it/s]

learning_rate: 0.00014999999999990038


100%|██████████| 918/918 [01:27<00:00, 10.47it/s]
100%|██████████| 230/230 [00:08<00:00, 26.51it/s]


[ Valid | 014/050 ] loss = 0.30463, acc = 0.91033


  0%|          | 3/918 [00:02<08:26,  1.81it/s]

learning_rate: 6.18322121561555e-05


 55%|█████▍    | 502/918 [00:47<00:36, 11.27it/s]

learning_rate: 6.183221215606186e-05


100%|██████████| 918/918 [01:25<00:00, 10.71it/s]
100%|██████████| 230/230 [00:08<00:00, 27.17it/s]


[ Valid | 015/050 ] loss = 0.41045, acc = 0.88913


  0%|          | 2/918 [00:02<13:20,  1.14it/s]

learning_rate: 7.34152255572697e-06


 55%|█████▍    | 503/918 [00:48<00:37, 11.17it/s]

learning_rate: 7.34152255572697e-06


100%|██████████| 918/918 [01:27<00:00, 10.54it/s]
100%|██████████| 230/230 [00:10<00:00, 22.95it/s]


[ Valid | 016/050 ] loss = 0.68864, acc = 0.82663


  0%|          | 3/918 [00:01<07:38,  2.00it/s]

learning_rate: 7.3415225557185945e-06


 55%|█████▍    | 503/918 [00:49<00:40, 10.24it/s]

learning_rate: 7.341522555789652e-06


100%|██████████| 918/918 [01:28<00:00, 10.35it/s]
100%|██████████| 230/230 [00:08<00:00, 26.66it/s]


[ Valid | 017/050 ] loss = 0.43216, acc = 0.88043


  0%|          | 3/918 [00:02<08:15,  1.85it/s]

learning_rate: 6.183221215633051e-05


 55%|█████▍    | 502/918 [00:48<00:38, 10.85it/s]

learning_rate: 6.183221215590225e-05


100%|██████████| 918/918 [01:25<00:00, 10.69it/s]
100%|██████████| 230/230 [00:09<00:00, 24.90it/s]


[ Valid | 018/050 ] loss = 0.39915, acc = 0.88587


  0%|          | 2/918 [00:01<10:58,  1.39it/s]

learning_rate: 0.0001500000000003028


 55%|█████▍    | 502/918 [00:46<00:35, 11.66it/s]

learning_rate: 0.0001500000000000092


100%|██████████| 918/918 [01:24<00:00, 10.86it/s]
100%|██████████| 230/230 [00:08<00:00, 27.26it/s]


[ Valid | 019/050 ] loss = 1.47073, acc = 0.65761


  0%|          | 3/918 [00:01<07:40,  1.99it/s]

learning_rate: 0.00023816778784359243


 55%|█████▍    | 503/918 [00:47<00:38, 10.82it/s]

learning_rate: 0.00023816778784483772


100%|██████████| 918/918 [01:26<00:00, 10.61it/s]
100%|██████████| 230/230 [00:11<00:00, 19.19it/s]


[ Valid | 020/050 ] loss = 0.29414, acc = 0.91576


  0%|          | 3/918 [00:02<09:09,  1.67it/s]

learning_rate: 0.0002926584774449294


 55%|█████▍    | 503/918 [00:48<00:38, 10.79it/s]

learning_rate: 0.00029265847744470075


100%|██████████| 918/918 [01:27<00:00, 10.52it/s]
100%|██████████| 230/230 [00:08<00:00, 27.39it/s]


[ Valid | 021/050 ] loss = 0.30327, acc = 0.91413


  0%|          | 3/918 [00:01<07:28,  2.04it/s]

learning_rate: 0.00029265847744423606


 55%|█████▍    | 502/918 [00:47<00:40, 10.24it/s]

learning_rate: 0.0002926584774456586


100%|██████████| 918/918 [01:26<00:00, 10.56it/s]
100%|██████████| 230/230 [00:08<00:00, 27.33it/s]


[ Valid | 022/050 ] loss = 0.30913, acc = 0.91984
saving model with acc 0.920


  0%|          | 3/918 [00:01<07:29,  2.04it/s]

learning_rate: 0.00023816778784463478


 55%|█████▍    | 503/918 [00:48<00:37, 11.10it/s]

learning_rate: 0.0002381677878444371


100%|██████████| 918/918 [01:27<00:00, 10.53it/s]
100%|██████████| 230/230 [00:08<00:00, 27.31it/s]


[ Valid | 023/050 ] loss = 0.40138, acc = 0.89837


  0%|          | 2/918 [00:01<11:24,  1.34it/s]

learning_rate: 0.0001500000000009709


 55%|█████▍    | 502/918 [00:47<00:37, 11.03it/s]

learning_rate: 0.0001500000000008569


100%|██████████| 918/918 [01:26<00:00, 10.65it/s]
100%|██████████| 230/230 [00:08<00:00, 27.18it/s]


[ Valid | 024/050 ] loss = 0.37841, acc = 0.89185


  0%|          | 3/918 [00:01<06:45,  2.26it/s]

learning_rate: 6.183221215636514e-05


 55%|█████▍    | 503/918 [00:48<00:38, 10.88it/s]

learning_rate: 6.183221215586132e-05


100%|██████████| 918/918 [01:26<00:00, 10.58it/s]
100%|██████████| 230/230 [00:08<00:00, 27.55it/s]


[ Valid | 025/050 ] loss = 2.33721, acc = 0.53152


  0%|          | 2/918 [00:01<10:06,  1.51it/s]

learning_rate: 7.34152255572697e-06


 55%|█████▍    | 502/918 [00:47<00:38, 10.68it/s]

learning_rate: 7.34152255572697e-06


100%|██████████| 918/918 [01:27<00:00, 10.52it/s]
100%|██████████| 230/230 [00:08<00:00, 27.71it/s]


[ Valid | 026/050 ] loss = 0.41422, acc = 0.89266


  0%|          | 3/918 [00:02<08:35,  1.77it/s]

learning_rate: 7.341522555729688e-06


 55%|█████▍    | 503/918 [00:48<00:37, 11.12it/s]

learning_rate: 7.3415225557164295e-06


100%|██████████| 918/918 [01:26<00:00, 10.65it/s]
100%|██████████| 230/230 [00:09<00:00, 25.47it/s]


[ Valid | 027/050 ] loss = 0.37195, acc = 0.89946


  0%|          | 3/918 [00:01<07:17,  2.09it/s]

learning_rate: 6.183221215626936e-05


 55%|█████▍    | 503/918 [00:46<00:37, 11.08it/s]

learning_rate: 6.183221215630655e-05


100%|██████████| 918/918 [01:24<00:00, 10.86it/s]
100%|██████████| 230/230 [00:08<00:00, 27.89it/s]


[ Valid | 028/050 ] loss = 0.30972, acc = 0.91902


  0%|          | 3/918 [00:02<08:13,  1.85it/s]

learning_rate: 0.00014999999999979865


 55%|█████▍    | 502/918 [00:47<00:37, 10.97it/s]

learning_rate: 0.00015000000000050295


100%|██████████| 918/918 [01:25<00:00, 10.78it/s]
100%|██████████| 230/230 [00:08<00:00, 28.28it/s]


[ Valid | 029/050 ] loss = 0.41054, acc = 0.88995


  0%|          | 3/918 [00:02<08:44,  1.75it/s]

learning_rate: 0.00023816778784315813


 55%|█████▍    | 503/918 [00:47<00:37, 11.00it/s]

learning_rate: 0.0002381677878456611


100%|██████████| 918/918 [01:25<00:00, 10.79it/s]
100%|██████████| 230/230 [00:08<00:00, 27.68it/s]


[ Valid | 030/050 ] loss = 0.28260, acc = 0.91821


  0%|          | 3/918 [00:02<09:40,  1.58it/s]

learning_rate: 0.00029265847744429444


 55%|█████▍    | 503/918 [00:47<00:37, 11.06it/s]

learning_rate: 0.00029265847744402356


100%|██████████| 918/918 [01:26<00:00, 10.64it/s]
100%|██████████| 230/230 [00:08<00:00, 27.23it/s]


[ Valid | 031/050 ] loss = 0.45746, acc = 0.88098


  0%|          | 3/918 [00:01<07:52,  1.94it/s]

learning_rate: 0.00029265847744527093


 55%|█████▍    | 502/918 [00:48<00:41, 10.10it/s]

learning_rate: 0.0002926584774450127


100%|██████████| 918/918 [01:27<00:00, 10.44it/s]
100%|██████████| 230/230 [00:09<00:00, 25.49it/s]


[ Valid | 032/050 ] loss = 0.38456, acc = 0.90000


  0%|          | 3/918 [00:02<09:01,  1.69it/s]

learning_rate: 0.0002381677878453185


100%|██████████| 918/918 [01:26<00:00, 10.63it/s]
100%|██████████| 230/230 [00:08<00:00, 26.59it/s]


[ Valid | 033/050 ] loss = 0.42307, acc = 0.89348


  0%|          | 3/918 [00:02<08:13,  1.85it/s]

learning_rate: 0.0001500000000014285


 55%|█████▍    | 503/918 [00:47<00:37, 11.20it/s]

learning_rate: 0.00014999999999959176


100%|██████████| 918/918 [01:24<00:00, 10.80it/s]
100%|██████████| 230/230 [00:08<00:00, 27.55it/s]


[ Valid | 034/050 ] loss = 0.34737, acc = 0.90870


  0%|          | 3/918 [00:01<08:01,  1.90it/s]

learning_rate: 6.183221215627496e-05


 55%|█████▍    | 503/918 [00:47<00:37, 11.21it/s]

learning_rate: 6.183221215601576e-05


100%|██████████| 918/918 [01:25<00:00, 10.75it/s]
100%|██████████| 230/230 [00:08<00:00, 27.14it/s]


[ Valid | 035/050 ] loss = 0.34421, acc = 0.90190


  0%|          | 3/918 [00:02<08:15,  1.85it/s]

learning_rate: 7.34152255572697e-06


 55%|█████▍    | 502/918 [00:48<00:37, 10.96it/s]

learning_rate: 7.34152255572697e-06


100%|██████████| 918/918 [01:26<00:00, 10.67it/s]
100%|██████████| 230/230 [00:08<00:00, 26.24it/s]


[ Valid | 036/050 ] loss = 0.44180, acc = 0.88505


  0%|          | 2/918 [00:01<10:59,  1.39it/s]

learning_rate: 7.341522555740778e-06


 55%|█████▍    | 502/918 [00:47<00:37, 10.96it/s]

learning_rate: 7.341522555727503e-06


100%|██████████| 918/918 [01:25<00:00, 10.69it/s]
100%|██████████| 230/230 [00:08<00:00, 26.38it/s]


[ Valid | 037/050 ] loss = 1.55395, acc = 0.65109


  0%|          | 3/918 [00:02<09:11,  1.66it/s]

learning_rate: 6.183221215656334e-05


 55%|█████▍    | 503/918 [00:48<00:37, 11.01it/s]

learning_rate: 6.183221215626926e-05


100%|██████████| 918/918 [01:26<00:00, 10.61it/s]
100%|██████████| 230/230 [00:08<00:00, 27.07it/s]


[ Valid | 038/050 ] loss = 0.47340, acc = 0.87853


  0%|          | 2/918 [00:01<12:33,  1.22it/s]

learning_rate: 0.00014999999999956696


 55%|█████▍    | 502/918 [00:47<00:37, 11.09it/s]

learning_rate: 0.00015000000000113284


 84%|████████▍ | 770/918 [01:11<00:13, 10.61it/s]

In [2]:
saveFileName = './dataset/classify_leaves/submission.csv'

In [33]:
device='cuda'
model = resnext_model(176)
model = model.to(device)
model.load_state_dict(torch.load(model_path))
model.eval()  #batchnormize用全局的batch算
preds = []
for batch in tqdm(test_iter):
    imgs = batch
    imgs = imgs.to(device)
    with torch.no_grad():
        logits = model(imgs)
    preds.extend(logits.argmax(dim=-1).cpu().numpy().tolist())
res=[num_to_cls[i] for i in preds]
test_data=pd.read_csv(test_path)
test_data['label']=pd.Series(res)
submission=pd.concat([test_data['image'],test_data['label']],axis=1)
submission.to_csv(saveFileName,index=False)

100%|██████████| 550/550 [00:51<00:00, 10.63it/s]
