In [1]:
import torch
import torch.nn as nn
from timm import create_model
from torchvision import datasets, transforms
import lightning as L
from torch.utils.data import DataLoader, random_split
from torchmetrics.classification import Accuracy
from lightning.pytorch.loggers import TensorBoardLogger
import warnings
warnings.filterwarnings('ignore')


DEFAULT_TRANSFORM = transforms.Compose([
    transforms.Resize((224, 224)),  # ResNet >> 224, 224 size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# PyTorch Lightning 모델 정의
class LitCIFAR10(L.LightningModule):
    def __init__(self, lr=0.001):
        super().__init__()
        self.save_hyperparameters()
        self.model = create_model('resnet34', pretrained=True, num_classes=10)
        # 모든 파라미터를 freeze
        for param in self.model.parameters():
            param.requires_grad = False

        # classifier만 학습되도록 requires_grad 설정
        for param in self.model.get_classifier().parameters():
            param.requires_grad = True
            
        self.loss_fn = nn.CrossEntropyLoss()
        self.train_acc = Accuracy(task="multiclass", num_classes=10)
        self.val_acc = Accuracy(task="multiclass", num_classes=10)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = self.train_acc(logits, y)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = self.val_acc(logits, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)

# 데이터 로딩을 위한 DataModule 정의
class CIFAR10DataModule(L.LightningDataModule):
    def __init__(self, data_dir="./data", batch_size=256):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def prepare_data(self):
        datasets.CIFAR10(root=self.data_dir, train=True, download=True)
        datasets.CIFAR10(root=self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        full_train = datasets.CIFAR10(root=self.data_dir, train=True,
                                      transform=DEFAULT_TRANSFORM)
        self.train_set, self.val_set = random_split(full_train, [45000, 5000])
        self.test_set = datasets.CIFAR10(root=self.data_dir, train=False,
                                         transform=DEFAULT_TRANSFORM)

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size,
                          shuffle=True, num_workers=5)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size,
                          shuffle=False, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size,
                          shuffle=False, num_workers=4)
    

In [3]:
# 모델 및 데이터 준비
model = LitCIFAR10(lr=0.125)
data = CIFAR10DataModule(batch_size=512)

# TensorBoardLogger 설정
logger = TensorBoardLogger(save_dir="tb_logs", name="resnet34_cifar10")

# Trainer 정의
trainer = L.Trainer(
    max_epochs=10,
    accelerator="auto",
    logger=logger
)

# 모델 학습
trainer.fit(model, datamodule=data)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | model     | ResNet             | 21.3 M | train
1 | loss_fn   | CrossEntropyLoss   | 0      | train
2 | train_acc | MulticlassAccuracy | 0      | train
3 | val_acc   | MulticlassAccuracy | 0      | train
---------------------------------------------------------
5.1 K     Trainable params
21.3 M    Non-trainable params
21.3 M    Total params
85.159    Total estimated model params size (MB)
169       Modules in train mode
0         Modules in eval mode


Epoch 9: 100%|██████████| 88/88 [03:31<00:00,  0.42it/s, v_num=3, train_loss=0.807, train_acc=0.746, val_loss=0.929, val_acc=0.692] 

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 88/88 [03:32<00:00,  0.41it/s, v_num=3, train_loss=0.807, train_acc=0.746, val_loss=0.929, val_acc=0.692]


In [7]:
!tensorboard --logdir=tb_logs

TensorFlow installation not found - running with reduced feature set.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)
^C
