In [11]:
import os
import torch
import numpy as np
import xml.etree.ElementTree as ET
from PIL import Image
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data import Dataset

In [12]:
class CustomVOCDataset(Dataset):
    def __init__(self, root, image_set='train', transforms=None, label_map=None):
        self.root = root
        self.image_dir = os.path.join(root, "JPEGImages")
        self.annotation_dir = os.path.join(root, "Annotations")
        self.transforms = transforms
        self.label_map = label_map or {
            '0': 1, '1': 2, '2': 3, '3': 4, '4': 5, '5': 6,
            '6': 7, '7': 8, '8': 9, '9': 10,
            '+': 11, '-': 12, '*': 13, '/': 14, '=': 15
        }

        # image_set (train.txtなど) を読み込む
        split_file = os.path.join(root, "ImageSets", "Main", f"{image_set}.txt")
        with open(split_file) as f:
            self.image_ids = [line.strip() for line in f.readlines()]

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        img_path = os.path.join(self.image_dir, f"{image_id}.jpg")
        xml_path = os.path.join(self.annotation_dir, f"{image_id}.xml")

        img = Image.open(img_path).convert("RGB")

        tree = ET.parse(xml_path)
        root = tree.getroot()

        boxes = []
        labels = []

        for obj in root.findall("object"):
            name = obj.find("name").text
            label = self.label_map[name]
            bbox = obj.find("bndbox")
            xmin = float(bbox.find("xmin").text)
            ymin = float(bbox.find("ymin").text)
            xmax = float(bbox.find("xmax").text)
            ymax = float(bbox.find("ymax").text)
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(label)

        target = {
            "boxes": torch.tensor(boxes, dtype=torch.float32),
            "labels": torch.tensor(labels, dtype=torch.int64)
        }

        if self.transforms:
            img = self.transforms(img)

        return img, target

In [13]:
torch.serialization.add_safe_globals([torchvision.models.detection.ssd.SSD])
model = torch.load('../src/model/ssd_calculator_merge_model4.1.10.pth', map_location='cpu',weights_only=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
model.head.classification_head.num_classes = 16
model.eval()

SSD(
  (backbone): SSDFeatureExtractorVGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): ReLU(inplace=

In [14]:
transform = transforms.ToTensor()

dataset = CustomVOCDataset(root="sample_dataset", image_set="train", transforms=transform)
train_loader = DataLoader(dataset, batch_size=128, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.00023082965571758206)


# 学習ループ
model.train()
for epoch in range(5):
    epoch_loss = 0.0
    num_batches = 0

    for batch_idx, (images, targets) in enumerate(train_loader):
        # データをデバイスに移動
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # 勾配をゼロに
        optimizer.zero_grad()

        try:
            # 順伝播
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

            # NaNチェック
            if torch.isnan(losses):
                print(f"NaN detected at epoch {epoch+1}, batch {batch_idx}")
                print(f"Loss dict: {loss_dict}")
                break

            # 逆伝播
            losses.backward()

            # 勾配クリッピング（重要）
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            #スケールの更新
            optimizer.step()

            epoch_loss += losses.item()
            num_batches += 1

        except Exception as e:
            print(f"Error at epoch {epoch+1}, batch {batch_idx}: {e}")
            continue

    avg_loss = epoch_loss / num_batches if num_batches > 0 else float('inf')
    print(f"Epoch {epoch+1} Average Loss: {avg_loss:.4f}")

Epoch 1 Average Loss: 0.5246
Epoch 2 Average Loss: 0.2848
Epoch 3 Average Loss: 0.5947
Epoch 4 Average Loss: 0.3151
Epoch 5 Average Loss: 0.4738
