## config

In [None]:
# 超参配置

class Hyperparameter:
    # ################################################################
    #                             Data
    # ################################################################
    device = 'cuda'
    data_root = './data/'
    origin_jpg_root = '../input/resnet-transferlearning-action-prepare/data'
    cls_mapper_path = '../input/resnet-transferlearning-action-prepare/cls_mapper.json'

    metadata_train_path = './data/meta_train.txt'
    metadata_dev_path = './data/meta_dev.txt'
    metadata_test_path = './data/meta_test.txt'

    class_num = 11
    seed = 1234  # random seed

    # ################################################################
    #                             Model Structure
    # ################################################################
    if_conv_frozen = False

    # ################################################################
    #                             Experiment
    # ################################################################
    batch_size = 128
    init_lr = 5e-4
    epochs = 30
    verbose_step = 20
    save_step = 100


HP = Hyperparameter()


## utils

In [None]:
import os
from PIL import Image


# 获取某个文件夹下面所有后缀为suffix的文件，返回path的list
def recursive_fetching(root, suffix=['jpg', 'png']):
    all_file_path = []

    def get_all_files(path):
        all_file_list = os.listdir(path)
        # 遍历该文件夹下的所有目录或者文件
        for file in all_file_list:
            filepath = os.path.join(path, file)
            # 如果是文件夹，递归调用函数
            if os.path.isdir(filepath):
                get_all_files(filepath)
            # 如果不是文件夹，保存文件路径及文件名
            elif os.path.isfile(filepath):
                all_file_path.append(filepath)

    get_all_files(root)

    file_paths = [it for it in all_file_path if os.path.split(it)[-1].split('.')[-1].lower() in suffix]

    return file_paths


def load_meta(meta_path):
    with open(meta_path, 'r') as fr:
        return [line.strip().split('|') for line in fr.readlines()]


def load_image(image_path):
    return Image.open(image_path)


## preprocess

In [None]:
import json
import os
import random

random.seed(HP.seed)

cls_mapper = json.load(open(HP.cls_mapper_path, 'r'))

for foldername in ['data', 'log', 'model_save']:
    if not os.path.exists(foldername):
        os.mkdir(foldername)

data_list = os.listdir(HP.origin_jpg_root)

random.shuffle(data_list)

train_ratio, dev_ratio, test_ratio = 0.8, 0.1, 0.1
avi_list_len = len(data_list)
train_len, dev_len = int(avi_list_len * train_ratio), int(avi_list_len * dev_ratio),
train_set, dev_set, test_set = data_list[:train_len], data_list[train_len:train_len + dev_len], data_list[train_len + dev_len:]

with open(HP.metadata_train_path, 'w') as fw:
    for path in train_set:
        fn_start = os.path.split(path)[-1].split('_')[1]
        cls_id = cls_mapper['cls2id'][fn_start]
        fw.write('%d|%s\n' % (cls_id, os.path.join(HP.origin_jpg_root, path)))

with open(HP.metadata_dev_path, 'w') as fw:
    for path in dev_set:
        fn_start = os.path.split(path)[-1].split('_')[1]
        cls_id = cls_mapper['cls2id'][fn_start]
        fw.write('%d|%s\n' % (cls_id, os.path.join(HP.origin_jpg_root, path)))

with open(HP.metadata_test_path, 'w') as fw:
    for path in test_set:
        fn_start = os.path.split(path)[-1].split('_')[1]
        cls_id = cls_mapper['cls2id'][fn_start]
        fw.write('%d|%s\n' % (cls_id, os.path.join(HP.origin_jpg_root, path)))


## dataset_action

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms as T

ac_transform = T.Compose([
    T.Resize((112, 112)),  # 保证同样输入的shape
    T.RandomRotation(degrees=45),  # 减小倾斜图片影响
    T.GaussianBlur(kernel_size=(3, 3)),  # 抑制模糊图片的影响
    T.RandomHorizontalFlip(),  # 左右
    T.ToTensor(),  # 归一化 & float32 tensor
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))  # 标准化
])


class ActionDataset(torch.utils.data.Dataset):
    def __init__(self, metadata_path):
        self.dataset = load_meta(metadata_path)  # [(0, image_path), () ,...]

    def __getitem__(self, idx):
        item = self.dataset[idx]
        cls_id, path = int(item[0]), item[1]
        image = load_image(path)
        return ac_transform(image).to(HP.device), cls_id

    def __len__(self):
        return len(self.dataset)


## model

In [None]:
from torch import nn
import torchvision


class TLNet(nn.Module):

    def __init__(self):
        super(TLNet, self).__init__()
        self.model = torchvision.models.resnet34(pretrained=True)
        if HP.if_conv_frozen:
            for k, v in self.model.named_parameters():
                v.requires_grad = False
        resnet_fc_dim = self.model.fc.in_features
        new_fc_layer = nn.Linear(resnet_fc_dim, HP.class_num)
        self.model.fc = new_fc_layer

    def forward(self, x):
        return self.model(x)

## trainer

In [None]:
import os.path
import random
import torch
import numpy as np
from tensorboardX import SummaryWriter
from torch import nn
from torch import optim
from torch.utils.data import DataLoader

logger = SummaryWriter('./log')

# seed init: 保证模型的可复现性
torch.manual_seed(HP.seed)
random.seed(HP.seed)
np.random.seed(HP.seed)
torch.cuda.manual_seed(HP.seed)


def evaluate(model, devloader, crit):
    model.eval()
    sum_loss = 0.
    with torch.no_grad():
        for batch in devloader:
            x, y = batch
            pred = model(x)
            loss = crit(pred, y.to(HP.device))
            sum_loss += loss.item()

    model.train()
    return sum_loss / len(devloader)


def save_checkpoint(model, epoch, opt, save_path):
    save_dict = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': opt.state_dict()
    }
    torch.save(save_dict, save_path)


def train():

    model = TLNet().to(HP.device)

    criterion = nn.CrossEntropyLoss()

    opt = optim.Adam(model.parameters(), lr=HP.init_lr)

    trainset = ActionDataset(HP.metadata_train_path)
    train_loader = DataLoader(trainset, batch_size=HP.batch_size, shuffle=True, drop_last=True)

    devset = ActionDataset(HP.metadata_dev_path)
    dev_loader = DataLoader(devset, batch_size=HP.batch_size, shuffle=True, drop_last=False)

    start_epoch, step = 0, 0

    model.train()

    for epoch in range(start_epoch, HP.epochs):
        print('Start Epoch: %d, Steps: %d' % (epoch, len(train_loader)))
        for batch in train_loader:
            x, y = batch  # 加载数据
            opt.zero_grad()  # 梯度归零
            pred = model(x)
            loss = criterion(pred, y.to(HP.device))

            loss.backward()
            opt.step()

            logger.add_scalar('Loss/Train', loss, step)

            if not step % HP.verbose_step:
                eval_loss = evaluate(model, dev_loader, criterion)
                logger.add_scalar('Loss/Dev', eval_loss, step)

            if not step % HP.save_step:
                model_path = 'model_%d_%d.model' % (epoch, step)
                save_checkpoint(model, epoch, opt, os.path.join('model_save', model_path))

            step += 1
            logger.flush()
            print('Epoch:[%d/%d], step:%d, Train Loss:%.5f, Dev Loss:%.5f' % (
                epoch, HP.epochs, step, loss.item(), eval_loss))

    logger.close()


## 训练

In [None]:
train()