In [45]:
import torch
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
from transformers import (
    ViTFeatureExtractor,
    ViTForImageClassification,
    TrainingArguments,
    Trainer
)
import numpy as np
import evaluate
from PIL import Image
from transformers import TrainingArguments
print(TrainingArguments.__module__)


transformers.training_args


In [46]:
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("GPU name:", torch.cuda.get_device_name(0))


PyTorch version: 2.5.1+cu121
CUDA available: True
GPU name: NVIDIA GeForce RTX 3060 Laptop GPU


In [47]:
data_dir = "chest_xray"

In [48]:
train_dataset = datasets.ImageFolder(root=f"{data_dir}/train")
val_dataset   = datasets.ImageFolder(root=f"{data_dir}/val")
test_dataset  = datasets.ImageFolder(root=f"{data_dir}/test")

print("Train samples:", len(train_dataset))
print("Val samples:", len(val_dataset))
print("Test samples:", len(test_dataset))
print("Classes:", train_dataset.classes)

Train samples: 5216
Val samples: 16
Test samples: 624
Classes: ['NORMAL', 'PNEUMONIA']


In [49]:
extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

def vit_transform(image):
    return extractor(images=image, return_tensors="pt")["pixel_values"].squeeze(0)

train_dataset.transform = vit_transform
val_dataset.transform   = vit_transform
test_dataset.transform  = vit_transform

In [50]:
class HFDataset(Dataset):
    def __init__(self, folder_dataset):
        self.dataset = folder_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        return {"pixel_values": img, "labels": label}

train_dataset_hf = HFDataset(train_dataset)
val_dataset_hf   = HFDataset(val_dataset)
test_dataset_hf  = HFDataset(test_dataset)

In [51]:
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=2,
    id2label={0: "NORMAL", 1: "PNEUMONIA"},
    label2id={"NORMAL": 0, "PNEUMONIA": 1},
)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [52]:
training_args = TrainingArguments(
    output_dir="./vit_xray_results",
    evaluation_strategy="epoch",     # ✅ ใช้ชื่อเต็ม!
    save_strategy="epoch",           # บันทึกทุก epoch
    learning_rate=3e-5,              # 👈 ปรับให้เหมาะกับ ViT
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=8,              # 👈 เพิ่ม epoch ให้ ViT fine-tune ดีขึ้น
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy", # ใช้ accuracy เลือกโมเดลที่ดีที่สุด
    greater_is_better=True,           # บอกว่า accuracy ยิ่งสูงยิ่งดี
    lr_scheduler_type="cosine",       # 👈 ใช้ cosine schedule (เหมาะกับ fine-tuning)
    warmup_ratio=0.1,                 # 👈 warmup 10% ของ steps
)

In [53]:
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return accuracy.compute(predictions=preds, references=labels)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset_hf,
    eval_dataset=val_dataset_hf,
    compute_metrics=compute_metrics,
)

trainer.train()

 17%|█▋        | 281/1630 [10:39<51:10,  2.28s/it]
                                                  
  2%|▏         | 50/2608 [00:46<38:14,  1.12it/s]

{'loss': 0.6359, 'grad_norm': 0.9570462703704834, 'learning_rate': 5.747126436781609e-06, 'epoch': 0.15}


                                                  
  4%|▍         | 100/2608 [01:33<39:02,  1.07it/s]

{'loss': 0.3556, 'grad_norm': 2.292985439300537, 'learning_rate': 1.1494252873563218e-05, 'epoch': 0.31}


                                                  
  6%|▌         | 150/2608 [02:20<38:08,  1.07it/s]

{'loss': 0.1781, 'grad_norm': 1.2274476289749146, 'learning_rate': 1.7241379310344828e-05, 'epoch': 0.46}


                                                  
  8%|▊         | 200/2608 [03:06<36:50,  1.09it/s]

{'loss': 0.0874, 'grad_norm': 6.357956409454346, 'learning_rate': 2.2988505747126437e-05, 'epoch': 0.61}


                                                  
 10%|▉         | 250/2608 [03:53<36:36,  1.07it/s]

{'loss': 0.0943, 'grad_norm': 4.9710516929626465, 'learning_rate': 2.8735632183908045e-05, 'epoch': 0.77}


                                                    
 12%|█▏        | 300/2608 [05:01<1:27:48,  2.28s/it]

{'loss': 0.1288, 'grad_norm': 0.7385244369506836, 'learning_rate': 2.9979565434294196e-05, 'epoch': 0.92}


 12%|█▎        | 326/2608 [05:57<1:01:20,  1.61s/it]
                                                    
[A                                              

 12%|█▎        | 326/2608 [05:58<1:01:20,  1.61s/it]
[A
[A

{'eval_loss': 0.15796920657157898, 'eval_accuracy': 0.9375, 'eval_runtime': 1.1633, 'eval_samples_per_second': 13.754, 'eval_steps_per_second': 0.86, 'epoch': 1.0}


                                                    
 13%|█▎        | 350/2608 [06:26<40:43,  1.08s/it]

{'loss': 0.0631, 'grad_norm': 0.1262059509754181, 'learning_rate': 2.9893683384022836e-05, 'epoch': 1.07}


                                                    
 15%|█▌        | 400/2608 [07:30<36:44,  1.00it/s]

{'loss': 0.0487, 'grad_norm': 7.311153411865234, 'learning_rate': 2.974111243076491e-05, 'epoch': 1.23}


                                                    
 17%|█▋        | 450/2608 [08:26<33:33,  1.07it/s]

{'loss': 0.0687, 'grad_norm': 0.14053528010845184, 'learning_rate': 2.9522535735914266e-05, 'epoch': 1.38}


                                                  
 19%|█▉        | 500/2608 [09:12<33:10,  1.06it/s]

{'loss': 0.0647, 'grad_norm': 7.382377624511719, 'learning_rate': 2.9238932012366986e-05, 'epoch': 1.53}


                                                    
 21%|██        | 550/2608 [10:15<59:03,  1.72s/it]

{'loss': 0.0456, 'grad_norm': 1.0640891790390015, 'learning_rate': 2.8891571142174315e-05, 'epoch': 1.69}


                                                  
 23%|██▎       | 600/2608 [11:07<35:32,  1.06s/it]

{'loss': 0.0682, 'grad_norm': 0.05844239145517349, 'learning_rate': 2.8482008490438112e-05, 'epoch': 1.84}


                                                    
 25%|██▍       | 650/2608 [12:31<59:07,  1.81s/it]

{'loss': 0.0677, 'grad_norm': 0.43423500657081604, 'learning_rate': 2.8012077940909242e-05, 'epoch': 1.99}


 25%|██▌       | 652/2608 [12:33<46:39,  1.43s/it]
[A

                                                  
[A                                              
 25%|██▌       | 652/2608 [12:35<46:39,  1.43s/it]
[A

{'eval_loss': 0.5480027794837952, 'eval_accuracy': 0.75, 'eval_runtime': 1.4737, 'eval_samples_per_second': 10.857, 'eval_steps_per_second': 0.679, 'epoch': 2.0}


 26%|██▌       | 683/2608 [13:29<53:00,  1.65s/it]  

In [None]:
print("Evaluating on test dataset...")
metrics = trainer.evaluate(test_dataset_hf)
print(metrics)

In [None]:
trainer.save_model("./vit_xray_finetuned")
print("Model saved to ./vit_xray_finetuned")