## Library

In [1]:
import timm
import torch
import torchvision
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.optim import AdamW, SGD
from torch import nn
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR

from patch_aug import NegativePatchShuffle, NegativePatchRotate

gpu_ids = []
device_names = []
if torch.cuda.is_available():
    for gpu_id in range(torch.cuda.device_count()):
        gpu_ids += [gpu_id]
        device_names += [torch.cuda.get_device_name(gpu_id)]
print(gpu_ids)
print(device_names)

if len(gpu_ids) > 1:
    gpu = 'cuda:' + str(gpu_ids[2])  # GPU Number
else:
    gpu = "cuda" if torch.cuda.is_available() else "cpu"

[0, 1, 2, 3]
['NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090']


## Hyper parameter

In [2]:
device = gpu
BATCH_SIZE = 64
NUM_EPOCHS = 1
NUM_WORKERS = 2
LEARNING_RATE = 2e-04
pre_model_path = './save/ViT_timm_vit_base_patch16_224_in21k_augNegative_i2012_ep8_lr0.0003.pt'
fine_model_path = f'./save/ViT_timm_vit_base_patch16_224_in21k_augNegative_i2012_ep{NUM_EPOCHS}_lr{LEARNING_RATE}.pt'

NUM_CLASSES = 1000

## Dataset

In [3]:
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
transform_test = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# pre_train_set = torchvision.datasets.ImageFolder('./data/ImageNet-21k', transform=transform_train)
# pre_train_loader = data.DataLoader(pre_train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
train_set = torchvision.datasets.ImageFolder('../../YJ/ILSVRC2012/train', transform=transform_train)
train_loader = data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_set = torchvision.datasets.ImageFolder('../../YJ/ILSVRC2012/val', transform=transform_test)
test_loader = data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

## Fine-tuning Class

In [4]:
class FineTunner(object):
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.epochs = [0]
        self.losses = [0]

    def process(self, load=False):
        self.build_model(load)
        self.finetune_model()
        self.save_model()

    def build_model(self, load):
        self.model = timm.create_model('vit_base_patch16_224_in21k', pretrained=True).to(device)
        self.model.num_classes = NUM_CLASSES
        print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
        print(f'Classes: {self.model.num_classes}')
        self.optimizer = SGD(self.model.parameters(), lr=0)
        if load:
            checkpoint = torch.load(pre_model_path)
            self.epochs = checkpoint['epochs']
            self.model.load_state_dict(checkpoint['model'])
            self.losses = checkpoint['losses']
            print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
            print(f'Classes: {self.model.num_classes}')
            print(f'Epoch: {self.epochs[-1]}')
            print(f'****** Reset epochs and losses ******')
            self.epochs = []
            self.losses = []

    def finetune_model(self):
        model = self.model
        criterion = nn.CrossEntropyLoss()
        optimizer = SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)
        scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
        aug = NegativePatchRotate(p=0.5)

        for epoch in range(NUM_EPOCHS):
            running_loss = 0.0
            saving_loss = 0.0
            for i, data in tqdm(enumerate(train_loader, 0), total=len(train_loader)):
                inputs, labels = data
                aug.roll_the_dice(len(inputs))
                inputs = aug.rotate(inputs)
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                loss = aug.cal_loss(outputs, labels, criterion, device)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                saving_loss = loss.item()
                if i % 100 == 99:
                    print(f'[Epoch {epoch}, Batch {i + 1:5d}] loss: {running_loss / 100:.3f}')
                    running_loss = 0.0
                if i % 1000 == 999:
                    self.epochs.append(epoch + 1)
                    self.model = model
                    self.optimizer = optimizer
                    self.losses.append(saving_loss)
                    self.save_model()
            scheduler.step()
        print('****** Finished Fine-tuning ******')
        self.model = model

    def save_model(self):
        checkpoint = {
            'epochs': self.epochs,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'losses': self.losses,
        }
        torch.save(checkpoint, fine_model_path)
        print(f"****** Model checkpoint saved at epochs {self.epochs[-1]} ******")

In [None]:
if __name__ == '__main__':
    trainer = FineTunner()
    trainer.process(load=True)

Parameter: 102595923
Classes: 1000
Parameter: 102595923
Classes: 1000
Epoch: 8
****** Reset epochs and losses ******


  0%|          | 100/20019 [00:41<2:13:03,  2.49it/s]

[Epoch 0, Batch   100] loss: 0.671


  1%|          | 200/20019 [01:23<2:38:53,  2.08it/s]

[Epoch 0, Batch   200] loss: 0.585


  1%|▏         | 300/20019 [02:06<2:12:52,  2.47it/s]

[Epoch 0, Batch   300] loss: 0.614


  2%|▏         | 400/20019 [02:47<2:37:42,  2.07it/s]

[Epoch 0, Batch   400] loss: 0.685


  2%|▏         | 500/20019 [03:29<2:42:26,  2.00it/s]

[Epoch 0, Batch   500] loss: 0.722


  3%|▎         | 600/20019 [04:11<2:10:27,  2.48it/s]

[Epoch 0, Batch   600] loss: 0.697


  3%|▎         | 700/20019 [04:52<2:10:25,  2.47it/s]

[Epoch 0, Batch   700] loss: 0.709


  4%|▍         | 800/20019 [05:32<2:09:21,  2.48it/s]

[Epoch 0, Batch   800] loss: 0.707


  4%|▍         | 900/20019 [06:14<2:09:08,  2.47it/s]

[Epoch 0, Batch   900] loss: 0.704


  5%|▍         | 999/20019 [06:54<2:07:32,  2.49it/s]

[Epoch 0, Batch  1000] loss: 0.685


  5%|▍         | 1000/20019 [06:56<4:41:56,  1.12it/s]

****** Model checkpoint saved at epochs 1 ******


  5%|▌         | 1100/20019 [07:39<2:06:14,  2.50it/s]

[Epoch 0, Batch  1100] loss: 0.691


  6%|▌         | 1200/20019 [08:22<2:08:10,  2.45it/s]

[Epoch 0, Batch  1200] loss: 0.680


  6%|▋         | 1300/20019 [09:03<2:05:58,  2.48it/s]

[Epoch 0, Batch  1300] loss: 0.742


  7%|▋         | 1400/20019 [09:43<2:04:06,  2.50it/s]

[Epoch 0, Batch  1400] loss: 0.694


  7%|▋         | 1500/20019 [10:26<2:28:28,  2.08it/s]

[Epoch 0, Batch  1500] loss: 0.721


  8%|▊         | 1600/20019 [11:06<2:03:06,  2.49it/s]

[Epoch 0, Batch  1600] loss: 0.718


  8%|▊         | 1700/20019 [11:47<2:02:48,  2.49it/s]

[Epoch 0, Batch  1700] loss: 0.724


  9%|▉         | 1800/20019 [12:29<2:04:24,  2.44it/s]

[Epoch 0, Batch  1800] loss: 0.737


  9%|▉         | 1900/20019 [13:36<3:33:37,  1.41it/s]

[Epoch 0, Batch  1900] loss: 0.701


 10%|▉         | 1999/20019 [14:47<3:30:50,  1.42it/s]

[Epoch 0, Batch  2000] loss: 0.719


 10%|▉         | 2000/20019 [14:50<6:39:34,  1.33s/it]

****** Model checkpoint saved at epochs 1 ******


 10%|█         | 2100/20019 [15:56<2:01:01,  2.47it/s]

[Epoch 0, Batch  2100] loss: 0.665


 11%|█         | 2200/20019 [16:36<1:59:20,  2.49it/s]

[Epoch 0, Batch  2200] loss: 0.734


 11%|█▏        | 2300/20019 [17:20<1:58:29,  2.49it/s]

[Epoch 0, Batch  2300] loss: 0.686


 12%|█▏        | 2400/20019 [18:00<1:58:42,  2.47it/s]

[Epoch 0, Batch  2400] loss: 0.708


 12%|█▏        | 2440/20019 [18:18<4:11:08,  1.17it/s]