# FashionMNIST 다운로드

In [1]:
from itertools import chain
from collections import defaultdict
from torch.utils.data import Subset
from torchvision import datasets

def subset_sampler(dataset, classes, max_len):
    target_idx = defaultdict(list)
    for idx, label in enumerate(dataset.train_labels):
        target_idx[int(label)].append(idx)

    indices = list(
        chain.from_iterable(
            [target_idx[idx][:max_len] for idx in range(len(classes))]
        )
    )
    return Subset(dataset, indices)

train_dataset = datasets.FashionMNIST(
    root='../datasets',
    train=True,
    download=True,
)

test_dataset = datasets.FashionMNIST(
    root='../datasets',
    train=True,
    download=False,
)

classes = train_dataset.classes
class_to_idx = train_dataset.class_to_idx

print(classes)
print(class_to_idx)

subset_train_dataset = subset_sampler(dataset=train_dataset, classes=train_dataset.classes, max_len= 1000)
subset_test_dataset = subset_sampler(dataset=test_dataset, classes=test_dataset.classes, max_len=100)

print(f"Training Data Size: {len(subset_train_dataset)}")
print(f"Test Data Size: {len(subset_test_dataset)}")
print(train_dataset[0])


100%|██████████| 26.4M/26.4M [00:01<00:00, 18.6MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 302kB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 5.58MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 14.6MB/s]


['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
{'T-shirt/top': 0, 'Trouser': 1, 'Pullover': 2, 'Dress': 3, 'Coat': 4, 'Sandal': 5, 'Shirt': 6, 'Sneaker': 7, 'Bag': 8, 'Ankle boot': 9}
Training Data Size: 10000
Test Data Size: 1000
(<PIL.Image.Image image mode=L size=28x28 at 0x7EF46042F5D0>, 9)


# ViT 실습

In [2]:
# 이미지 전처리
import torch
from torchvision import transforms
from transformers import AutoImageProcessor

image_processor = AutoImageProcessor.from_pretrained(
    pretrained_model_name_or_path="google/vit-base-patch16-224-in21k")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(
        size = (
            image_processor.size["height"],
            image_processor.size["width"]
        )
    ),
    transforms.Lambda(
        lambda x: torch.cat([x,x,x],0)
    ),
    transforms.Normalize(
        mean=image_processor.image_mean,
        std=image_processor.image_std
    )
]
)

print(f"size: {image_processor.size}")
print(f"mean: {image_processor.image_mean}")
print(f"std: {image_processor.image_std}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


size: {'height': 224, 'width': 224}
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]


In [3]:
#ViT 데이터로더 적용
from torch.utils.data import DataLoader

# 데이터로더에 집합함수 collate_fn을 적용하기 위해 사용자 정의함수 생성
# 입력값: 데이터로더에서 불러온 배치
def collator(data, transform):
    images, labels = zip(*data)
    pixel_values = torch.stack([transform(image) for image in images])
    labels = torch.tensor([label for label in labels])
    return {"pixel_values": pixel_values, "labels": labels}

train_dataloader = DataLoader(
    subset_train_dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=lambda x: collator(x, transform),
    drop_last=True
)

valid_dataloader = DataLoader(
    subset_test_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=lambda x: collator(x, transform),
    drop_last=True
)

batch = next(iter(train_dataloader))
for key, value in batch.items():
    print(f"{key}:{value.shape}")

pixel_values:torch.Size([32, 3, 224, 224])
labels:torch.Size([32])


In [4]:
# 사전학습된 ViT모델 불러오기
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained(
    pretrained_model_name_or_path="google/vit-base-patch16-224-in21k",
    num_labels=len(classes),
    id2label = {idx: label for label, idx in class_to_idx.items()},
    label2id = class_to_idx,
    ignore_mismatched_sizes=True
)

print(model.classifier)

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Linear(in_features=768, out_features=10, bias=True)


In [5]:
print(model.vit.embeddings)

batch = next(iter(train_dataloader))
print("image shape: ", batch["pixel_values"].shape)
print("patch embeddings shape: ",
      model.vit.embeddings.patch_embeddings(batch["pixel_values"]).shape)
print("[CLS] + patch embeddings shape: ",
      model.vit.embeddings(batch["pixel_values"]).shape)

ViTEmbeddings(
  (patch_embeddings): ViTPatchEmbeddings(
    (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (dropout): Dropout(p=0.0, inplace=False)
)
image shape:  torch.Size([32, 3, 224, 224])
patch embeddings shape:  torch.Size([32, 196, 768])
[CLS] + patch embeddings shape:  torch.Size([32, 197, 768])


In [6]:
# 하이퍼파라미터 설정
from transformers import TrainingArguments
args = TrainingArguments(
    output_dir="../models/ViT-FashinMNIST", #체크포인트 저장 경로
    save_strategy="epoch", #체크포인트 저장 간격
    eval_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.001, # 가중치 감쇠
    load_best_model_at_end=True,
    metric_for_best_model="f1", # 매크로 평균 F점수
    logging_dir="logs",
    logging_steps=125, # 로그 출력 간격
    remove_unused_columns=False,
    seed=7,
    report_to=["none"]
)

In [7]:
!pip install evaluate

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting datasets>=2.0.0 (from evaluate)
  Downloading datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting dill (from evaluate)
  Downloading dill-0.4.0-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from evaluate)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from evaluate)
  Downloading multiprocess-0.70.18-py311-none-any.whl.metadata (7.5 kB)
Collecting dill (from evaluate)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting multiprocess (from evaluate)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec>=2021.05.0 (from fsspec[http]>=2021.05.0->evaluate)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [

In [8]:
# 매크로 평균 F1 점수
import evaluate
import numpy as np

def compute_metrics(eval_pred):
    metric = evaluate.load("f1")
    predictions, lables = eval_pred
    predictions = np.argmax(predictions, axis=1)
    macro_f1 = metric.compute(
        predictions=predictions, references = labels, average="macro"
    )
    return macro_f1

In [17]:
from transformers import TrainingArguments, Trainer


def subset_sampler(dataset, classes, max_len):
    target_idx = defaultdict(list)
    for idx, label in enumerate(dataset.train_labels):
        target_idx[int(label)].append(idx)

    indices = list(
        chain.from_iterable(
            [target_idx[idx][:max_len] for idx in range(len(classes))]
        )
    )
    return Subset(dataset, indices)

def model_init(classes, class_to_idx):
    model = ViTForImageClassification.from_pretrained(
        pretrained_model_name_or_path="google/vit-base-patch16-224-in21k",
        num_labels=len(classes),
        id2label = {idx: label for label, idx in class_to_idx.items()},
        label2id = class_to_idx,
    )
    return model

def collator(data, transform):
    images, labels = zip(*data)
    pixel_values = torch.stack([transform(image) for image in images])
    labels = torch.tensor([label for label in labels])
    return {"pixel_values": pixel_values, "labels": labels}

def compute_metrics(eval_pred):
    metric = evaluate.load("f1")
    predictions, lables = eval_pred
    predictions = np.argmax(predictions, axis=1)
    macro_f1 = metric.compute(
        predictions=predictions, references = labels, average="macro"
    )
    return macro_f1
train_dataset = datasets.FashionMNIST(
    root='../datasets',
    train=True,
    download=True,
)

test_dataset = datasets.FashionMNIST(
    root='../datasets',
    train=True,
    download=False,
)

classes = train_dataset.classes
class_to_idx = train_dataset.class_to_idx

subset_train_dataset = subset_sampler(dataset=train_dataset, classes=train_dataset.classes, max_len= 1000)
subset_test_dataset = subset_sampler(dataset=test_dataset, classes=test_dataset.classes, max_len=100)

image_processor = AutoImageProcessor.from_pretrained(
    pretrained_model_name_or_path="google/vit-base-patch16-224-in21k",
    use_fast=True)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(
        size = (
            image_processor.size["height"],
            image_processor.size["width"]
        )
    ),
    transforms.Lambda(
        lambda x: torch.cat([x,x,x],0)
    ),
    transforms.Normalize(
        mean=image_processor.image_mean,
        std=image_processor.image_std
    )
]
)

args = TrainingArguments(
    output_dir="../models/ViT-FashinMNIST", #체크포인트 저장 경로
    save_strategy="epoch", #체크포인트 저장 간격
    eval_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.001, # 가중치 감쇠
    load_best_model_at_end=True,
    metric_for_best_model="f1", # 매크로 평균 F점수
    logging_dir="logs",
    logging_steps=125, # 로그 출력 간격
    remove_unused_columns=False,
    seed=7,
    report_to=["none"]
)


trainer = Trainer(
    model_init = lambda x: model_init(classes,class_to_idx),
    args=args,
    train_dataset=subset_train_dataset,
    eval_dataset=subset_test_dataset,
    data_collator=lambda x: collator(x, transform),
    processing_class=image_processor,
    compute_metrics=compute_metrics,
    )

trainer.train()

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


KeyboardInterrupt: 

In [None]:
#ViT 모델 성능 평가
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

outputs = trainer.predict(subset_test_dataset)
print(outputs)

y_true = outputs.label_ids
y_pred = outputs.predictions.argmax(axis=1)

labels = list(classes)
matrix = confusion_matrix(y_true, y_pred)
display = ConfusionMatrixDisplay(confusion_matrix=matrix, display_labels=labels)
_,ax = plt.subplots(figsize=(10,10))
display.plot(xtricks_rotation=45, ax=ax)
plt.show()

# Swin Transformer

In [9]:
# 사전 학습된 스윈 트랜스포머 모델
from transformers import SwinForImageClassification

model = SwinForImageClassification.from_pretrained(
    pretrained_model_name_or_path="microsoft/swin-base-patch4-window7-224-in22k",
    num_labels=len(train_dataset.classes),
    id2label = {idx: label for label, idx in train_dataset.class_to_idx.items()},
    label2id = train_dataset.class_to_idx,
    ignore_mismatched_sizes=True
)

for main_name, main_module in model.named_children():
    print(main_name)
    for sub_name, sub_module in main_module.named_children():
        print("L", sub_name)
        for ssub_name, ssub_module in sub_module.named_children():
            print("| L", ssub_name)
            for sssub_name, sssub_module in ssub_module.named_children():
                if sssub_name == "projection":
                    print("| | L", sssub_name, sssub_module)
                else:
                    print("| | L", sssub_name)


config.json:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/437M [00:00<?, ?B/s]

Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-base-patch4-window7-224-in22k and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([21841]) in the checkpoint and torch.Size([10]) in the model instantiated
- classifier.weight: found shape torch.Size([21841, 1024]) in the checkpoint and torch.Size([10, 1024]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


swin
L embeddings
| L patch_embeddings
| | L projection Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
| L norm
| L dropout
L encoder
| L layers
| | L 0
| | L 1
| | L 2
| | L 3
L layernorm
L pooler
classifier


In [10]:
# 패치 임베딩 모듈
batch = next(iter(train_dataloader))
print("이미지 차원 : ", batch["pixel_values"].shape)

patch_emb_output, shape = model.swin.embeddings.patch_embeddings(batch["pixel_values"])
print("모듈:", model.swin.embeddings.patch_embeddings)
print("패치 임베딩 차원:", patch_emb_output.shape)

이미지 차원 :  torch.Size([32, 3, 224, 224])
모듈: SwinPatchEmbeddings(
  (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
)
패치 임베딩 차원: torch.Size([32, 3136, 128])


In [11]:
# 스윈 트랜스포머 블록
for main_name, main_module in model.swin.encoder.layers[0].named_children():
    print(main_name)
    for sub_name, sub_module in main_module.named_children():
        print("L", sub_name)
        for ssub_name, ssub_module in sub_module.named_children():
            print("| L", ssub_name)

blocks
L 0
| L layernorm_before
| L attention
| L drop_path
| L layernorm_after
| L intermediate
| L output
L 1
| L layernorm_before
| L attention
| L drop_path
| L layernorm_after
| L intermediate
| L output
downsample
L reduction
L norm


In [12]:
print(model.swin.encoder.layers[0].blocks[0])

SwinLayer(
  (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (attention): SwinAttention(
    (self): SwinSelfAttention(
      (query): Linear(in_features=128, out_features=128, bias=True)
      (key): Linear(in_features=128, out_features=128, bias=True)
      (value): Linear(in_features=128, out_features=128, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (output): SwinSelfOutput(
      (dense): Linear(in_features=128, out_features=128, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
  )
  (drop_path): Identity()
  (layernorm_after): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (intermediate): SwinIntermediate(
    (dense): Linear(in_features=128, out_features=512, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): SwinOutput(
    (dense): Linear(in_features=512, out_features=128, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
)


In [13]:
# W-MSA, SW-MSA 모듈
print("패치 임베딩 차원: ", patch_emb_output.shape)

W_MSA = model.swin.encoder.layers[0].blocks[0]
SW_MSA = model.swin.encoder.layers[0].blocks[1]

W_MSA_output = W_MSA(patch_emb_output, W_MSA.input_resolution)[0]
SW_MSA_output = SW_MSA(W_MSA_output, SW_MSA.input_resolution)[0]
print("W-MSA 차원: ", W_MSA_output.shape)
print("SW-MSA 차원: ", SW_MSA_output.shape)

패치 임베딩 차원:  torch.Size([32, 3136, 128])
W-MSA 차원:  torch.Size([32, 3136, 128])
SW-MSA 차원:  torch.Size([32, 3136, 128])


In [14]:
# 패치 병합
patch_merge = model.swin.encoder.layers[0].downsample
print("print_merge 모듈: ", patch_merge)
output = patch_merge(SW_MSA_output, patch_merge.input_resolution)
print("patch_merge 결과 차원: ", output.shape)

print_merge 모듈:  SwinPatchMerging(
  (reduction): Linear(in_features=512, out_features=256, bias=False)
  (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
patch_merge 결과 차원:  torch.Size([32, 784, 256])


In [15]:
from transformers import Trainer

In [None]:
# 스윈 트랜스포머 미세 조정

def subset_sampler(dataset, classes, max_len):
    target_idx = defaultdict(list)
    for idx, label in enumerate(dataset.train_labels):
        target_idx[int(label)].append(idx)

    indices = list(
        chain.from_iterable(
            [target_idx[idx][:max_len] for idx in range(len(classes))]
        )
    )
    return Subset(dataset, indices)

def model_init(classes, class_to_idx):
    model = ViTForImageClassification.from_pretrained(
        pretrained_model_name_or_path="microsoft/swin-tiny-patch4-window7-224",
        num_labels=len(classes),
        id2label = {idx: label for label, idx in class_to_idx.items()},
        label2id = class_to_idx,
        ignore_mismatched_sizes=True
    )
    return model

def collator(data, transform):
    images, labels = zip(*data)
    pixel_values = torch.stack([transform(image) for image in images])
    labels = torch.tensor([label for label in labels])
    return {"pixel_values": pixel_values, "labels": labels}

def compute_metrics(eval_pred):
    metric = evaluate.load("f1")
    predictions, lables = eval_pred
    predictions = np.argmax(predictions, axis=1)
    macro_f1 = metric.compute(
        predictions=predictions, references = labels, average="macro"
    )
    return macro_f1
train_dataset = datasets.FashionMNIST(
    root='../datasets',
    train=True,
    download=True,
)

test_dataset = datasets.FashionMNIST(
    root='../datasets',
    train=True,
    download=False,
)

classes = train_dataset.classes
class_to_idx = train_dataset.class_to_idx

subset_train_dataset = subset_sampler(dataset=train_dataset, classes=train_dataset.classes, max_len= 1000)
subset_test_dataset = subset_sampler(dataset=test_dataset, classes=test_dataset.classes, max_len=100)

image_processor = AutoImageProcessor.from_pretrained(
    pretrained_model_name_or_path="microsoft/swin-tiny-patch4-window7-224")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(
        size = (
            image_processor.size["height"],
            image_processor.size["width"]
        )
    ),
    transforms.Lambda(
        lambda x: torch.cat([x,x,x],0)
    ),
    transforms.Normalize(
        mean=image_processor.image_mean,
        std=image_processor.image_std
    )
]
)

args = TrainingArguments(
    output_dir="../models/Swin-FashionMNIST", #체크포인트 저장 경로
    save_strategy="epoch", #체크포인트 저장 간격
    eval_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.001, # 가중치 감쇠
    load_best_model_at_end=True,
    metric_for_best_model="f1", # 매크로 평균 F점수
    logging_dir="logs",
    logging_steps=125, # 로그 출력 간격
    remove_unused_columns=False,
    seed=7,
)


trainer = Trainer(
    model_init = lambda x: model_init(classes,class_to_idx),
    args=args,
    train_dataset=subset_train_dataset,
    eval_dataset=subset_test_dataset,
    data_collator=lambda x: collator(x, transform),
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    )

trainer.train()

# CvT

In [18]:
#이미지 전처리
from transformers import AutoImageProcessor

image_processor = AutoImageProcessor.from_pretrained(
    pretrained_model_name_or_path="microsoft/cvt-21",
    use_fast=True)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(
        size = (
            image_processor.size["shortest_edge"],
            image_processor.size["shortest_edge"]
        )
    ),
    transforms.Lambda(
        lambda x: torch.cat([x,x,x],0)
    ),
    transforms.Normalize(
        mean=image_processor.image_mean,
        std=image_processor.image_std
    )
]
)

print(f"size: {image_processor.size}")
print(f"mean: {image_processor.image_mean}")
print(f"std: {image_processor.image_std}")


preprocessor_config.json:   0%|          | 0.00/266 [00:00<?, ?B/s]

size: {'shortest_edge': 224}
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]


In [20]:
# 사전학습된 CvT 모델
from transformers import CvtForImageClassification

model = CvtForImageClassification.from_pretrained(
    pretrained_model_name_or_path="microsoft/cvt-21",
    num_labels=len(train_dataset.classes),
    id2label={idx: label for label, idx in train_dataset.class_to_idx.items()},
    ignore_mismatched_sizes=True
)

for main_name, main_module in model.named_children():
    print(main_name)
    for sub_name, sub_module in main_module.named_children():
        print("L", sub_name)
        for ssub_name, ssub_module in sub_module.named_children():
            print("  L", ssub_name)
            for sssub_name, sssub_module in ssub_module.named_children():
                print("    L", sssub_name)


config.json:   0%|          | 0.00/70.3k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/127M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/127M [00:00<?, ?B/s]

Some weights of CvtForImageClassification were not initialized from the model checkpoint at microsoft/cvt-21 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 384]) in the checkpoint and torch.Size([10, 384]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


cvt
L encoder
  L stages
    L 0
    L 1
    L 2
layernorm
classifier


In [21]:
# CvT 모델의 스테이지 구조
stages = model.cvt.encoder.stages
print(stages[0])

CvtStage(
  (embedding): CvtEmbeddings(
    (convolution_embeddings): CvtConvEmbeddings(
      (projection): Conv2d(3, 64, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
      (normalization): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (layers): Sequential(
    (0): CvtLayer(
      (attention): CvtAttention(
        (attention): CvtSelfAttention(
          (convolution_projection_query): CvtSelfAttentionProjection(
            (convolution_projection): CvtSelfAttentionConvProjection(
              (convolution): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
              (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (linear_projection): CvtSelfAttentionLinearProjection()
          )
          (convolution_projection_key): CvtSelfAttentionProjection(
            (convolution_projection): CvtSelfAtte

In [None]:
# 셀프 어텐션 적용
batch = next(iter(train_dataloader))
print("이미지 차원: ", batch["pixel_values"].shape)
print("패치 임베딩 차원:", patch_emb_output.shape)

batch_size, num_channels, height, width = patch_emb_output.shape
hidden_state = patch_emb_output.view(batch_size, num_channels, height * width).permute(0, 2, 1)
print("셀프 어텐션 입력 차원: ", hidden_state.shape)

attention_output = stages[0].layers[0].attention.attention(hidden_state, height, width)
print("셀프 어텐션 출력 차원: ", attention_output.shape)

In [None]:

def model_init(classes, class_to_idx):
    model = CvtForImageClassification.from_pretrained(
        pretrained_model_name_or_path="microsoft/cvt-21",
        num_labels=len(classes),
        id2label = {idx: label for label, idx in class_to_idx.items()},
        label2id = class_to_idx,
    )
    return model

def collator(data, transform):
    images, labels = zip(*data)
    pixel_values = torch.stack([transform(image) for image in images])
    labels = torch.tensor([label for label in labels])
    return {"pixel_values": pixel_values, "labels": labels}

def compute_metrics(eval_pred):
    metric = evaluate.load("f1")
    predictions, lables = eval_pred
    predictions = np.argmax(predictions, axis=1)
    macro_f1 = metric.compute(
        predictions=predictions, references = labels, average="macro"
    )
    return macro_f1
train_dataset = datasets.FashionMNIST(
    root='../datasets',
    train=True,
    download=True,
)

test_dataset = datasets.FashionMNIST(
    root='../datasets',
    train=True,
    download=False,
)

classes = train_dataset.classes
class_to_idx = train_dataset.class_to_idx

subset_train_dataset = subset_sampler(dataset=train_dataset, classes=train_dataset.classes, max_len= 1000)
subset_test_dataset = subset_sampler(dataset=test_dataset, classes=test_dataset.classes, max_len=100)


args = TrainingArguments(
    output_dir="../models/CvT-FashionMNIST", #체크포인트 저장 경로
    save_strategy="epoch", #체크포인트 저장 간격
    eval_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.001, # 가중치 감쇠
    load_best_model_at_end=True,
    metric_for_best_model="f1", # 매크로 평균 F점수
    logging_dir="logs",
    logging_steps=125, # 로그 출력 간격
    remove_unused_columns=False,
    seed=7,
    report_to=["none"]
)


trainer = Trainer(
    model_init = lambda x: model_init(classes,class_to_idx),
    args=args,
    train_dataset=subset_train_dataset,
    eval_dataset=subset_test_dataset,
    data_collator=lambda x: collator(x, transform),
    processing_class=image_processor,
    compute_metrics=compute_metrics,
    )

trainer.train()