In [7]:
import os
import glob
import re
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torchvision.transforms.functional as F
import random

In [8]:
# ----------------------------
# 1. 暗い画像補正 + 明るさ補正関数
# ----------------------------
def normalize_image_safe_pil(img):
    """PIL Image を min-max 正規化して明暗バランスを整える"""
    img_np = np.array(img).astype(np.float32)
    
    if len(img_np.shape) == 3:
        out = np.zeros_like(img_np)
        for c in range(3):
            min_val = img_np[:, :, c].min()
            max_val = img_np[:, :, c].max()
            if max_val - min_val > 1e-5:
                out[:, :, c] = (img_np[:, :, c] - min_val) * 255 / (max_val - min_val)
            else:
                out[:, :, c] = img_np[:, :, c]
        out = np.clip(out, 0, 255).astype(np.uint8)
    else:
        min_val = img_np.min()
        max_val = img_np.max()
        if max_val - min_val > 1e-5:
            out = (img_np - min_val) * 255 / (max_val - min_val)
        else:
            out = img_np
        out = np.clip(out, 0, 255).astype(np.uint8)
    
    return Image.fromarray(out)

def add_gaussian_noise(img, mean=0.0, std=0.05):
    """PIL Image にガウシアンノイズを追加"""
    img_tensor = F.to_tensor(img)  # [0,1]
    noise = torch.randn_like(img_tensor) * std + mean
    img_tensor = img_tensor + noise
    img_tensor = torch.clamp(img_tensor, 0.0, 1.0)
    return F.to_pil_image(img_tensor)

In [9]:
# ----------------------------
# 1. データセット
# ----------------------------
class DigitsDataset(Dataset):
    def __init__(self, folders, transform=None):
        self.images = []
        self.labels = []
        self.transform = transform
        
        for folder in folders:
            png_files = glob.glob(os.path.join(folder, "*.png"))
            for f in png_files:
                # ラベルは最後の数字
                match = re.search(r'_(\d+)\.png$', f)
                if match:
                    label = int(match.group(1))
                    self.images.append(f)
                    self.labels.append(label)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert('RGB')

        # 暗い画像補正
        image = normalize_image_safe_pil(image)

        if self.transform:
            image = self.transform(image)
        
        return image, label

# ----------------------------
# 2. データ変換
# ----------------------------
transform = transforms.Compose([
    transforms.Lambda(lambda img: img),  # 補正済み
    
    transforms.RandomAffine(
        degrees=5,
        translate=(0.05, 0.05),
        scale=(0.9, 1.1)
    ),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
    
    transforms.ColorJitter(brightness=0.3, contrast=0.3),
    transforms.RandomGrayscale(p=0.3),

    transforms.RandomApply([
        transforms.GaussianBlur(3, sigma=(0.1, 1.0))
    ], p=0.3),

    transforms.Lambda(lambda img: add_gaussian_noise(img, std=0.05)),

    transforms.Resize((224, 224)),
    transforms.ToTensor(),

    transforms.RandomErasing(
        p=0.3,
        scale=(0.02, 0.08),
        ratio=(0.3, 3.3),
        value='random'
    ),

    transforms.Normalize(
        mean=[0.485,0.456,0.406],
        std=[0.229,0.224,0.225]
    )
])

# ----------------------------
# 3. データセットとデータローダー
# ----------------------------
folders = ["./digits"]
dataset = DigitsDataset(folders, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [10]:
# ----------------------------
# 4. モデル定義（ResNet18を使用）
# ----------------------------
num_classes = 10  # 0~9 の数字
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

In [11]:
# ----------------------------
# 5. 損失関数と最適化
# ----------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [12]:
num_epochs = 20  # お試し

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    # tqdmでミニバッチの進捗を表示
    loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
    for imgs, labels in loop:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 統計情報更新
        running_loss += loss.item() * imgs.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # tqdmバーに現在の損失と精度を表示
        loop.set_postfix(loss=running_loss/total, acc=correct/total)

    epoch_loss = running_loss / len(dataset)
    epoch_acc = correct / total
    print(f"Epoch {epoch+1}/{num_epochs} 終了 -> Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")

# モデル保存
torch.save(model.state_dict(), "digit_classifier.pth")
print("モデルを digit_classifier.pth として保存しました")

Epoch 1/20: 100%|██████████| 17/17 [00:28<00:00,  1.70s/batch, acc=0.375, loss=1.9] 


Epoch 1/20 終了 -> Loss: 1.8974, Accuracy: 0.3755


Epoch 2/20: 100%|██████████| 17/17 [00:29<00:00,  1.75s/batch, acc=0.604, loss=1.29]


Epoch 2/20 終了 -> Loss: 1.2857, Accuracy: 0.6038


Epoch 3/20: 100%|██████████| 17/17 [00:31<00:00,  1.83s/batch, acc=0.708, loss=0.963]


Epoch 3/20 終了 -> Loss: 0.9626, Accuracy: 0.7075


Epoch 4/20: 100%|██████████| 17/17 [00:30<00:00,  1.79s/batch, acc=0.785, loss=0.724]


Epoch 4/20 終了 -> Loss: 0.7237, Accuracy: 0.7849


Epoch 5/20: 100%|██████████| 17/17 [00:30<00:00,  1.81s/batch, acc=0.851, loss=0.54] 


Epoch 5/20 終了 -> Loss: 0.5403, Accuracy: 0.8509


Epoch 6/20: 100%|██████████| 17/17 [00:30<00:00,  1.80s/batch, acc=0.908, loss=0.382]


Epoch 6/20 終了 -> Loss: 0.3823, Accuracy: 0.9075


Epoch 7/20: 100%|██████████| 17/17 [00:30<00:00,  1.77s/batch, acc=0.942, loss=0.244]


Epoch 7/20 終了 -> Loss: 0.2443, Accuracy: 0.9415


Epoch 8/20: 100%|██████████| 17/17 [00:31<00:00,  1.84s/batch, acc=0.962, loss=0.195]


Epoch 8/20 終了 -> Loss: 0.1950, Accuracy: 0.9623


Epoch 9/20: 100%|██████████| 17/17 [00:30<00:00,  1.82s/batch, acc=0.97, loss=0.141] 


Epoch 9/20 終了 -> Loss: 0.1410, Accuracy: 0.9698


Epoch 10/20: 100%|██████████| 17/17 [00:30<00:00,  1.81s/batch, acc=0.979, loss=0.107]


Epoch 10/20 終了 -> Loss: 0.1069, Accuracy: 0.9792


Epoch 11/20: 100%|██████████| 17/17 [00:31<00:00,  1.86s/batch, acc=0.983, loss=0.086] 


Epoch 11/20 終了 -> Loss: 0.0860, Accuracy: 0.9830


Epoch 12/20: 100%|██████████| 17/17 [00:29<00:00,  1.75s/batch, acc=0.983, loss=0.0859]


Epoch 12/20 終了 -> Loss: 0.0859, Accuracy: 0.9830


Epoch 13/20: 100%|██████████| 17/17 [00:30<00:00,  1.78s/batch, acc=0.972, loss=0.104] 


Epoch 13/20 終了 -> Loss: 0.1042, Accuracy: 0.9717


Epoch 14/20: 100%|██████████| 17/17 [00:31<00:00,  1.83s/batch, acc=0.987, loss=0.0559]


Epoch 14/20 終了 -> Loss: 0.0559, Accuracy: 0.9868


Epoch 15/20: 100%|██████████| 17/17 [00:30<00:00,  1.78s/batch, acc=0.994, loss=0.0381]


Epoch 15/20 終了 -> Loss: 0.0381, Accuracy: 0.9943


Epoch 16/20: 100%|██████████| 17/17 [00:29<00:00,  1.75s/batch, acc=0.989, loss=0.0413]


Epoch 16/20 終了 -> Loss: 0.0413, Accuracy: 0.9887


Epoch 17/20: 100%|██████████| 17/17 [00:29<00:00,  1.73s/batch, acc=0.992, loss=0.0428]


Epoch 17/20 終了 -> Loss: 0.0428, Accuracy: 0.9925


Epoch 18/20: 100%|██████████| 17/17 [00:29<00:00,  1.72s/batch, acc=0.996, loss=0.0415]


Epoch 18/20 終了 -> Loss: 0.0415, Accuracy: 0.9962


Epoch 19/20: 100%|██████████| 17/17 [00:29<00:00,  1.75s/batch, acc=0.985, loss=0.0403]


Epoch 19/20 終了 -> Loss: 0.0403, Accuracy: 0.9849


Epoch 20/20: 100%|██████████| 17/17 [00:29<00:00,  1.73s/batch, acc=0.991, loss=0.0402]

Epoch 20/20 終了 -> Loss: 0.0402, Accuracy: 0.9906
モデルを digit_classifier.pth として保存しました



