In [1]:
import torch
import torch.nn as nn
from torchvision import transforms,datasets
import json
import matplotlib.pyplot as plt
import os
import matplotlib.pyplot as plt
import torch.optim as optim
import import_ipynb
from model import resnet34,resnet50,resnet101

importing Jupyter notebook from model.ipynb


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device".format(device))

using cuda:0 device


In [3]:
data_transform = {
    "train":transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    "val":transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

In [4]:
data_root = os.path.abspath(os.path.join(os.getcwd(),"../../"))
image_path = os.path.join(data_root,"datasets","flower_data")
assert os.path.exists(image_path),"{} path does not exist".format(image_path)

In [5]:
train_dataset = datasets.ImageFolder(root=os.path.join(image_path,'train'),
                                    transform=data_transform['train'])
train_num = len(train_dataset)
train_num

3306

In [6]:
# {'daisy':0,'dandelion':1,'roses':2,'sunflower':3,'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val,key) for key,val in flower_list.items())

In [7]:
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json','w') as json_file:
    json_file.write(json_str)

In [8]:
batch_size = 16
nw = min([os.cpu_count(), batch_size if batch_size>1 else 0, 8]) # number of workers
print('using {} dataloader workers every process'.format(nw))

using 8 dataloader workers every process


In [9]:
train_loader = torch.utils.data.DataLoader(train_dataset,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=nw)

In [10]:
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path,'val'),
                                       transform=data_transform['val'])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=nw)
print('using {} images for training,{} images for validation'.format(train_num,val_num))

using 3306 images for training,364 images for validation


In [11]:
net = resnet50()
# load pretrain weights
# model_weight_path = './resnet34-333f7ec4.pth'
model_weight_path = './resnet50-19c8e357.pth'
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth

assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)

In [12]:
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel,5)
net.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [13]:
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(),lr=0.0001)

best_acc = 0.0
save_path = './resNet50.pth'
for epoch in range(3):
    # train
    net.train()
    running_loss = 0.0
    for step,data in enumerate(train_loader,start=0):
        images, labels = data
        optimizer.zero_grad()
        logits = net(images.to(device))
        loss = loss_function(logits,labels.to(device))
        loss.backward()
        optimizer.step()
        
        #print statistics
        running_loss += loss.item()
        #print train process
        rate = (step+1)/len(train_loader)
        a = "*"*int(rate*50)
        b = "."*int((1-rate)*50)
        print("\train loss: {:^3.0f}% [{}->{}]{:.4f}".format(int(rate*100),a,b,loss),end="")
    print()
    
    #validate
    net.eval()
    acc = 0.0 # accumulate accurate number/epoch
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device)) #eval model only have last output layer
            # loss = loss_function(outputs,test_labels)
            predict_y = torch.max(outputs,dim=1)[1]
            acc+=(predict_y==val_labels.to(device)).sum().item()
            
        val_accurate = acc/val_num
        if val_accurate>best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(),save_path)
            
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
                (epoch + 1, running_loss / step, val_accurate))
    print('Finished Training')

	rain loss:  0 % [->.................................................]1.6003	rain loss:  0 % [->.................................................]1.4258	rain loss:  1 % [->.................................................]1.4029	rain loss:  1 % [->.................................................]1.2501	rain loss:  2 % [*->................................................]1.0918	rain loss:  2 % [*->................................................]0.8996	rain loss:  3 % [*->................................................]1.0500	rain loss:  3 % [*->................................................]0.9289	rain loss:  4 % [**->...............................................]0.7069	rain loss:  4 % [**->...............................................]0.8175	rain loss:  5 % [**->...............................................]0.6449	rain loss:  5 % [**->...............................................]0.8500	rain loss:  6 % [***->..............................................]1.1341	rain loss: 