In [None]:
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import models
import argparse
from timm.optim import create_optimizer
from timm.utils import accuracy, AverageMeter
from datasets.mnist.mnist_data import get_mnist_dataset
from util.show_images import ShowImages
from util.visualization import Visualization
from torchvision import transforms
import time
import matplotlib.pyplot as plt

In [None]:
def get_args_parser():
  parser = argparse.ArgumentParser(description="Training Config")
  parser.add_argument('--batch-size', default=64, type=int)
  parser.add_argument('--epochs', default=300, type=int)
  parser.add_argument('--device', default='cuda:0',
                        help='device to use for training / testing')
  parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')

  # Model parameters
  parser.add_argument("--input-size", default=None, nargs=3, type=int, help='images input size')

  # Optimizer parameters
  parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
                      help='Optimizer (default: "adamw"')
  parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
                      help='Optimizer Epsilon (default: 1e-8)')
  parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
                      help='Optimizer Betas (default: None, use opt default)')
  parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
                      help='Clip gradient norm (default: None, no clipping)')
  parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                      help='SGD momentum (default: 0.9)')
  parser.add_argument('--weight-decay', type=float, default=0.05,
                      help='weight decay (default: 0.05)')

  # Learning rate schedule parameters
  parser.add_argument("--lr", type=float, default=0.01, metavar="LR")
  
  return parser


In [None]:
args = get_args_parser().parse_args(["--input-size", "3", "32", "32"])
args.epochs = 10
Device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
EPOCHS = args.epochs
BATCH_SIZE = args.batch_size

In [None]:
transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),  # 1チャンネルを3チャンネルに変換
        transforms.Resize((32, 32)),  # ViTモデルは大きな画像が必要
        transforms.ToTensor()
    ])
train_data_loader = get_mnist_dataset(train=True, transform=transform, batch_size=BATCH_SIZE)
test_data_loader = get_mnist_dataset(train=False, transform=transform, batch_size=BATCH_SIZE)
examples = next(iter(train_data_loader))


In [None]:
show_images_instance = ShowImages()

show_images_instance.show_images_loader(dataset_lodaer=train_data_loader, num_images_to_display=30)

In [None]:
# model = models.beit_base_16_224(num_classes=10, img_size=32)
model = models.beit2_base_16_224(num_classes=10, img_size=32)