In [1]:
import os
import time
import numpy as np
import pandas as pd
import random
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.hub import load_state_dict_from_url
from typing import Union, List, Dict, Any, cast

import matplotlib.pyplot as plt
from PIL import Image

import json
from tqdm import tqdm

import sys

sys.path.insert(0, "../..") # to include ../helper_evaluate.py etc.
from helper_utils import set_all_seeds, set_deterministic

if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

### parameters

In [2]:
random_seed = 47
learning_rate = 0.0001
batch_size = 8
epochs = 10

num_classes = 5

DEVICE = "cuda"

save_path = "vgg16_flower.pth"

### setting

In [3]:
set_all_seeds(random_seed)

set_deterministic()

### data

In [4]:
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
test_transform = transforms.Compose([transforms.RandomResizedCrop((224, 224)),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

image_path = "D:/work/data/Python/flower_data/"
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)

train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), transform=train_transform)
test_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), transform=test_transform)

flower_list = train_dataset.class_to_idx
class_dict = dict((val, key) for key, val in flower_list.items())

# dump dict too json file
json_str = json.dumps(class_dict, indent=4)
with open("class_indices.json", "w") as json_file:
    json_file.write(json_str)
#     json.dump(class_dict, json_file, ensure_ascii=False)

num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 0])
print('Using {} dataloader workers every process'.format(num_workers))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

train_num = len(train_dataset)
test_num = len(test_dataset)

print("train images num: ", train_num)
print("test images num: ", test_num)

for images, labels in train_loader:
    print("images shape: ", images.size())
    print("labels shape: ", labels.size())
    break

Using 0 dataloader workers every process
train images num:  3306
test images num:  364
images shape:  torch.Size([8, 3, 224, 224])
labels shape:  torch.Size([8])


### model

In [5]:
model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}

cfgs: Dict[str, List[Union[str, int]]] = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

class VGG(nn.Module):
    def __init__(self, features=nn.Module, num_classes=1000, init_weights=True):
        super(VGG, self).__init__()
        self.features = features
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 *7, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_classes)
        )
        
        if init_weights:
            self._initialize_weights()
    
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            v = cast(int, v)
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, stride=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG:
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model

def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
    r"""VGG 16-layer model
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _vgg('vgg16', 'vgg16', False, pretrained, progress, **kwargs)


def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
    r"""VGG 16-layer model with batch normalization
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _vgg('vgg16_bn', 'vgg16', True, pretrained, progress, **kwargs)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model = vgg16(num_classes=num_classes, init_weights=True)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

cuda


In [7]:
image_mean = torch.tensor([0.485, 0.456, 0.406])

image_std = torch.tensor([0.229, 0.224, 0.225])

def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = (img * image_std + image_mean)*255     # unnormalize
    npimg = img.numpy().astype('uint8')
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

### TensorBoard

#### TensorBoard setup

In [8]:
from torch.utils.tensorboard import SummaryWriter

if not os.path.exists('runs'):
    os.mkdir('runs')

# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter('runs/fashion_mnist_experiment_1')

# 将模型写入tensorboard
init_img = torch.zeros((1, 3, 224, 224), device=device)
writer.add_graph(model, init_img)

In [9]:
from data_utils import generate_label_file, plot_class_preds
from train_evaluate_utils import train_one_epoch, evaluate_accuracy, evaluate_one
import torch.optim.lr_scheduler as lr_scheduler

weights_path = "../vgg16_flower.pth"
if os.path.exists(weights_path):
#     weights_dict = torch.load(weights_path, map_location=device)
#     load_weights_dict = {k: v for k, v in weights_dict.items()
#                          if model.state_dict()[k].numel() == v.numel()}
#     model.load_state_dict(load_weights_dict, strict=False)

    model.load_state_dict(torch.load(weights_path, map_location=device))
else:
    print("not using pretrain-weights.")
    
# 是否冻结权重
freeze_layers = []
if freeze_layers:
    print("freeze layers except fc layer.")
    for name, para in model.named_parameters():
        # 除最后的全连接层外，其他权重全部冻结
        if "fc" not in name:
            para.requires_grad_(False)

lrf = 0.1
pg = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(pg, lr=learning_rate, momentum=0.9, weight_decay=0.005)
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - lrf) + lrf  # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

In [10]:
print(model)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
    (18): ReLU(inplace=True)
  

In [11]:
plot_images_dir = "D:/work/data/Python/flower_data/flower_photos/"
for epoch in range(epochs):
    # train
    mean_loss = train_one_epoch(model=model,
                                loss_fn=F.cross_entropy,
                                optimizer=optimizer,
                                data_loader=train_loader,
                                device=device,
                                epoch=epoch)
    # update learning rate
    scheduler.step()

    # validate
    acc = evaluate_accuracy(model=model,
                            data_loader=test_loader,
                            device=device)

    # add loss, acc and lr into tensorboard
    print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))
    tags = ["train_loss", "accuracy", "learning_rate"]
    writer.add_scalar(tags[0], mean_loss, epoch)
    writer.add_scalar(tags[1], acc, epoch)
    writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)

    # add figure into tensorboard
    fig = plot_class_preds(net=model,
                           images_dir=plot_images_dir,
                           transform=test_transform,
                           num_plot=5,
                           device=device)
    if fig is not None:
        writer.add_figure("predictions vs. actuals",
                             figure=fig,
                             global_step=epoch)

    # add conv1 weights into tensorboard
    writer.add_histogram(tag="conv1",
                            values=model.features[0].weight,
                            global_step=epoch)

    if (epoch+1) % (epochs // 2) == 0:
        # save weights
        torch.save(model.state_dict(), "./vgg16_flower_model-{}.pth".format(epoch))

[epoch 0] mean loss 1.485:   9%|████▊                                                 | 37/414 [00:24<04:09,  1.51it/s]


KeyboardInterrupt: 

In [None]:
fig = plot_class_preds(net=model,
                       images_dir="D:/work/data/Python/flower_data/flower_photos/",
                       transform=test_transform,
                       num_plot=5,
                       device=device)