1. Install and set up required libraries

In [66]:
# Install necessary libraries
!pip install --upgrade jax jaxlib flax optax pandas scikit-learn Pillow

# Check GPU availability
import jax
print("Available devices:", jax.devices())


Available devices: [CudaDevice(id=0)]


2. Preparing the dataset

In [67]:
import zipfile
zip_file = "skeletonized_image.zip"
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
    zip_ref.extractall("/content/")

csv_file = "/content/skeletonized_labels.csv"
base_dir = "/content/skeletonized_image"

3. Data loading and preprocessing

In [69]:
import os
import pandas as pd
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split

def load_data_from_csv(csv_file, img_size=(28, 28)):
    """
    Load and preprocess skeletonized images from a CSV file.

    Args:
        csv_file (str): Path to the CSV file containing image paths and labels.
        img_size (tuple): Target size for resizing the images.

    Returns:
        images (np.ndarray): Array of preprocessed images.
        labels (np.ndarray): Array of corresponding labels.
    """
    # Read the CSV file
    data = pd.read_csv(csv_file)

    images, labels = [], []
    for _, row in data.iterrows():
        img_path = row['file_path']  # Image path from CSV
        label = row['label']        # Label from CSV

        try:
            img = Image.open(img_path).convert('L')  # Convert to grayscale
            img = img.resize(img_size)  # Resize to target size
            images.append(np.array(img) / 255.0)  # Normalize pixel values
            labels.append(label)
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")

    return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int32)

# Set paths
csv_file = "/content/skeletonized_labels.csv"  # Path to the CSV file

# Load data
images, labels = load_data_from_csv(csv_file)

# Split dataset into training and testing sets
train_images, test_images, train_labels, test_labels = train_test_split(
    images, labels, test_size=0.2, random_state=42
)

# Add batch dimension (JAX expects [N, H, W, C] format)
train_images = train_images[..., np.newaxis]
test_images = test_images[..., np.newaxis]

print(f"Training images: {train_images.shape}, Training labels: {train_labels.shape}")
print(f"Testing images: {test_images.shape}, Testing labels: {test_labels.shape}")

print("Unique labels in dataset:", np.unique(labels))
print("Number of unique labels:", len(np.unique(labels)))

num_classes = len(np.unique(labels))

Training images: (498, 28, 28, 1), Training labels: (498,)
Testing images: (125, 28, 28, 1), Testing labels: (125,)
Unique labels in dataset: [0 1 2 3 4 5 6 7 8 9]
Number of unique labels: 10


4. CNN model definition

In [75]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax

# Define the CNN model
class SkeletonCNN(nn.Module):
    num_classes: int = len(np.unique(labels))

    @nn.compact
    def __call__(self, x):
        # 기존 첫 번째 Conv: 채널 32 -> 64로 늘림
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))

        # 기존 두 번째 Conv: 채널 64 -> 128로 늘림
        x = nn.Conv(features=128, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))

        # **새롭게 추가한** 세 번째 Conv
        x = nn.Conv(features=128, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        # 필요하면 pooling을 한 번 더 넣을 수도 있음
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))

        # Flatten
        x = x.reshape((x.shape[0], -1))

        # Dense layer
        x = nn.Dense(features=256)(x)  # 기존 128 -> 256로 증가
        x = nn.relu(x)

        # Output layer
        x = nn.Dense(features=self.num_classes)(x)
        return x

model = SkeletonCNN(num_classes=10)
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

params = model.init(init_rng, jnp.ones([1, 28, 28, 1]))['params']


5. Training and Evaluation Functions


- Training Step

In [76]:
@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        logits = SkeletonCNN().apply({'params': params}, batch['image'])
        loss = jnp.mean(optax.softmax_cross_entropy(
            logits=logits,
            labels=jax.nn.one_hot(batch['label'], num_classes=num_classes)))
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits, batch['label'])
    return state, metrics

@jax.jit
def eval_step(params, batch):
    logits = SkeletonCNN().apply({'params': params}, batch['image'])
    return compute_metrics(logits, batch['label'])

def compute_metrics(logits, labels):
    loss = jnp.mean(optax.softmax_cross_entropy(
        logits, jax.nn.one_hot(labels, num_classes=num_classes)
    ))
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return {'loss': loss, 'accuracy': accuracy}

- Training Loop

In [77]:
def train_epoch(state, images, labels, batch_size, rng):
    num_samples = images.shape[0]
    perms = jax.random.permutation(rng, num_samples)
    steps_per_epoch = num_samples // batch_size  # remainder는 버린다. (원한다면 유지해도 됨)

    batch_metrics = []
    for i in range(steps_per_epoch):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size

        batch = {
            'image': images[perms[start_idx:end_idx], ...],
            'label': labels[perms[start_idx:end_idx], ...]
        }
        state, metrics = train_step(state, batch)
        batch_metrics.append(metrics)

    batch_metrics = jax.device_get(batch_metrics)
    epoch_metrics = {
        k: np.mean([m[k] for m in batch_metrics])
        for k in batch_metrics[0]
    }
    return state, epoch_metrics

6. Model Initialization and Training


In [84]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

model = SkeletonCNN()
params = model.init(init_rng, jnp.ones([1, 28, 28, 1]))['params']

# Adam
tx = optax.adam(learning_rate=1e-3)

state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=tx
)

num_epochs = 20
batch_size = 10

for epoch in range(1, num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    state, train_metrics = train_epoch(
        state, train_images, train_labels,
        batch_size, input_rng
    )
    test_metrics = eval_step(
        state.params, {'image': test_images, 'label': test_labels}
    )
    print(f"[Epoch {epoch:2d}] "
          f"Train Loss={train_metrics['loss']:.4f}, Train Acc={train_metrics['accuracy']:.2%}, "
          f"Test Loss={test_metrics['loss']:.4f}, Test Acc={test_metrics['accuracy']:.2%}")

[Epoch  1] Train Loss=2.3187, Train Acc=7.96%, Test Loss=2.2999, Test Acc=11.20%
[Epoch  2] Train Loss=2.3005, Train Acc=11.84%, Test Loss=2.3425, Test Acc=11.20%
[Epoch  3] Train Loss=2.2237, Train Acc=15.10%, Test Loss=2.1139, Test Acc=19.20%
[Epoch  4] Train Loss=2.0119, Train Acc=18.37%, Test Loss=1.8548, Test Acc=24.80%
[Epoch  5] Train Loss=1.7518, Train Acc=30.20%, Test Loss=1.6721, Test Acc=35.20%
[Epoch  6] Train Loss=1.5792, Train Acc=38.78%, Test Loss=1.4736, Test Acc=36.80%
[Epoch  7] Train Loss=1.4243, Train Acc=44.29%, Test Loss=1.3341, Test Acc=49.60%
[Epoch  8] Train Loss=1.2428, Train Acc=52.04%, Test Loss=1.2500, Test Acc=50.40%
[Epoch  9] Train Loss=1.1186, Train Acc=57.55%, Test Loss=1.0429, Test Acc=62.40%
[Epoch 10] Train Loss=0.9751, Train Acc=63.06%, Test Loss=0.9429, Test Acc=67.20%
[Epoch 11] Train Loss=0.7998, Train Acc=67.14%, Test Loss=0.7763, Test Acc=79.20%
[Epoch 12] Train Loss=0.6548, Train Acc=75.71%, Test Loss=0.7095, Test Acc=75.20%
[Epoch 13] Train 