## Library

In [1]:
import torch
import torchvision
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.optim import Adam, SGD
from torch import nn
from tqdm.notebook import tqdm

from vit_pooling import ViTPooling

gpu_ids = []
device_names = []
if torch.cuda.is_available():
    for gpu_id in range(torch.cuda.device_count()):
        gpu_ids += [gpu_id]
        device_names += [torch.cuda.get_device_name(gpu_id)]
print(gpu_ids)
print(device_names)

if len(gpu_ids) > 1:
    gpu = 'cuda:' + str(gpu_ids[0])  # GPU Number
else:
    gpu = "cuda" if torch.cuda.is_available() else "cpu"

[0, 1, 2, 3]
['TITAN Xp', 'TITAN Xp', 'TITAN Xp', 'TITAN Xp']


## Hyper parameter

In [2]:
device = gpu
pre_model_path = './save/ViT_i2012_ep300_lr0.003_augVanilla.pt'
fine_model_path = './save/ViT_i2012_ep300_lr0.003_augVanilla_i2012_ep8_lr0.03.pt'
BATCH_SIZE = 16
NUM_EPOCHS = 300
NUM_WORKERS = 2
LEARNING_RATE = 0.003

IMAGE_SIZE = 224
PATCH_SIZE = 16
IN_CHANNELS = 3
NUM_CLASSES = 1000
EMBED_DIM = 768
DEPTH = 12
NUM_HEADS = 12

## Dataset

In [3]:
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transform_test = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# pre_train_set = torchvision.datasets.ImageFolder('./data/ImageNet-21k', transform=transform_train)
# pre_train_loader = data.DataLoader(pre_train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
train_set = torchvision.datasets.ImageFolder('../../YJ/ILSVRC2012/train', transform=transform_train)
train_loader = data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_set = torchvision.datasets.ImageFolder('../../YJ/ILSVRC2012/val', transform=transform_test)
test_loader = data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

## Pre-training Class

In [4]:
class PreTrainer(object):
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.epochs = []
        self.losses = []

    def process(self):
        self.build_model()
        self.pretrain_model()
        self.save_model()

    def build_model(self):
        self.model = ViTPooling(image_size=IMAGE_SIZE,
                                patch_size=PATCH_SIZE,
                                in_channels=IN_CHANNELS,
                                num_classes=NUM_CLASSES,
                                embed_dim=EMBED_DIM,
                                depth=DEPTH,
                                num_heads=NUM_HEADS,
                                ).to(device)

    def pretrain_model(self):
        model = self.model
        criterion = nn.CrossEntropyLoss()
        optimizer = Adam(model.parameters(), lr=LEARNING_RATE)

        for epoch in range(NUM_EPOCHS):
            running_loss = 0.0
            saving_loss = 0.0
            for i, data in tqdm(enumerate(train_loader, 0), total=len(train_loader)):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                saving_loss = loss.item()
                if i % 100 == 99:
                    print(f'[Epoch {epoch}, Batch {i + 1:5d}] loss: {running_loss / 100:.3f}')
                    running_loss = 0.0
            self.epochs.append(epoch + 1)
            self.model = model
            self.optimizer = optimizer
            self.losses.append(saving_loss)
            self.save_model()
        print('****** Finished Pre-training ******')
        self.model = model

    def save_model(self):
        checkpoint = {
            'epochs': self.epochs,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'losses': self.losses,
        }
        torch.save(checkpoint, pre_model_path)
        print(f"****** Model checkpoint saved at epochs {self.epochs[-1]} ******")

## Fine-tuning Class

In [5]:
class FineTunner(object):
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.epochs = []
        self.losses = []

    def process(self):
        self.build_model()
        self.finetune_model()
        self.save_model()

    def build_model(self):
        self.model = ViTPooling(image_size=IMAGE_SIZE,
                                patch_size=PATCH_SIZE,
                                in_channels=IN_CHANNELS,
                                num_classes=NUM_CLASSES,
                                embed_dim=EMBED_DIM,
                                depth=DEPTH,
                                num_heads=NUM_HEADS,
                                ).to(device)
        checkpoint = torch.load(pre_model_path)
        self.epochs = checkpoint['epochs']
        self.model.load_state_dict(checkpoint['model'])
        self.losses = checkpoint['losses']
        print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
        print(f'Classes: {NUM_CLASSES}')
        print(f'Epoch: {self.epochs[-1]}')

    def finetune_model(self):
        model = self.model
        criterion = nn.CrossEntropyLoss()
        optimizer = SGD(model.parameters(), lr=LEARNING_RATE)

        for epoch in range(NUM_EPOCHS):
            running_loss = 0.0
            saving_loss = 0.0
            for i, data in tqdm(enumerate(train_loader, 0), total=len(train_loader)):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                saving_loss = loss.item()
                if i % 100 == 99:
                    print(f'[Epoch {epoch}, Batch {i + 1:5d}] loss: {running_loss / 100:.3f}')
                    running_loss = 0.0
            self.epochs.append(epoch + 1)
            self.model = model
            self.optimizer = optimizer
            self.losses.append(saving_loss)
            self.save_model()
        print('****** Finished Pre-training ******')
        self.model = model

    def save_model(self):
        checkpoint = {
            'epochs': self.epochs,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'losses': self.losses,
        }
        torch.save(checkpoint, fine_model_path)
        print(f"****** Model checkpoint saved at epochs {self.epochs[-1]} ******")

In [None]:
if __name__ == '__main__':
    trainer = PreTrainer()
    trainer.process()

  0%|          | 0/80073 [00:00<?, ?it/s]

[Epoch 0, Batch   100] loss: 14.438
[Epoch 0, Batch   200] loss: 7.090
[Epoch 0, Batch   300] loss: 6.919
[Epoch 0, Batch   400] loss: 6.929
[Epoch 0, Batch   500] loss: 6.937
[Epoch 0, Batch   600] loss: 6.918
[Epoch 0, Batch   700] loss: 6.923
[Epoch 0, Batch   800] loss: 6.923
[Epoch 0, Batch   900] loss: 6.922
[Epoch 0, Batch  1000] loss: 6.922
[Epoch 0, Batch  1100] loss: 6.919
[Epoch 0, Batch  1200] loss: 6.912
[Epoch 0, Batch  1300] loss: 6.914
[Epoch 0, Batch  1400] loss: 6.922
[Epoch 0, Batch  1500] loss: 6.921
[Epoch 0, Batch  1600] loss: 6.919
[Epoch 0, Batch  1700] loss: 6.919
[Epoch 0, Batch  1800] loss: 6.911
[Epoch 0, Batch  1900] loss: 6.918
[Epoch 0, Batch  2000] loss: 6.913
[Epoch 0, Batch  2100] loss: 6.914
[Epoch 0, Batch  2200] loss: 6.913
[Epoch 0, Batch  2300] loss: 6.916
[Epoch 0, Batch  2400] loss: 6.915
[Epoch 0, Batch  2500] loss: 6.916
[Epoch 0, Batch  2600] loss: 6.920
[Epoch 0, Batch  2700] loss: 6.913
[Epoch 0, Batch  2800] loss: 6.911
[Epoch 0, Batch  29