## Library

In [None]:
import PIL
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from matplotlib import pyplot as plt
from timm.data import create_transform, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, Mixup
from timm.loss import SoftTargetCrossEntropy
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torchvision.models import resnet50
import math
import os
from tqdm import tqdm_notebook as tqdm

from puzzle_res50 import PuzzleCNNCoord
from puzzle_vit import PuzzleViT
from util.tester import visualDoubleLoss

import facebook_vit
from mae_util import interpolate_pos_embed, RandomResizedCrop, LARS
from timm.models.layers import trunc_normal_

# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
gpu_ids = []
device_names = []
if torch.cuda.is_available():
    for gpu_id in range(torch.cuda.device_count()):
        gpu_ids += [gpu_id]
        device_names += [torch.cuda.get_device_name(gpu_id)]
print(gpu_ids)
print(device_names)

if len(gpu_ids) > 1:
    gpu = 'cuda:' + str(gpu_ids[1])  # GPU Number
else:
    gpu = "cuda" if torch.cuda.is_available() else "cpu"

## Hyper parameter

In [None]:
device = gpu

'''Pre-training'''
LEARNING_RATE = 1e-05
BATCH_SIZE = 64
NUM_EPOCHS = 100
NUM_WORKERS = 2
TASK_NAME = 'puzzle_imagenet_1000'
MODEL_NAME = 'vitPreFalse'
pre_model_path = f'./save/{TASK_NAME}_{MODEL_NAME}_ep{NUM_EPOCHS}_lr{LEARNING_RATE}_b{BATCH_SIZE}.pt'
pre_load_model_path = './save/xxx.pt'
pre_reload_model_path = './save/xxx.pt'

'''Fine-tuning'''
# AUGMENTATION = True
# LEARNING_RATE = 2e-02
# BATCH_SIZE = 32
# NUM_EPOCHS = 100
# WARMUP_EPOCHS = 5
# NUM_WORKERS = 2
# TASK_NAME = 'classification_ImageNet'
# fine_load_model_path = './save/puzzle_imagenet_1000_vit_ep20_lr7e-05_b64_c.pt'  # duplicate file
# fine_model_path = fine_load_model_path[:-3] + f'___{TASK_NAME}_ep{NUM_EPOCHS}_lr{LEARNING_RATE}_b{BATCH_SIZE}_SGD_aug.pt'
# fine_reload_model_path = './save/xxx.pt'

'''Linear-probing'''
# LEARNING_RATE = 2
# BATCH_SIZE = 32
# NUM_EPOCHS = 100
# WARMUP_EPOCHS = 5
# NUM_WORKERS = 2
# TASK_NAME = 'linear_ImageNet'
# fine_load_model_path = './save/puzzle_imagenet_1000_vit_ep100_lr1e-05_b64_c.pt'  # duplicate file
# fine_model_path = fine_load_model_path[:-3] + f'___{TASK_NAME}_ep{NUM_EPOCHS}_lr{LEARNING_RATE}_b{BATCH_SIZE}_SGD.pt'
# fine_reload_model_path = './save/xxx.pt'

## Dataset

In [None]:
'''Pre-training'''
# transform = transforms.Compose([
#     transforms.Pad(padding=3),
#     transforms.CenterCrop(30),
#     transforms.ToTensor(),
#     transforms.Normalize((0.5,), (0.5,))
# ])

transform = transforms.Compose([
    transforms.Resize(256, interpolation=PIL.Image.BICUBIC),
    transforms.CenterCrop(224),
    transforms.Pad(padding=(0, 0, 1, 1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

'''Fine-tuning'''
# transform = transforms.Compose([
#     transforms.Resize(256, interpolation=PIL.Image.BICUBIC),
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])

# if AUGMENTATION:
#     transform = create_transform(
#         input_size=224,
#         is_training=True,
#         color_jitter=None,
#         auto_augment='rand-m9-mstd0.5-inc1',
#         interpolation='bicubic',
#         re_prob=0.25,
#         re_mode='pixel',
#         re_count=1,
#         mean=IMAGENET_DEFAULT_MEAN,
#         std=IMAGENET_DEFAULT_STD,
#     )

# mixup_fn = Mixup(
#     mixup_alpha=0.8,
#     cutmix_alpha=1.0,
#     cutmix_minmax=None,
#     prob=1.0,
#     switch_prob=0.5,
#     mode='batch',
#     label_smoothing=0.1,
#     num_classes=1000
# )

'''Linear-probing'''
# transform = transforms.Compose([
#             RandomResizedCrop(224, interpolation=PIL.Image.BICUBIC),
#             transforms.RandomHorizontalFlip(),
#             transforms.ToTensor(),
#             transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
# )

# train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
# train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
# val_dataset = Subset(train_dataset, list(range(int(0.2*len(train_dataset)))))
# val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
# test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
# test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

train_dataset = datasets.ImageFolder('../datasets/ImageNet/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, drop_last=True)
val_dataset = Subset(train_dataset, list(range(int(0.01 * len(train_dataset)))))
val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, drop_last=True)
test_dataset = datasets.ImageFolder('../datasets/ImageNet/val', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, drop_last=True)

## Pre-training class

In [None]:
class PreTrainer(object):
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.epochs = []
        self.losses_c = []
        self.losses_t = []
        self.accuracies = []

    def process(self, load=False, reload=False):
        self.build_model(load)
        self.pretrain_model(reload)
        self.save_model()

    def build_model(self, load):
        self.model = PuzzleViT(size_puzzle=75).to(device)
        print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
        if load:
            checkpoint = torch.load(pre_load_model_path)
            self.epochs = checkpoint['epochs']
            self.model.load_state_dict(checkpoint['model'])
            self.losses_c = checkpoint['losses_coord']
            self.losses_t = checkpoint['losses_total']
            print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
            print(f'Epoch: {self.epochs[-1]}')
            print(f'****** Reset epochs and losses ******')
            self.epochs = []
            self.losses_c = []
            self.losses_t = []

    def pretrain_model(self, reload):
        model = self.model
        criterion = nn.SmoothL1Loss()
        optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.05)
        scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
        range_epochs = range(NUM_EPOCHS)
        if reload:
            checkpoint = torch.load(pre_reload_model_path)
            self.model.load_state_dict(checkpoint['model'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.epochs = checkpoint['epochs']
            self.losses_c = checkpoint['losses_coord']
            self.losses_t = checkpoint['losses_total']
            self.accuracies = checkpoint['accuracies']
            range_epochs = range(self.epochs[-1], NUM_EPOCHS)

        model.train()
        for epoch in range_epochs:
            running_loss_c = 0.
            running_loss_t = 0.
            for batch_idx, (inputs, _) in tqdm(enumerate(train_loader, 0), total=len(train_loader)):
                inputs = inputs.to(device)

                optimizer.zero_grad()

                outputs, labels, loss_var = model(inputs)
                loss_coord = criterion(outputs, labels)
                loss = loss_coord + loss_var/1e05
                loss.backward()
                optimizer.step()
                running_loss_c += loss_coord.item()
                running_loss_t += loss.item()

                inter = 100
                if batch_idx % inter == inter - 1:
                    print(f'[Epoch {epoch + 1}] [Batch {batch_idx + 1}] Loss: {running_loss_c / inter:.4f}')
                    print(f'[Epoch {epoch + 1}] [Batch {batch_idx + 1}] Total Loss: {running_loss_t / inter:.4f}')
                    self.epochs.append(epoch + 1)
                    self.losses_c.append(running_loss_c / inter)
                    self.losses_t.append(running_loss_t / inter)
                    running_loss_c = 0.
                    running_loss_t = 0.
            scheduler.step()
            self.model = model
            self.optimizer = optimizer
            self.scheduler = scheduler
            self.save_model()
            visualDoubleLoss(self.losses_c, self.losses_t)
            self.val_model(epoch)
        print('****** Finished Fine-tuning ******')
        self.model = model

    def val_model(self, epoch=-1):
        model = self.model

        model.eval()

        total = 0
        diff = 0
        correct = 0
        with torch.no_grad():
            for batch_idx, (inputs, _) in tqdm(enumerate(val_loader, 0), total=len(val_loader)):
                inputs = inputs.to(device)

                outputs, labels, _ = model(inputs)

                pred = outputs
                total += labels.size(0)
                diff += (torch.dist(pred, labels)).sum().item()
                pred_ = model.mapping(pred)
                labels_ = model.mapping(labels)
                correct += (pred_ == labels_).all(dim=2).sum().item()

        acc = 100 * correct / (total * labels.size(1))
        print(f'[Epoch {epoch + 1}] Avg diff on the test set: {diff / total:.2f}')
        print(f'[Epoch {epoch + 1}] Accuracy on the test set: {acc:.2f}%')
        torch.set_printoptions(precision=2)
        total = labels.size(1)
        correct = (pred_[0] == labels_[0]).all(dim=1).sum().item()
        print(f'[Sample result]')
        print(torch.cat((pred_[0], labels_[0]), dim=1))
        print(f'Accuracy: {100 * correct / total:.2f}%')
        self.accuracies.append(acc)

    def save_model(self):
        checkpoint = {
            'epochs': self.epochs,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'losses_coord': self.losses_c,
            'losses_total': self.losses_t,
            'accuracies': self.accuracies,
        }
        torch.save(checkpoint, pre_model_path)
        # if self.epochs[-1] % 50 == 0:
        #     torch.save(checkpoint, pre_model_path[:-3]+f'_{self.epochs[-1]}l{NUM_EPOCHS}.pt')
        print(f"****** Model checkpoint saved at epochs {self.epochs[-1]} ******")

## Fine tuning

In [None]:
class FineTuner(object):
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.epochs = [0]
        self.losses = [0]
        self.accuracies = [0]

    def process(self, load=False, reload=False):
        self.build_model(load)
        self.finetune_model(reload)
        self.save_model()

    def build_model(self, load):
        self.model = facebook_vit.__dict__['vit_base_patch16'](
            num_classes=1000,
            drop_path_rate=0.1,
            global_pool=True,
        )
        print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
        self.optimizer = optim.SGD(self.model.parameters(), lr=0)
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=NUM_EPOCHS)

        if load:
            checkpoint = torch.load(fine_load_model_path, map_location=device)
            checkpoint_model = checkpoint['model']
            for key in list(checkpoint_model.keys()):
                if key.startswith('vit_features.'):
                    new_key = key.replace('vit_features.', '')
                    checkpoint_model[new_key] = checkpoint_model.pop(key)
            for key in list(checkpoint_model.keys()):
                if key.startswith('norm.'):
                    new_key = key.replace('norm.', 'fc_norm.')
                    checkpoint_model[new_key] = checkpoint_model.pop(key)

            state_dict = self.model.state_dict()
            for k in ['head.weight', 'head.bias']:
                if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
                    print(f"Removing key {k} from pretrained checkpoint")
                    del checkpoint_model[k]
            interpolate_pos_embed(self.model, checkpoint_model)
            msg = self.model.load_state_dict(checkpoint_model, strict=False)
            print(msg)
            trunc_normal_(self.model.head.weight, std=2e-5)
            self.model.to(device)

            if 'given' not in str(fine_load_model_path):
                self.epochs = checkpoint['epochs']
            print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
            print(f'Epoch: {self.epochs[-1]}')
            print(f'****** Reset epochs and losses ******')
            self.epochs = []
            self.losses = []
            self.accuracies = []

    def finetune_model(self, reload):
        model = self.model.train()
        criterion = nn.CrossEntropyLoss()
        if AUGMENTATION:
            criterion = SoftTargetCrossEntropy()
        optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)
        scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
        range_epochs = range(NUM_EPOCHS)
        if reload:
            checkpoint = torch.load(fine_reload_model_path)
            self.model.load_state_dict(checkpoint['model'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.epochs = checkpoint['epochs']
            self.losses = checkpoint['losses']
            self.accuracies = checkpoint['accuracies']
            range_epochs = range(self.epochs[-1], NUM_EPOCHS)

        for epoch in range_epochs:
            if epoch < WARMUP_EPOCHS:
                lr_warmup = ((epoch + 1) / WARMUP_EPOCHS) * LEARNING_RATE
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_warmup
                if epoch + 1 == WARMUP_EPOCHS:
                    scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
            print(f"epoch {epoch + 1} learning rate : {optimizer.param_groups[0]['lr']}")
            running_loss = 0.0
            saving_loss = 0.0
            correct = 0
            total = 0
            for i, data in tqdm(enumerate(train_loader, 0), total=len(train_loader)):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                if AUGMENTATION:
                    inputs, labels = mixup_fn(inputs, labels)

                optimizer.zero_grad()

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                saving_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                if not AUGMENTATION:
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

                inter = 100
                if i % inter == inter - 1:
                    if AUGMENTATION:
                        print(f'[Epoch {epoch}, Batch {i + 1:5d}] loss: {running_loss / 100:.3f}')
                    else:
                        print(f'[Epoch {epoch}, Batch {i + 1:5d}] loss: {running_loss / 100:.3f}, acc: {correct / total * 100:.2f} %')
                        self.accuracies.append(correct / total * 100)
                    self.epochs.append(epoch + 1)
                    self.losses.append(saving_loss / inter)
                    running_loss = 0.0
                    saving_loss = 0.0
                    correct = 0
                    total = 0
                mid_term = len(train_loader)//3
                if i % mid_term == mid_term-1:
                    self.val_model(epoch)
            self.model = model
            self.optimizer = optimizer
            self.scheduler = scheduler
            self.save_model()
            self.val_model(epoch)
            scheduler.step()
        print('****** Finished Fine-tuning ******')

    def val_model(self, epoch=-1):
        self.model.eval()

        correct = 0
        total = 0
        with torch.no_grad():
            for i, data in tqdm(enumerate(val_loader, 0), total=len(val_loader)):
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = self.model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        print(f'[Epoch {epoch + 1}] Accuracy of {len(val_dataset)} test images: {100 * correct / total:.2f} %')

    def save_model(self):
        checkpoint = {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'epochs': self.epochs,
            'losses': self.losses,
            'accuracies': self.accuracies,
        }
        torch.save(checkpoint, fine_model_path)
        #         torch.save(checkpoint, dynamic_model_path+str(self.epochs[-1])+f'_lr{LEARNING_RATE}.pt')
        print(f"****** Model checkpoint saved at epochs {self.epochs[-1]} ******")

## Linear Probing

In [None]:
class LinearProber(object):
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.epochs = [0]
        self.losses = [0]
        self.accuracies = [0]

    def process(self, load=False, reload=False):
        self.build_model(load)
        self.linearprob_model(reload)
        self.save_model()

    def build_model(self, load):
        self.model = facebook_vit.__dict__['vit_base_patch16'](
            num_classes=1000,
            drop_path_rate=0.1,
            global_pool=True,
        )
        print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
        self.optimizer = optim.SGD(self.model.parameters(), lr=0)
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=NUM_EPOCHS)

        if load:
            checkpoint = torch.load(fine_load_model_path, map_location=device)
            checkpoint_model = checkpoint['model']
            for key in list(checkpoint_model.keys()):
                if key.startswith('vit_features.'):
                    new_key = key.replace('vit_features.', '')
                    checkpoint_model[new_key] = checkpoint_model.pop(key)
            for key in list(checkpoint_model.keys()):
                if key.startswith('norm.'):
                    new_key = key.replace('norm.', 'fc_norm.')
                    checkpoint_model[new_key] = checkpoint_model.pop(key)


            state_dict = self.model.state_dict()
            for k in ['head.weight', 'head.bias']:
                if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
                    print(f"Removing key {k} from pretrained checkpoint")
                    del checkpoint_model[k]
            interpolate_pos_embed(self.model, checkpoint_model)
            msg = self.model.load_state_dict(checkpoint_model, strict=False)
            print(msg)
            trunc_normal_(self.model.head.weight, std=2e-5)
            for _, p in self.model.named_parameters():
                p.requires_grad = False
            for _, p in self.model.head.named_parameters():
                p.requires_grad = True
            self.model.to(device)

            if 'given' not in str(fine_load_model_path):
                self.epochs = checkpoint['epochs']
            print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
            print(f'Epoch: {self.epochs[-1]}')
            print(f'****** Reset epochs and losses ******')
            self.epochs = []
            self.losses = []
            self.accuracies = []

    def linearprob_model(self, reload):
        model = self.model.train()
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=0)
        scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
        range_epochs = range(NUM_EPOCHS)
        if reload:
            checkpoint = torch.load(fine_reload_model_path)
            self.model.load_state_dict(checkpoint['model'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.epochs = checkpoint['epochs']
            self.losses = checkpoint['losses']
            self.accuracies = checkpoint['accuracies']
            range_epochs = range(self.epochs[-1], NUM_EPOCHS)

        for epoch in range_epochs:
            if epoch < WARMUP_EPOCHS:
                lr_warmup = ((epoch + 1) / WARMUP_EPOCHS) * LEARNING_RATE
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_warmup
                if epoch + 1 == WARMUP_EPOCHS:
                    scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
            print(f"epoch {epoch + 1} learning rate : {optimizer.param_groups[0]['lr']}")
            running_loss = 0.0
            saving_loss = 0.0
            correct = 0
            total = 0
            for i, data in tqdm(enumerate(train_loader, 0), total=len(train_loader)):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                saving_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                inter = 100
                if i % inter == inter - 1:
                    print(f'[Epoch {epoch}, Batch {i + 1:5d}] loss: {running_loss / 100:.3f}, acc: {correct / total * 100:.2f} %')
                    self.accuracies.append(correct/total*100)
                    self.epochs.append(epoch + 1)
                    self.losses.append(saving_loss/inter)
                    running_loss = 0.0
                    saving_loss = 0.0
                    correct = 0
                    total = 0
                mid_term = len(train_loader)//3
                if i % mid_term == mid_term-1:
                    self.val_model(epoch)
            self.model = model
            self.optimizer = optimizer
            self.scheduler = scheduler
            self.save_model()
            self.val_model(epoch)
            scheduler.step()
        print('****** Finished Fine-tuning ******')

    def val_model(self, epoch=-1):
        self.model.eval()

        correct = 0
        total = 0
        with torch.no_grad():
            for i, data in tqdm(enumerate(val_loader, 0), total=len(val_loader)):
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = self.model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        print(f'[Epoch {epoch + 1}] Accuracy of {len(val_dataset)} test images: {100 * correct / total:.2f}%')

    def save_model(self):
        checkpoint = {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'epochs': self.epochs,
            'losses': self.losses,
            'accuracies': self.accuracies,
        }
        torch.save(checkpoint, fine_model_path)
        #         torch.save(checkpoint, dynamic_model_path+str(self.epochs[-1])+f'_lr{LEARNING_RATE}.pt')
        print(f"****** Model checkpoint saved at epochs {self.epochs[-1]} ******")

In [None]:
if __name__ == '__main__':
    trainer = PreTrainer()
    trainer.process(load=False)

# if __name__ == '__main__':
#     trainer = FineTuner()
#     trainer.process(load=True)

# if __name__ == '__main__':
#     trainer = LinearProber()
#     trainer.process(load=True)

In [None]:
'''
[Pre-training]
epochs 3에서 사전 테스트
-> c=sL1, batch_size=64, lr=6e-06 : 11.31%
-> c=sL1, batch_size=64, lr=7e-06 : 11.32%
-> c=sL1, batch_size=64, lr=8e-06 : 11.32%
-> c=sL1, batch_size=64, lr=9e-06 : 11.36%
-> c=sL1, batch_size=64, lr=1e-05 : 11.36% (best)
-> c=sL1, batch_size=64, lr=2e-05 : 11.29%
-> c=sL1, batch_size=64, lr=3e-05 : 11.30%

-> c=sL1, batch_size=64, lr=1e-05, ratio=7e03 : 11.37%
-> c=sL1, batch_size=64, lr=1e-05, ratio=1e04 : 11.68% (best) ->  50 epochs, X(학습 불안정)
-> c=sL1, batch_size=64, lr=1e-05, ratio=3e04 : 11.39%
-> c=sL1, batch_size=64, lr=1e-05, ratio=1e05 : 11.36% -> 50 epochs, 75.8%-> 100 epochs, 95.4%
-> c=sL1, batch_size=64, lr=1e-05, ratio=3e05 : 11.38%

[Fine-tuning]
epochs 2에서 사전 테스트, (pre 84 epochs)
-> batch_size=64, lr=2e-01 : 19.70%
-> batch_size=64, lr=1e-01 : 20.22%
-> batch_size=64, lr=7e-02 : 23.13%
-> batch_size=64, lr=5e-02 : 23.88%
-> batch_size=64, lr=4e-02 : 25.87% (best) -> 50 epochs X, 100 epochs O
-> batch_size=64, lr=3e-02 : 22.55%
-> batch_size=64, lr=3e-02 : 22.55%

-> batch_size=32, lr=9e-03 : 18.06%
-> batch_size=32, lr=1e-02 : 18.98%
-> batch_size=32, lr=2e-02 : 23.02% (best) -> 100 epochs, (진행중)
-> batch_size=32, lr=3e-02 : 22.01%
-> batch_size=32, lr=4e-02 : 22.12%

[Linear-probing]
epochs 2에서 사전 테스트, LARS
-> batch_size=256, lr=8 : 8.48%
-> batch_size=256, lr=7 : 8.70%
-> batch_size=256, lr=6 : 11.23%
-> batch_size=256, lr=5 : 12.40% (best) -> 100 epochs, 발산 -> warm up 제거, 발산
-> batch_size=256, lr=4 : 12.00%
-> batch_size=256, lr=3 : 11.62%
-> batch_size=32, lr=5e-01 : 12.54% (best) -> 100 epochs, 발산
-> batch_size=32, lr=4e-01 : 8.55%
-> batch_size=32, lr=2e-01 : 9.14%
-> batch_size=32, lr=1e-01 : 6.98%

epochs 2에서 사전 테스트, SGD
-> batch_size=32, lr=3e-02 : 8.16%
-> batch_size=32, lr=2e-02 : 10.84%
-> batch_size=32, lr=1e-02 : 12.94%
-> batch_size=32, lr=9e-03 : 14.62%
-> batch_size=32, lr=8e-03 : 16.46% (best) -> 100 epochs, 정체
-> batch_size=32, lr=7e-03 : 15.00%
-> batch_size=32, lr=6e-03 : 14.45%
-> batch_size=32, lr=5e-03 : 12.62%
-> batch_size=32, lr=4e-03 : 13.67%
-> batch_size=32, lr=3e-03 : 13.34%

epochs 2에서 사전 테스트, SGD, 아키텍처 fine 이랑 동일하게 수정
-> batch_size=32, lr=8e-02 : 13.91%
-> batch_size=32, lr=7e-02 : 13.12%
-> batch_size=32, lr=6e-02 : 15.66%
-> batch_size=32, lr=5e-02 : 19.67% (best) -> 100 epochs, 초반 불안정 -> warm up 제거, -> 50 epochs, 
-> batch_size=32, lr=4e-02 : 17.91% -> 50 epochs, 
-> batch_size=32, lr=3e-02 : 13.22%
-> batch_size=32, lr=2e-02 : 13.32%
-> batch_size=32, lr=1e-02 : 15.11%
-> batch_size=32, lr=9e-03 : 13.34%




[Ablation study]
loss_same 제거: 
MSE: 
+)
'''