## Importing Modules

필요한 모듈을 Import 합니다.


In [2]:
# Modules About Hydra
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from omegaconf import DictConfig

# Modules About Torch
import torch
import torch.nn.functional as F
import torchmetrics
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Modules About Pytorch Lightning
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, LightningDataModule
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

# Modules About HuggingFace Transformers
from transformers import ViTImageProcessor, ViTForImageClassification, AdamW


  from .autonotebook import tqdm as notebook_tqdm


## Configure Dataset

Custom Dataset을 구성합니다.
Data_Module에 있는 CustomDataset을 불러옵니다.


In [11]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size: int = 64, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize((0.1307,), (0.3081,)),
            # ViT expects 224x224 images
            transforms.Resize((224, 224), antialias=True),
            transforms.Lambda(self.repeat_channels)  # ViT expects 3 channels
        ])

    def repeat_channels(self, x):
        return x.repeat(3, 1, 1)

    def setup(self, stage=None):
        self.mnist_train = datasets.MNIST(
            self.data_dir, train=True, download=True, transform=self.transform
        )
        self.mnist_val = datasets.MNIST(
            self.data_dir, train=False, download=True, transform=self.transform
        )

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

In [9]:
class MyModel(pl.LightningModule):
    def __init__(self, config):
        super(MyModel, self).__init__()
        self.config = config
        self.save_hyperparameters()
        self.model = ViTForImageClassification.from_pretrained(
            "google/vit-base-patch16-224")
        self.feature_extractor = ViTImageProcessor.from_pretrained(
            "google/vit-base-patch16-224")

    def forward(self, pixel_values, labels=None):
        output = self.model(pixel_values=pixel_values, labels=labels)
        return output.loss, output.logits

    def training_step(self, batch, batch_idx):
        images, labels = batch
        pixel_values = self.feature_extractor(
            images=images, return_tensors="pt").pixel_values
        loss, logits = self(images, labels)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        pixel_values = self.feature_extractor(
            images=images, return_tensors="pt").pixel_values
        loss, logits = self(images, labels)
        # loss.requires_grad_(True)
        self.log('val_loss', loss, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=2e-5)

In [10]:
# 데이터 로드 및 훈련
data_module = MNISTDataModule()
config = {}
model = MyModel(config=config)
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model, data_module)


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type                      | Params
----------------------------------------------------
0 | model | ViTForImageClassification | 86.6 M
----------------------------------------------------
86.6 M    Trainable params
0         Non-trainable params
86.6 M    Total params
346.271   Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 0:   0%|          | 1/938 [00:02<42:01,  2.69s/it, v_num=21, train_loss=9.330]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
