挂载Google Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


解压文件

In [2]:
!cp -r /content/drive/My\ Drive/Scene/Image_Classification.zip ./  #将google云盘中的数据集压缩文件拷贝到当前运行环境
!unzip Image_Classification.zip  #将数据集压缩文件解压，在当前运行环境得到'train'文件夹、'test'文件夹和'train.csv'文件

[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
  inflating: train/55.jpg            
  inflating: train/550.jpg           
  inflating: train/5500.jpg          
  inflating: train/5501.jpg          
  inflating: train/5502.jpg          
  inflating: train/5503.jpg          
  inflating: train/5504.jpg          
  inflating: train/5505.jpg          
  inflating: train/5506.jpg          
  inflating: train/5507.jpg          
  inflating: train/5508.jpg          
  inflating: train/5509.jpg          
  inflating: train/551.jpg           
  inflating: train/5510.jpg          
  inflating: train/5511.jpg          
  inflating: train/5512.jpg          
  inflating: train/5513.jpg          
  inflating: train/5514.jpg          
  inflating: train/5515.jpg          
  inflating: train/5516.jpg          
  inflating: train/5517.jpg          
  inflating: train/5518.jpg          
  inflating: train/5519.jpg          
  inflating: train/552.jpg           
  inflating: train/5520.jpg          
  inflati

创建一个文件夹存放训练好的模型

In [3]:
! mkdir /content/drive/My\ Drive/Scene/checkpoint

mkdir: cannot create directory ‘/content/drive/My Drive/Scene/checkpoint’: File exists


导入所有要用的包（需要一个写一个）

In [4]:
import torch
import pandas as pd
from PIL import Image
from torchvision import transforms, models
from torch.utils.data import random_split, DataLoader
import os
import torch.nn as nn
import time
import torch.optim as optim

查看是否使用GPU

In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'CPU')
print(device)

cuda:0


读取训练集的带标签文件（此处为CSV格式文件）

In [6]:
def readLabelFile():
  label_file = pd.read_csv('train.csv')
  return label_file['filename'],label_file['label']

filename,filelabel = readLabelFile()
map = ['buildings', 'street', 'forest', 'sea', 'mountain', 'glacier']
num_class = len(map)
#将label中的字符串转换为数字
for i in range(len(map)):
  filelabel[filelabel==map[i]] = i
#将对象转换为列表  
filename = filename.values
filelabel = filelabel.values

定义读取数据集的类(包括训练集和测试集)

In [7]:
class TrainDataset(torch.utils.data.Dataset):

  def __init__(self, root, img_list, label_list, transform = None):
    self.root = root
    self.img_list = img_list
    self.label_list = label_list
    self.transform = transform
  
  def __getitem__(self, index):
    img = Image.open(self.root + self.img_list[index]).convert('RGB')
    label = self.label_list[index]
    if self.transform:
      img = self.transform(img)
    return img,label
  
  def __len__(self):
    return len(self.img_list)


class TestDataset(torch.utils.data.Dataset):

  def __init__(self, img_path, transform = None):
    self.img_path = img_path
    self.transform = transform
  
  def __getitem__(self, index):
    img = Image.open(self.img_path[index]).convert('RGB')
    if self.transform:
      img = self.transform(img)
    return img,index

  def __len__(self):
    return len(self.img_path)

对数据集进行预处理

In [8]:
transform = {
    'train': transforms.Compose([
          transforms.Resize((224, 224),interpolation=2),
          transforms.RandomHorizontalFlip(p=0.5),
          transforms.ToTensor(),
          transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ]),
    'val': transforms.Compose([
          
    ])
}

调用读取数据集的类（包括训练集和测试集）

In [9]:
train_dataset = TrainDataset('./train/', filename, filelabel, transform['train'])
tra_dataset, val_dataset = random_split(train_dataset, [10000, 3627])
test_dataset = TestDataset([x.path for x in os.scandir('./test/')], transform['train'])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
tra_loader = DataLoader(tra_dataset, batch_size=64, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

train_dataset_num = train_dataset.__len__()
tra_dataset_num = tra_dataset.__len__()

初始化预训练模型

In [10]:
def initializeModel(model_name, num_class, finetuning=False, pretrained=True):

  if model_name == 'alexnet':
    model = models.alexnet(pretrained=pretrained)
  elif model_name == 'vgg11':
    model = models.vgg11(pretrained=pretrained)
  elif model_name == 'vgg11_bn':
    model = models.vgg11_bn(pretrained=pretrained)
  elif model_name == 'vgg13':
    model = models.vgg13(pretrained=pretrained)
  elif model_name == 'vgg13_bn':
    model = models.vgg13_bn(pretrained=pretrained)
  elif model_name == 'vgg16':
    model = models.vgg16(pretrained=pretrained)
  elif model_name == 'vgg16_bn':
    model = models.vgg11(pretrained=pretrained)
  elif model_name == 'vgg19':
    model = models.vgg11(pretrained=pretrained)
  elif model_name == 'vgg19_bn':
    model = models.vgg11(pretrained=pretrained)
  elif model_name == 'resnet18':
    model = models.resnet18(pretrained=pretrained)
  elif model_name == 'resnet34':
    model = models.resnet34(pretrained=pretrained)
  elif model_name == 'resnet50':
    model = models.resnet50(pretrained=pretrained)
  elif model_name == 'resnet101':
    model = models.resnet101(pretrained=pretrained)
  elif model_name == 'resnet152':
    model = models.resnet152(pretrained=pretrained)
  elif model_name == 'squeezenet1_0':
    model = models.squeezenet1_0(pretrained=pretrained)
  elif model_name == 'squeezenet1_1':
    model = models.squeezenet1_1(pretrained=pretrained)
  elif model_name == 'densenet121':
    model = models.densenet121(pretrained=pretrained)
  elif model_name == 'densenet169':
    model = models.densenet169(pretrained=pretrained)
  elif model_name == 'densenet161':
    model = models.densenet161(pretrained=pretrained)
  elif model_name == 'densenet201':
    model = models.densenet201(pretrained=pretrained)
  elif model_name == 'inception_v3':
    model = models.inception_v3(pretrained=pretrained)
  elif modle_name == 'googlenet':
    model = models.googlenet(pretrained=pretrained)
  elif model_name == 'shufflenet_v2_x0_5':
    model = models.shufflenet_v2_x0_5(pretrained=pretrained)
  elif model_name == 'shufflenet_v2_x1_0':
    model = models.shufflenet_v2_x1_0(pretrained=pretrained)
  elif model_name == 'shufflenet_v2_x1_5':
    model = models.shufflenet_v2_x1_5(pretrained=pretrained)
  elif model_name == 'shufflenet_v2_x2_0':
    model = models.shufflenet_v2_x2_0(pretrained=pretrained)
  elif model_name == 'mobilenet_v2':
    model = models.mobilenet_v2(pretrained=pretrained)
  elif model_name == 'resnext50_32x4d':
    model = models.resnext50_32x4d(pretrained=pretrained)
  elif model_name == 'resnext101_32x8d':
    model = models.resnext101_32x8d(pretrained=pretrained)
  elif model_name == 'wide_resnet50_2':
    model = models.wide_resnet50_2(pretrained=pretrained)
  elif model_name == 'wide_resnet101_2':
    model = models.wide_resnet101_2(pretrained=pretrained)
  elif model_name == 'mnasnet0_5':
    model = models.mnasnet0_5(pretrained=pretrained)
  elif model_name == 'mnasnet0_75':
    model = models.mnasnet0_75(pretrained=pretrained)
  elif model_name == 'mnasnet1_0':
    model = models.mnasnet1_0(pretrained=pretrained)
  elif model_name == 'mnasnet1_3':
    model = models.mnasnet1_3(pretrained=pretrained)
  else:
    raise ValueError('No such Model %s' % model_name)

  if finetuning:
    for param in model.parameters():
      param.requires_grad = True
  else:
    for param in model.parameters():
      param.requires_grad = False

  fc_features = model.fc.in_features #提取预训练网络模型fc层中固定的参数
  model.fc = nn.Linear(fc_features, num_class) #将预训练网络模型fc层中最终分类的类别数修改为数据集的类别数
  model = model.to(device) #将模型加载到指定设备(GPU)上
  return model

定义训练方法

In [11]:
def traWay(model, criterion, optimizer, epochs):
  begin_time = time.time()
  once_begin_time = begin_time
  for epoch in range(epochs):
    print('Epoch {}/{}'.format(epoch+1, epochs))
    print('-' * 10)

    running_loss = 0.0
    running_corrects = 0.0

    #遍历数据集
    for img, labels in train_loader:
      img = img.to(device)
      labels = labels.to(device)

      optimizer.zero_grad() #将梯度初始化为零
      outputs = model(img) #前向传播求出预测的值
      preds = torch.argmax(outputs, dim=1)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step() #对参数进行更新

      running_loss += loss.item() * img.size(0)
      running_corrects += torch.sum(preds == labels.data)
    
    epoch_loss = running_loss/train_dataset_num
    epoch_acc = running_corrects/train_dataset_num

    print('Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))
    print('Training Time per Epoch {}'.format(time.time() - once_begin_time))
    once_begin_time = time.time()
  
  end_time = time.time() - begin_time
  print('Training complete in {:.0f}m {:.0f}s'.format(end_time // 60, end_time % 60))
  return model

训练

In [13]:
model = initializeModel('resnet152', num_class, True)
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

pre_check_name = r'/content/drive/My Drive/Scene/checkpoint/152_state_best1.tar'

if '152_state_best1.tar' in os.listdir(r'/content/drive/My Drive/Scene/checkpoint'):
  print('loading previous state......')
  checkpoint = torch.load(pre_check_name)
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

model = traWay(model, criterion, optimizer, 5)
check_name = r'/content/drive/My Drive/Scene/checkpoint/152_state_best1.tar'

torch.save({
    'model_state_dict':model.state_dict(),
    'optimizer_state_dict':optimizer.state_dict()
},check_name)

Epoch 1/5
----------
Loss: 0.4334 Acc: 0.8580
Training Time per Epoch 327.4913682937622
Epoch 2/5
----------
Loss: 0.2704 Acc: 0.9036
Training Time per Epoch 336.58992862701416
Epoch 3/5
----------
Loss: 0.2581 Acc: 0.9068
Training Time per Epoch 336.44926953315735
Epoch 4/5
----------
Loss: 0.2461 Acc: 0.9118
Training Time per Epoch 336.57881808280945
Epoch 5/5
----------
Loss: 0.2407 Acc: 0.9111
Training Time per Epoch 336.52711725234985
Training complete in 27m 54s


测试

In [14]:
model = initializeModel('resnet152', num_class, False)
check_name = r'/content/drive/My Drive/Scene/checkpoint/152_state_best1.tar'
checkpoint = torch.load(check_name)
model.load_state_dict(checkpoint['model_state_dict'])

with open('./result.txt', mode='w') as result_file:
  for img, index in test_loader:
    img = img.to(device)

    outputs = model(img)
    preds = torch.argmax(outputs, dim=1)

    for i in range(index.shape[0]):
      print(str(index[i].item())+','+str(map[preds[i]]))
      result_file.write(str(index[i].item())+','+str(map[preds[i]])+'\n')


0,forest
1,street
2,street
3,mountain
4,forest
5,street
6,buildings
7,street
8,sea
9,mountain
10,glacier
11,mountain
12,buildings
13,glacier
14,forest
15,street
16,mountain
17,glacier
18,glacier
19,mountain
20,glacier
21,mountain
22,mountain
23,forest
24,forest
25,forest
26,glacier
27,sea
28,buildings
29,mountain
30,sea
31,mountain
32,street
33,sea
34,forest
35,street
36,sea
37,glacier
38,sea
39,glacier
40,glacier
41,mountain
42,forest
43,sea
44,mountain
45,sea
46,mountain
47,glacier
48,street
49,sea
50,glacier
51,buildings
52,sea
53,mountain
54,glacier
55,forest
56,forest
57,sea
58,mountain
59,street
60,street
61,street
62,sea
63,sea
64,forest
65,sea
66,forest
67,mountain
68,forest
69,forest
70,street
71,street
72,sea
73,mountain
74,sea
75,buildings
76,street
77,forest
78,forest
79,forest
80,street
81,forest
82,forest
83,glacier
84,sea
85,mountain
86,glacier
87,street
88,forest
89,street
90,mountain
91,street
92,mountain
93,buildings
94,mountain
95,sea
96,buildings
97,mountain
98,sea
