In [1]:
import csv
import datetime
import os
import pickle
import re
import sys

import numpy as np
import pandas as pd
import torch.optim
import torch.utils.data
import yaml
from easydict import EasyDict
from tqdm import tqdm

sys.path.append(".")
sys.path.append("..")
sys.path.append("../src")
from configs.config import CONF  # noqa: E402
from dataset.all import CustomDataset, DetectionDataset  # noqa: E402
from model.all import (  # noqa: E402
    CNNModel,
    CustomFasterRCNN1,
    CustomFasterRCNN2,
    CustomResNetModel,
    CustomViTModel,
    CustomMobileNetV2Model
)

from utils import seed_worker, fix_seed  # noqa: E402

import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from typing import List, Tuple

import matplotlib.patches as patches
import numpy as np
from PIL import Image

# Ref) https://cocodrips.hateblo.jp/entry/2020/05/04/210156

def add_bboxes_to_image(ax, image: np.ndarray,
                        bboxes: List[Tuple[int, int, int, int]],
                        labels: List[str] = None,
                        label_size: int = 10,
                        line_width: int = 2,
                        border_color=(0, 1, 0, 1)) -> None:
    """
    Add bbox to ax

    :param image: dtype=np.uint8
    :param bbox: [(left, top, right, bottom)]
    :param label: List[str] or None
    :return: ax
    """
    # Display the image
    ax.imshow(image, cmap='gray')

    if labels is None:
        labels = [None] * len(bboxes)

    for bbox, label in zip(bboxes, labels):
        # Add bounding box
        top, left, bottom, right = bbox
        rect = patches.Rectangle((left, top), right - left, bottom - top,
                                 linewidth=line_width,
                                 edgecolor=border_color,
                                 facecolor='none')
        ax.add_patch(rect)

        # label
        if label:
            bbox_props = dict(boxstyle="square,pad=0",
                              linewidth=line_width, facecolor=border_color,
                              edgecolor=border_color)
            ax.text(left, top, label,
                    ha="left", va="bottom", rotation=0,
                    size=label_size, bbox=bbox_props)
    return ax

def show(X_data, y_data, data_dicts_df):
    fig, axes = plt.subplots(3, 10, figsize=(15, 4.5))
    for i in range(3):
        for j in range(10):
            index = i*10+j
            data = X_data[index]
            label = y_data[index]
            data_dict = data_dicts_df.iloc[index]
            add_bboxes_to_image(axes[i,j], data, data_dict["boxes"], data_dict["labels"])
            axes[i,j].set_title(f"ID: {data_dict['image_id']}, Label: {label}", fontsize=6)
            axes[i,j].axis('off')
    plt.show()

In [3]:
args = {
    'hidden': 512, 
    'dropout': 0.5, 
    'no_decay': ['bias', 'LayerNorm.weight'], 
    'weight_decay': 0.01, 
    'eps': '1e-10', 
    'lr': 0.002, 
    'batch_size': 8, 
    'epochs': 5, 
    'data_dir': 'all+comp', 
    'detection': True, 
    # 'model': 'fasterrcnn1',
    'model': 'fasterrcnn2', 
    'model_layer': 18, 
    'pretrained': True, 
    # 'config': '../configs/comp_det_resnet_18.yaml', 
    'config': '../configs/comp_det_resnet_50_2.yaml',
    'seed': 42, 
    'gpu': '0', 
    'now_str': '20231120_000000',
    'folder': 'comp_det_resnet_18'
}
args = EasyDict(args)

In [8]:
class Trainer:
    def __init__(self, cfg, generator):
        self.cfg = cfg

        # GPU設定
        if torch.cuda.is_available():
            self.device = "cuda:0"
        elif torch.backends.mps.is_available():
            self.device = "mps"
        else:
            self.device = "cpu"
        # self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print(self.device)

        # データセットの読み込み
        X_train = np.load(
            os.path.join(CONF.PATH.DATASET, cfg.data_dir, "X_train_filtered.npy")
        )
        y_train = np.load(
            os.path.join(CONF.PATH.DATASET, cfg.data_dir, "y_train_filtered.npy")
        )
        X_val = np.load(
            os.path.join(CONF.PATH.DATASET, cfg.data_dir, "X_val_filtered.npy")
        )
        y_val = np.load(
            os.path.join(CONF.PATH.DATASET, cfg.data_dir, "y_val_filtered.npy")
        )
        print(X_train.shape, y_train.shape, X_val.shape, y_val.shape)

        if not cfg.detection:
            self.train_dataset = CustomDataset(X_train, y_train)
            self.val_dataset = CustomDataset(X_val, y_val)
        else:
            train_df = pd.read_json(
                os.path.join(CONF.PATH.DATASET, cfg.data_dir, "train.json")
            ).reset_index(drop=True)
            val_df = pd.read_json(
                os.path.join(CONF.PATH.DATASET, cfg.data_dir, "val.json")
            ).reset_index(drop=True)
            print(train_df.shape, val_df.shape)

            self.train_dataset = DetectionDataset(X_train, y_train, train_df)
            self.val_dataset = DetectionDataset(X_val, y_val, val_df)

        self.train_loader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=cfg.batch_size,
            shuffle=True,
            drop_last=True,
            pin_memory=True,  # メモリのページングをしないように設定
            worker_init_fn=seed_worker,  # シード固定
            generator=generator,  # シード固定
        )
        self.val_loader = torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=1,
            pin_memory=True,
            worker_init_fn=seed_worker,
            generator=generator,
        )

        # モデルの読み込み
        if not cfg.detection:
            if cfg.model == "cnn":
                self.model = CNNModel()
            elif cfg.model == "vit":
                # self.model = CustomViTModel(
                #     dim=cfg.dim,
                #     seq_len=cfg.seq_len, # 50: 7x7 patches + 1 cls-token
                #     depth=cfg.depth,
                #     heads=cfg.heads,
                #     k=cfg.k,
                #     image_size=cfg.image_size,
                #     patch_size=cfg.patch_size,
                #     num_classes=cfg.num_classes
                # )
                self.model = CustomViTModel()
            elif cfg.model == "resnet":
                self.model = CustomResNetModel(
                    model=cfg.model_layer, pretrained=cfg.pretrained
                )
            elif cfg.model == "mobilenetv2":
                self.model = CustomMobileNetV2Model()
            else:
                raise NotImplementedError
        else:
            if cfg.model == "fasterrcnn1":
                self.model = CustomFasterRCNN1(
                    model=cfg.model_layer, pretrained=cfg.pretrained
                )
            elif cfg.model == "fasterrcnn2":
                self.model = CustomFasterRCNN2(
                    pretrained=cfg.pretrained
                )
            # self.model = CustomFasterRCNN()
        self.model.to(self.device)

        optimizer_grouped_parameters = [
            # weight_decayの設定
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in cfg.no_decay) and p.requires_grad
                ],
                "weight_decay": cfg.weight_decay,
            },
            # weight_decayを設定しない場合
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if any(nd in n for nd in cfg.no_decay) and p.requires_grad
                ],
                "weight_decay": 0.0,
            },
        ]
        self.optimizer = torch.optim.AdamW(
            optimizer_grouped_parameters, lr=float(cfg.lr), eps=float(cfg.eps)
        )

        self.log_data = []
        self.folder = os.path.join(
            CONF.PATH.OUTPUT, f"{self.cfg.now_str}_{self.cfg.folder}"
        )
        self.csv_file = os.path.join(self.folder, "history.csv")
        self.yaml_file = os.path.join(self.folder, "config.yaml")
        self.pickle_file = os.path.join(self.folder, "config.pkl")

        # フォルダがなければ作成
        if not os.path.exists(self.folder):
            os.makedirs(self.folder)

        # 既存のcsvファイルを空にする
        with open(self.csv_file, "w", newline="") as f:
            f.truncate(0)

        # EasyDictをYAMLファイルに保存
        with open(self.yaml_file, "w") as yaml_file:
            yaml.dump(self.cfg, yaml_file, default_flow_style=False)

        # EasyDictをpickleファイルに保存
        with open(self.pickle_file, "wb") as pickle_file:
            pickle.dump(self.cfg, pickle_file)

    def train_one_epoch(self, epoch=0):
        """
        Trains the model for one epoch.
        Config setting:
            None
        Inputs:
            current epoch num
        Returns:
            None
        Ourputs:
            None
        """
        self.model.train()
        if self.cfg.detection:
            for i, batch in enumerate(self.train_loader):
                global_step = i + epoch * len(self.train_loader)
                images, answers, image_ids, boxes, labels, areas, iscrowds = batch
                images = images.to(self.device)
                rate = 224 / images.shape[2]
                boxes = boxes * rate
                areas = areas * rate * rate
                targets = []
                for box, label, image_id, area, iscrowd in zip(boxes, labels, image_ids, areas, iscrowds):
                    num_box = len(iscrowd[iscrowd == 0])
                    targets.append(
                        {
                            "boxes": box[:num_box].to(self.device),
                            "labels": label[:num_box].to(self.device),
                            "image_id": image_id.to(self.device),
                            "area": area[:num_box].to(self.device),
                            "iscrowd": iscrowd[:num_box].to(self.device),
                        }
                    )
                # print(images.shape, targets)
                loss_dict = self.model(images, targets)

                losses = 0
                for k, v in loss_dict.items():
                    if k == "loss_objectness" or k == "loss_rpn_box_reg":
                        losses += 10 * v
                    else:
                        losses += v
                # losses = sum(loss for loss in loss_dict.values())
                loss_value = losses.item()

                self.optimizer.zero_grad()
                losses.backward()
                self.optimizer.step()

                if global_step % 100 == 0:
                    log = {
                        "epoch": epoch,
                        "global_step": global_step,
                        "loss": loss_value,
                    }
                    log.update({k: v.item() for k, v in loss_dict.items()})
                    print(log)
                    self.log_data.append(log)
                    self.write_log_to_csv()
                # if global_step % 10 == 0:
                    self.show_result(images, targets)
        else:
            for i, (X, y) in enumerate(self.train_loader):
                global_step = i + epoch * len(self.train_loader)
                X = X.to(self.device)
                y = y.to(self.device)

                output = self.model(X, y)
                loss = output["loss"]

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                if global_step % 100 == 0:
                    log = {
                        "epoch": epoch,
                        "global_step": global_step,
                        "loss": loss.item(),
                    }
                    print(log)
                    self.log_data.append(log)
                    self.write_log_to_csv()

    def train(self):
        """
        Trains the model for a specified number of epochs.
        Config setting:
            epochs
        Inputs:
            None
        Returns:
            None
        Ourputs:
            trained model
        """
        for epoch in tqdm(range(self.cfg.epochs)):
            self.train_one_epoch(epoch)
            if self.cfg.detection:
                self.save_model(f"epoch{epoch + 1:04d}")
            else:
                if (epoch + 1) % 5 == 0:
                    self.save_model(f"epoch{epoch + 1:04d}")

        self.save_model("model")

    def save_model(self, name):
        torch.save(self.model.state_dict(), os.path.join(self.folder, f"{name}.pth"))

    def write_log_to_csv(self):
        with open(self.csv_file, "a", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=self.log_data[0].keys())

            if f.tell() == 0:
                writer.writeheader()

            writer.writerow(self.log_data[-1])

    def show_result(self, images, targets):
        rate =  20 / 224
        self.model.eval()
        fig, axes = plt.subplots(2, images.shape[0], figsize=(images.shape[0], 2))
        for i in range(images.shape[0]):
            image = images[i].cpu()
            target = targets[i]
            boxes = target["boxes"].cpu().numpy()
            labels = target["labels"].cpu().numpy()
            add_bboxes_to_image(axes[0,i], image, boxes * rate, labels)
            axes[0,i].axis('off')
        with torch.no_grad():
            outputs = self.model(images)
        for i, output in enumerate(outputs):
            image = images[i].cpu()
            boxes = output["boxes"].cpu().numpy()
            labels = output["labels"].cpu().numpy()
            scores = output["scores"].cpu().numpy()
            if len(scores) == 0:
                threshold = 0
            elif len(scores) < 3:
                threshold = scores.min()
            else:
                threshold = sorted(scores)[-3]
            boxes = boxes[scores > threshold]
            labels = labels[scores > threshold]
            add_bboxes_to_image(axes[1,i], image, boxes * rate, labels)
            axes[1,i].axis('off')
        plt.title(f"Step: {self.log_data[-1]['global_step']}")
        self.model.train()
        plt.show()

In [9]:
g = fix_seed(args.seed)
cfg = EasyDict(yaml.load(open(args.config), yaml.SafeLoader))
cfg.update(args)

trainer = Trainer(cfg, generator=g)
trainer.train()

cuda:0
(192000, 400) (192000,) (48000, 400) (48000,)
(192000, 5) (48000, 5)


  0%|          | 0/5 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 314.00 MiB. GPU 0 has a total capacty of 7.78 GiB of which 206.75 MiB is free. Process 3084 has 155.06 MiB memory in use. Including non-PyTorch memory, this process has 7.00 GiB memory in use. Of the allocated memory 6.65 GiB is allocated by PyTorch, and 129.41 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF