In [None]:
# !pip install transformers datasets torchvision tensorboard evaluate datasets

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-man

## ViT model

🤖 모델: ViTForImageClassification (HuggingFace, pretrained)

🖼️ 입력 전처리: ViTImageProcessor 사용 (구버전 ViTFeatureExtractor 대체)

⚙️ 옵티마이저: AdamW

🧈 손실 함수: CrossEntropyLoss

📈 스케줄러: OneCycleLR

📉 조기 종료: EarlyStoppingCallback

📊 로깅: TensorBoard

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from transformers import (
    ViTForImageClassification,
    ViTImageProcessor,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback,
    TrainerCallback
)
from datasets import Dataset
import evaluate
import numpy as np

# 디바이스 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ✅ 이미지 전처리기
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

# ✅ torchvision의 PIL 이미지만 사용
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True)

# ✅ HuggingFace Datasets로 변환
train_ds = Dataset.from_list([{"pixel_values": img, "labels": label} for img, label in train_dataset])
test_ds = Dataset.from_list([{"pixel_values": img, "labels": label} for img, label in test_dataset])

# ✅ collate_fn에서 image_processor 적용
def collate_fn(batch):
    images = [item["pixel_values"] for item in batch]
    labels = torch.tensor([item["labels"] for item in batch])
    processed = image_processor(images=images, return_tensors="pt")
    pixel_values = processed["pixel_values"]
    return {"pixel_values": pixel_values, "labels": labels}

# ✅ 메트릭 정의
accuracy = evaluate.load("accuracy")
precision = evaluate.load("precision")
recall = evaluate.load("recall")
f1 = evaluate.load("f1")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)
    return {
        "accuracy": accuracy.compute(predictions=preds, references=labels)["accuracy"],
        "precision": precision.compute(predictions=preds, references=labels, average="macro")["precision"],
        "recall": recall.compute(predictions=preds, references=labels, average="macro")["recall"],
        "f1": f1.compute(predictions=preds, references=labels, average="macro")["f1"]
    }

# ✅ OneCycleLR 콜백 정의
class OneCycleLRSchedulerCallback(TrainerCallback):
    def __init__(self, max_lr, steps_per_epoch, epochs):
        self.max_lr = max_lr
        self.steps_per_epoch = steps_per_epoch
        self.epochs = epochs
        self.scheduler = None

    def on_train_begin(self, args, state, control, **kwargs):
        optimizer = kwargs['optimizer']
        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=self.max_lr,
            steps_per_epoch=self.steps_per_epoch,
            epochs=self.epochs
        )

    def on_step_end(self, args, state, control, **kwargs):
        if self.scheduler:
            self.scheduler.step()

# ✅ 학습 중 로그 출력 콜백 정의
class LoggingCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs:
            print(f"[Step {state.global_step}] Log: {logs}")

# ✅ 모델 불러오기
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=10
)
model.to(device)

# ✅ 학습 설정
training_args = TrainingArguments(
    output_dir="./vit-cifar10",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    logging_dir="./logs",  # TensorBoard
    logging_steps=10,
    report_to="tensorboard"
)

# ✅ Trainer 정의 및 학습
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
    callbacks=[
        OneCycleLRSchedulerCallback(max_lr=5e-5, steps_per_epoch=len(train_ds) // 32, epochs=10),
        EarlyStoppingCallback(early_stopping_patience=3)
        LoggingCallback()
    ]
)

trainer.train()
trainer.save_model("./vit-cifar10/best-model")
# tensorboard --logdir=./logs

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.2924,0.28926,0.9717,0.971824,0.9717,0.971666
2,0.0536,0.089542,0.9791,0.979299,0.9791,0.979121
3,0.0716,0.083222,0.9796,0.979933,0.9796,0.979631
4,0.0437,0.080801,0.9795,0.979847,0.9795,0.979511
5,0.0265,0.083366,0.9791,0.979183,0.9791,0.979031
6,0.0184,0.114579,0.9781,0.978317,0.9781,0.978114


In [None]:
# !tensorboard --logdir=./logs

In [None]:
# !pip install pytorch_lightning

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.5.1-py3-none-any.whl.metadata (20 kB)
Collecting torchmetrics>=0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch_lightning)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Downloading pytorch_lightning-2.5.1-py3-none-any.whl (822 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.0/823.0 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.14.3-py3-none-any.whl (28 kB)
Downloading torchmetrics-1.7.1-py3-none-any.whl (961 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m961.5/961.5 kB[0m [31m43.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lightning-utilities, torchmetrics, pytorch_lightning
Successfully installed lightning-utilities-0.14.3 pytorch_lightning-2.5.1 torchmetrics-1.7.1


## ViT model by Pytorch Lightning

🤖 모델: ViTForImageClassification (HuggingFace, pretrained)

🖼️ 입력 전처리: ViTImageProcessor 사용

- 데이터: Pytorch의 CIFAR10
- PyTorch Lightning의 Trainer로 학습
- ⚙️ 옵티마이저: AdamW
- 🧈 손실 함수: CrossEntropyLoss

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import ViTForImageClassification, ViTImageProcessor
from torchmetrics.classification import Accuracy

# 1. FeatureExtractor
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

# 2. Transform 정의
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
])

# 3. Dataset & Dataloader
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
val_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)

# 4. Hugging Face ViT 모델을 LightningModule로 감싸기
class ViTLightning(pl.LightningModule):
    def __init__(self, num_classes, lr=2e-5):
        super().__init__()
        self.save_hyperparameters()
        self.model = ViTForImageClassification.from_pretrained(
            "google/vit-base-patch16-224-in21k",
            num_labels=num_classes
        )
        self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)

    def forward(self, x):
        return self.model(x).logits

    def training_step(self, batch, batch_idx):
        x, y = batch
        outputs = self.model(x, labels=y)
        loss = outputs.loss
        preds = torch.argmax(outputs.logits, dim=1)
        acc = self.train_acc(preds, y)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        outputs = self.model(x, labels=y)
        loss = outputs.loss
        preds = torch.argmax(outputs.logits, dim=1)
        acc = self.val_acc(preds, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)

# 5. 모델 생성 & 학습 실행
model = ViTLightning(num_classes=len(train_dataset.classes))

trainer = pl.Trainer(
    max_epochs=5,
    accelerator="auto",
    devices=1,
    log_every_n_steps=10,
)

trainer.fit(model, train_loader, val_loader)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:pytorch_lightning.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type                      | Params | Mode 
----------------------------------------------------------------
0 | model     | ViT

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


In [6]:
!tensorboard --logdir lightning_logs/

2025-04-15 13:09:08.727971: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744722548.749714   15761 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744722548.756352   15761 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.18.0 at http://localhost:6006/ (Press CTRL+C to quit)


## ViT model by Pytorch Lightning with Callback Function

🤖 모델: ViTForImageClassification (HuggingFace, pretrained)

🖼️ 입력 전처리: ViTImageProcessor 사용

- 데이터: Pytorch의 CIFAR10
- PyTorch Lightning의 Trainer로 학습
- 옵티마이저: AdamW
- 손실 함수: CrossEntropyLoss
- Metric: Accuracy, Precision, Recall, F1Score
- Callback Function: ModelCheckpoint, EarlyStopping

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import ViTForImageClassification, ViTImageProcessor
from torchmetrics.classification import (
    Accuracy, Precision, Recall, F1Score
)
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

# 1. FeatureExtractor
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

# 2. Transform 정의
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
])

# 3. Dataset & Dataloader
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
val_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)

# 4. LightningModule
class ViTLightning(pl.LightningModule):
    def __init__(self, num_classes, lr=2e-5):
        super().__init__()
        self.save_hyperparameters()
        self.model = ViTForImageClassification.from_pretrained(
            "google/vit-base-patch16-224-in21k",
            num_labels=num_classes
        )
        self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.train_f1 = F1Score(task="multiclass", num_classes=num_classes)
        self.train_precision = Precision(task="multiclass", num_classes=num_classes)
        self.train_recall = Recall(task="multiclass", num_classes=num_classes)

        self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.val_f1 = F1Score(task="multiclass", num_classes=num_classes)
        self.val_precision = Precision(task="multiclass", num_classes=num_classes)
        self.val_recall = Recall(task="multiclass", num_classes=num_classes)

    def forward(self, x):
        return self.model(x).logits

    def training_step(self, batch, batch_idx):
        x, y = batch
        outputs = self.model(x, labels=y)
        loss = outputs.loss
        preds = torch.argmax(outputs.logits, dim=1)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", self.train_acc(preds, y), prog_bar=True)
        self.log("train_f1", self.train_f1(preds, y), prog_bar=True)
        self.log("train_precision", self.train_precision(preds, y), prog_bar=False)
        self.log("train_recall", self.train_recall(preds, y), prog_bar=False)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        outputs = self.model(x, labels=y)
        loss = outputs.loss
        preds = torch.argmax(outputs.logits, dim=1)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.val_acc(preds, y), prog_bar=True)
        self.log("val_f1", self.val_f1(preds, y), prog_bar=True)
        self.log("val_precision", self.val_precision(preds, y), prog_bar=False)
        self.log("val_recall", self.val_recall(preds, y), prog_bar=False)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x, _ = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        return preds

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)

# 5. 콜백 설정
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath="checkpoints",
    filename="vit-best",
    save_top_k=1,
    mode="min"
)

early_stopping_callback = EarlyStopping(
    monitor="val_loss",
    patience=3,
    mode="min",
    verbose=True
)

# 6. Trainer & 학습 실행
model = ViTLightning(num_classes=len(train_dataset.classes))

trainer = pl.Trainer(
    max_epochs=10,
    accelerator="auto",
    devices=1,
    log_every_n_steps=10,
    callbacks=[checkpoint_callback, early_stopping_callback]
)

trainer.fit(model, train_loader, val_loader)

# 7. 예측 (test set에 대해)
predictions = trainer.predict(model, val_loader)
all_preds = torch.cat(predictions).cpu().numpy()


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name            | Type                      | Params | Mode 
----------------------------------------------------------------------
0 | model           | ViTForImageClassification | 85.8 M | eval 
1 | train_acc       | MulticlassAccuracy        | 0      | train
2 | train_f1        | MulticlassF1Score         | 0  

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved. New best score: 0.245


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.103 >= min_delta = 0.0. New best score: 0.142


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.036 >= min_delta = 0.0. New best score: 0.106


Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
!tensorboard --logdir lightning_logs/