In [1]:
import numpy as np
import os

# torch
import torch
import torch.nn as nn
import torch.optim as optim
from collections import OrderedDict

# submodules
from net.st_gcn import ST_GCN_Model
from feeder.feeder import Feeder

In [2]:
class ST_GCN_Processor:
    def __init__(
        self,
        train_data_path,
        train_label_path,
        test_data_path,
        test_label_path,
        in_channels=3,
        num_class=60,
        dropout=0.5,
        edge_importance_weighting=True,
        graph_args={"layout": "ntu-rgb+d", "strategy": "spatial"},
        step=[10, 50],
        base_lr=0.01,
        batch_size=64,
        test_batch_size=64,
        epoch=80,
        topk=5,
        debug=False,
        model_path="epoch_model.pt",
    ):
        self.step = step
        self.base_lr = base_lr
        self.batch_size = batch_size
        self.test_batch_size = test_batch_size
        self.debug = debug
        self.epoch = epoch
        self.meta_info = dict(epoch=0, iter=1)
        self.epoch_info = dict()
        self.result = dict()
        self.model_path = model_path
        self.topk = topk
        self.dev = (
            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        )
        self.__load_model(
            in_channels,
            num_class,
            dropout,
            edge_importance_weighting,
            graph_args,
        )
        self.optimizer = optim.SGD(
            self.model.parameters(),
            lr=self.base_lr,
            momentum=0.9,
            nesterov=True,
            weight_decay=0.0001,
        )

        self.__load_data(
            train_data_path,
            train_label_path,
            test_data_path,
            test_label_path,
        )

    def __weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find("Conv1d") != -1:
            m.weight.data.normal_(0.0, 0.02)
            if m.bias is not None:
                m.bias.data.fill_(0)
        elif classname.find("Conv2d") != -1:
            m.weight.data.normal_(0.0, 0.02)
            if m.bias is not None:
                m.bias.data.fill_(0)
        elif classname.find("BatchNorm") != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    def __load_model(
        self,
        in_channels,
        num_class,
        dropout,
        edge_importance_weighting,
        graph_args,
    ):
        self.model = ST_GCN_Model(
            in_channels=in_channels,
            num_class=num_class,
            dropout=dropout,
            edge_importance_weighting=edge_importance_weighting,
            graph_args=graph_args,
        )
        self.model.apply(self.__weights_init)
        self.loss = nn.CrossEntropyLoss()

        # model_path が存在する場合は、そのモデルをロードする
        if self.model_path != None and os.path.exists(self.model_path):
            self.model.load_state_dict(
                torch.load(self.model_path, map_location=self.dev, weights_only=True)
            )
            print(f"The model has been loaded from {self.model_path}.")

    def __adjust_lr(self):
        if self.step:
            lr = self.base_lr * (
                0.1 ** np.sum(self.meta_info["epoch"] >= np.array(self.step))
            )
            for param_group in self.optimizer.param_groups:
                param_group["lr"] = lr
            self.lr = lr
        else:
            self.lr = self.base_lr

    def __load_data(
        self,
        train_data_path,
        train_label_path,
        test_data_path,
        test_label_path,
        num_workers=11,
    ):
        self.data_loader = dict()
        self.data_loader["train"] = torch.utils.data.DataLoader(
            dataset=Feeder(
                data_path=train_data_path,
                label_path=train_label_path,
                debug=self.debug,
            ),
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=num_workers,
            drop_last=True,
        )
        self.data_loader["test"] = torch.utils.data.DataLoader(
            dataset=Feeder(
                data_path=test_data_path,
                label_path=test_label_path,
                debug=self.debug,
            ),
            batch_size=self.test_batch_size,
            shuffle=False,
            num_workers=num_workers,
        )

    def __save_model(self, model):
        state_dict = model.state_dict()
        weights = OrderedDict(
            [["".join(k.split("module.")), v.cpu()] for k, v in state_dict.items()]
        )
        torch.save(weights, self.model_path)
        print(f"The model has been saved as {self.model_path}.")

    def __show_topk(self, k):
        rank = self.result.argsort()
        hit_top_k = [l in rank[i, -k:] for i, l in enumerate(self.label)]
        accuracy = sum(hit_top_k) * 1.0 / len(hit_top_k)
        print(f"\tTop{k}: {100 * accuracy:.2f}%")

    def __train_once(self):
        self.model.train()
        self.__adjust_lr()
        loader = self.data_loader["train"]
        loss_value = []

        if self.meta_info["epoch"] == 0:
            print("loader length: ", len(loader))

        self.meta_info["iter"] = 0
        for data, label in loader:
            print(f"iter: {self.meta_info['iter']}")
            # get data
            data = data.float().to(self.dev)
            label = label.long().to(self.dev)

            # forward
            output = self.model(data)
            loss = self.loss(output, label)

            # backward
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # statistics
            loss_value.append(loss.data.item())
            self.meta_info["iter"] += 1

        self.epoch_info["mean_loss"] = np.mean(loss_value)

    def __test(self, evaluation=True):
        self.model.eval()
        loader = self.data_loader["test"]
        loss_value = []
        result_frag = []
        label_frag = []

        for data, label in loader:

            # get data
            data = data.float().to(self.dev)
            label = label.long().to(self.dev)

            # inference
            with torch.no_grad():
                output = self.model(data)
            result_frag.append(output.data.cpu().numpy())

            # get loss
            if evaluation:
                loss = self.loss(output, label)
                loss_value.append(loss.item())
                label_frag.append(label.data.cpu().numpy())

        self.result = np.concatenate(result_frag)
        if evaluation:
            self.label = np.concatenate(label_frag)
            self.epoch_info["mean_loss"] = np.mean(loss_value)

            for k, v in self.epoch_info.items():
                print(f"\t{k}: {v}")

        for k in range(1, self.topk + 1):
            self.__show_topk(k)

    def train_once(self):
        print("Start training")
        self.__train_once()
        print("End of training")

    def train(self):
        for epoch in range(1, self.epoch + 1):
            self.meta_info["epoch"] = epoch
            print(f"Epoch: {self.meta_info['epoch']}")
            self.__train_once()

            self.__save_model(self.model)

    def test(self):
        print("Start testing")
        self.__test()
        print("End of testing")

In [3]:
p = ST_GCN_Processor(
    "./data/xsub/train_data.npy",
    "./data/xsub/train_label.pkl",
    "./data/xsub/val_data.npy",
    "./data/xsub/val_label.pkl",
    epoch=10,
    topk=3,
    model_path="./models/epoch_model.pt",
)

p.train()
p.test()

The model has been loaded from ./models/epoch_model.pt.
Epoch: 1
iter: 0
The model has been saved as ./models/epoch_model.pt.
Epoch: 2
iter: 0
The model has been saved as ./models/epoch_model.pt.
Epoch: 3
iter: 0
The model has been saved as ./models/epoch_model.pt.
Epoch: 4
iter: 0
The model has been saved as ./models/epoch_model.pt.
Epoch: 5
iter: 0
The model has been saved as ./models/epoch_model.pt.
Epoch: 6
iter: 0
The model has been saved as ./models/epoch_model.pt.
Epoch: 7
iter: 0
The model has been saved as ./models/epoch_model.pt.
Epoch: 8
iter: 0
The model has been saved as ./models/epoch_model.pt.
Epoch: 9
iter: 0
The model has been saved as ./models/epoch_model.pt.
Epoch: 10
iter: 0
The model has been saved as ./models/epoch_model.pt.
Start testing
	mean_loss: 1.493636965751648
	Top1: 33.33%
	Top2: 66.67%
	Top3: 100.00%
End of testing
