# Trainer

## Library

In [1]:
import torch
import torchvision
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.optim import AdamW, SGD
from torch import nn
from torch.utils.data import random_split
from tqdm import tqdm, tqdm_notebook
from torch.optim.lr_scheduler import CosineAnnealingLR
from functools import partial

import facebook_vit
from mae_util import interpolate_pos_embed
from timm.models.layers import trunc_normal_
from facebook_mae import MaskedAutoencoderViT

import facebook_mae
import mae_misc as misc
from mae_misc import NativeScalerWithGradNormCount as NativeScaler


gpu_ids = []
device_names = []
if torch.cuda.is_available():
    for gpu_id in range(torch.cuda.device_count()):
        gpu_ids += [gpu_id]
        device_names += [torch.cuda.get_device_name(gpu_id)]
print(gpu_ids)
print(device_names)

if len(gpu_ids) > 1:
    gpu = 'cuda:' + str(gpu_ids[0])  # GPU Number
else:
    gpu = "cuda" if torch.cuda.is_available() else "cpu"

[0, 1, 2, 3]
['NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090']


## Hyper parameter

In [7]:
device = gpu
BATCH_SIZE = 64  # 1024
NUM_EPOCHS = 15  # 100
WARMUP_EPOCHS = 5  # 5
NUM_WORKERS = 2
LEARNING_RATE = 4e-06  # paper: 1e-03 -> implementation: 5e-04
pre_model_path = f'./save/mae_base_i2012_ep{NUM_EPOCHS}_lr{LEARNING_RATE}.pt'
load_model_path = './save/MAE/mae_finetuned_vit_base_given.pth'
fine_model_path = f'./save/mae_vit_base_i2012_ep{NUM_EPOCHS}_lr{LEARNING_RATE}.pt'
dynamic_model_path = f'./save/mae_vit_base_i2012_ep'

## Dataset

In [8]:
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
transform_test = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
train_set = torchvision.datasets.ImageFolder('../datasets/ImageNet/train', transform=transform_train)
train_loader = data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_set = torchvision.datasets.ImageFolder('../datasets/ImageNet/val', transform=transform_test)
test_loader = data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

## Fine-tuning Class

In [9]:
class FineTuner(object):
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.epochs = [0]
        self.losses = [0]
        self.accuracies = [0]

    def process(self, load=False):
        self.build_model(load)
        self.finetune_model()
        self.save_model()

    def build_model(self, load):
        self.model = facebook_vit.__dict__['vit_base_patch16'](
            num_classes=1000,
            drop_path_rate=0.1,
            )
        print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
        self.optimizer = SGD(self.model.parameters(), lr=0)
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=NUM_EPOCHS)

        if load:
            checkpoint = torch.load(load_model_path)
            checkpoint_model = checkpoint['model']
            state_dict = self.model.state_dict()
            for k in ['head.weight', 'head.bias']:
                if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
                    print(f"Removing key {k} from pretrained checkpoint")
                    del checkpoint_model[k]
            interpolate_pos_embed(self.model, checkpoint_model)
            msg = self.model.load_state_dict(checkpoint_model, strict=False)
            print(msg)
            trunc_normal_(self.model.head.weight, std=2e-5)
            self.model.to(device)

            if 'given' not in str(load_model_path):
                self.epochs = checkpoint['epochs']
                self.losses = checkpoint['losses']
                self.accuracies = checkpoint['accuracies']
            print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
            print(f'Epoch: {self.epochs[-1]}')
            print(f'****** Reset epochs and losses ******')
            self.epochs = []
            self.losses = []
            self.accuracies = []

    def finetune_model(self):
        model = self.model.train()
        criterion = nn.CrossEntropyLoss()
        optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.05)
        scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

        for epoch in range(NUM_EPOCHS):
            if epoch < WARMUP_EPOCHS:
                lr_warmup = ((epoch + 1) / WARMUP_EPOCHS) * LEARNING_RATE
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_warmup
                if epoch + 1 == WARMUP_EPOCHS:
                    scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
            print(f"epoch {epoch+1} learning rate : {optimizer.param_groups[0]['lr']}")
            running_loss = 0.0
            saving_loss = 0.0
            correct = 0
            total = 0
            for i, data in tqdm_notebook(enumerate(train_loader, 0), total=len(train_loader)):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                saving_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                if i % 100 == 99:
                    print(f'[Epoch {epoch}, Batch {i + 1:5d}] loss: {running_loss / 100:.3f}, acc: {correct/total*100:.2f} %')
                    running_loss = 0.0
                if i % 1000 == 999:
                    self.epochs.append(epoch + 1)
                    self.losses.append(saving_loss/1000)
                    self.accuracies.append(correct/total*100)
                    saving_loss = 0.0
                    correct = 0
                    total = 0
            self.model = model
            self.optimizer = optimizer
            self.scheduler = scheduler
            self.save_model()
            scheduler.step()
        print('****** Finished Fine-tuning ******')

    def save_model(self):
        checkpoint = {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'epochs': self.epochs,
            'losses': self.losses,
            'accuracies': self.accuracies,
        }
        torch.save(checkpoint, fine_model_path)
#         torch.save(checkpoint, dynamic_model_path+str(self.epochs[-1])+f'_lr{LEARNING_RATE}.pt')
        print(f"****** Model checkpoint saved at epochs {self.epochs[-1]} ******")

## Pre-training Class

In [10]:
class PreTrainer(object):
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.epochs = [0]
        self.losses = [0]
        self.accuracies = [0]

    def process(self, load=False):
        self.build_model(load)
        self.pretrain_model()
        self.save_model()

    def build_model(self, load):
        self.model = facebook_mae.__dict__['mae_vit_base_patch16_dec512d8b'](norm_pix_loss=False).to(device)
        print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
        self.optimizer = SGD(self.model.parameters(), lr=0)
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=NUM_EPOCHS)

        if load:
            checkpoint = torch.load(load_model_path)
            self.model.load_state_dict(checkpoint['model'])
            if 'given' not in str(load_model_path):
                self.epochs = checkpoint['epochs']
                self.losses = checkpoint['losses']
                self.accuracies = checkpoint['accuracies']
            print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
            print(f'Epoch: {self.epochs[-1]}')
            print(f'****** Reset epochs and losses ******')
            self.epochs = []
            self.losses = []
            self.accuracies = []

    def pretrain_model(self):
        model = self.model.train()
        optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.95))
        scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
        loss_scaler = NativeScaler()

        for epoch in range(NUM_EPOCHS):
            if epoch < WARMUP_EPOCHS:
                lr_warmup = ((epoch + 1) / WARMUP_EPOCHS) * LEARNING_RATE
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_warmup
                if epoch + 1 == WARMUP_EPOCHS:
                    scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
            print(f"epoch {epoch + 1} learning rate : {optimizer.param_groups[0]['lr']}")
            running_loss = 0.0
            saving_loss = 0.0
            for i, data in tqdm_notebook(enumerate(train_loader, 0), total=len(train_loader)):
                samples, _ = data
                samples = samples.to(device, non_blocking=True)

                optimizer.zero_grad()

                loss, _, _ = model(samples, mask_ratio=.75)
                loss_scaler(loss, optimizer, parameters=model.parameters(), update_grad=True)
                # Scaler include loss.backward() and optimizer.step()

                running_loss += loss.item()
                saving_loss += loss.item()

                if i % 100 == 99:
                    print(f'[Epoch {epoch}, Batch {i + 1:5d}] loss: {running_loss / 100:.3f}')
                    running_loss = 0.0
                if i % 1000 == 999:
                    self.epochs.append(epoch + 1)
                    self.losses.append(saving_loss/1000)
                    saving_loss = 0.0
            self.model = model
            self.optimizer = optimizer
            self.scheduler = scheduler
            self.save_model()
            scheduler.step()
        print('****** Finished Fine-tuning ******')

    def save_model(self):
        checkpoint = {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'epochs': self.epochs,
            'losses': self.losses,
        }
        torch.save(checkpoint, pre_model_path)
#         torch.save(checkpoint, dynamic_model_path+str(self.epochs[-1])+f'_lr{LEARNING_RATE}.pt')
        print(f"****** Model checkpoint saved at epochs {self.epochs[-1]} ******")


In [None]:
if __name__ == '__main__':
    trainer = FineTuner()
    trainer.process(load=True)

Parameter: 86567656
_IncompatibleKeys(missing_keys=['norm.weight', 'norm.bias'], unexpected_keys=['fc_norm.weight', 'fc_norm.bias'])
Parameter: 86567656
Epoch: 0
****** Reset epochs and losses ******
epoch 1 learning rate : 8e-07


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 0, Batch   100] loss: 6.919, acc: 0.08 %
[Epoch 0, Batch   200] loss: 6.919, acc: 0.09 %
[Epoch 0, Batch   300] loss: 6.919, acc: 0.08 %
[Epoch 0, Batch   400] loss: 6.916, acc: 0.10 %
[Epoch 0, Batch   500] loss: 6.912, acc: 0.09 %
[Epoch 0, Batch   600] loss: 6.897, acc: 0.09 %
[Epoch 0, Batch   700] loss: 6.883, acc: 0.09 %
[Epoch 0, Batch   800] loss: 6.877, acc: 0.09 %
[Epoch 0, Batch   900] loss: 6.871, acc: 0.09 %
[Epoch 0, Batch  1000] loss: 6.860, acc: 0.10 %
[Epoch 0, Batch  1100] loss: 6.849, acc: 0.17 %
[Epoch 0, Batch  1200] loss: 6.839, acc: 0.15 %
[Epoch 0, Batch  1300] loss: 6.830, acc: 0.15 %
[Epoch 0, Batch  1400] loss: 6.824, acc: 0.14 %
[Epoch 0, Batch  1500] loss: 6.814, acc: 0.14 %
[Epoch 0, Batch  1600] loss: 6.808, acc: 0.14 %
[Epoch 0, Batch  1700] loss: 6.799, acc: 0.16 %
[Epoch 0, Batch  1800] loss: 6.785, acc: 0.16 %
[Epoch 0, Batch  1900] loss: 6.783, acc: 0.17 %
[Epoch 0, Batch  2000] loss: 6.776, acc: 0.18 %
[Epoch 0, Batch  2100] loss: 6.765, acc:

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 1, Batch   100] loss: 5.475, acc: 36.86 %
[Epoch 1, Batch   200] loss: 5.458, acc: 36.35 %
[Epoch 1, Batch   300] loss: 5.461, acc: 36.27 %
[Epoch 1, Batch   400] loss: 5.441, acc: 36.58 %
[Epoch 1, Batch   500] loss: 5.427, acc: 36.92 %
[Epoch 1, Batch   600] loss: 5.421, acc: 37.00 %
[Epoch 1, Batch   700] loss: 5.404, acc: 37.19 %
[Epoch 1, Batch   800] loss: 5.388, acc: 37.39 %
[Epoch 1, Batch   900] loss: 5.382, acc: 37.45 %
[Epoch 1, Batch  1000] loss: 5.368, acc: 37.64 %
[Epoch 1, Batch  1100] loss: 5.360, acc: 39.17 %
[Epoch 1, Batch  1200] loss: 5.347, acc: 38.70 %
[Epoch 1, Batch  1300] loss: 5.329, acc: 38.77 %
[Epoch 1, Batch  1400] loss: 5.331, acc: 38.73 %
[Epoch 1, Batch  1500] loss: 5.313, acc: 38.73 %
[Epoch 1, Batch  1600] loss: 5.301, acc: 38.70 %
[Epoch 1, Batch  1700] loss: 5.293, acc: 38.82 %
[Epoch 1, Batch  1800] loss: 5.280, acc: 38.85 %
[Epoch 1, Batch  1900] loss: 5.264, acc: 38.97 %
[Epoch 1, Batch  2000] loss: 5.258, acc: 38.98 %
[Epoch 1, Batch  210

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 2, Batch   100] loss: 3.578, acc: 60.58 %
[Epoch 2, Batch   200] loss: 3.563, acc: 60.27 %
[Epoch 2, Batch   300] loss: 3.544, acc: 60.54 %
[Epoch 2, Batch   400] loss: 3.532, acc: 60.79 %
[Epoch 2, Batch   500] loss: 3.516, acc: 60.94 %
[Epoch 2, Batch   600] loss: 3.493, acc: 61.16 %
[Epoch 2, Batch   700] loss: 3.510, acc: 61.02 %
[Epoch 2, Batch   800] loss: 3.482, acc: 61.04 %
[Epoch 2, Batch   900] loss: 3.474, acc: 61.03 %
[Epoch 2, Batch  1000] loss: 3.446, acc: 61.10 %
[Epoch 2, Batch  1100] loss: 3.469, acc: 60.36 %
[Epoch 2, Batch  1200] loss: 3.440, acc: 60.36 %
[Epoch 2, Batch  1300] loss: 3.434, acc: 60.71 %
[Epoch 2, Batch  1400] loss: 3.417, acc: 60.87 %
[Epoch 2, Batch  1500] loss: 3.396, acc: 61.06 %
[Epoch 2, Batch  1600] loss: 3.413, acc: 60.96 %
[Epoch 2, Batch  1700] loss: 3.386, acc: 61.14 %
[Epoch 2, Batch  1800] loss: 3.361, acc: 61.26 %
[Epoch 2, Batch  1900] loss: 3.383, acc: 61.21 %
[Epoch 2, Batch  2000] loss: 3.351, acc: 61.27 %
[Epoch 2, Batch  210

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 3, Batch   100] loss: 1.866, acc: 72.92 %
[Epoch 3, Batch   200] loss: 1.857, acc: 73.34 %
[Epoch 3, Batch   300] loss: 1.882, acc: 72.97 %
[Epoch 3, Batch   400] loss: 1.853, acc: 73.02 %
[Epoch 3, Batch   500] loss: 1.867, acc: 72.99 %
[Epoch 3, Batch   600] loss: 1.819, acc: 72.89 %
[Epoch 3, Batch   700] loss: 1.831, acc: 72.85 %
[Epoch 3, Batch   800] loss: 1.849, acc: 72.83 %
[Epoch 3, Batch   900] loss: 1.836, acc: 72.97 %
[Epoch 3, Batch  1000] loss: 1.842, acc: 72.90 %
[Epoch 3, Batch  1100] loss: 1.829, acc: 72.69 %
[Epoch 3, Batch  1200] loss: 1.808, acc: 73.03 %
[Epoch 3, Batch  1300] loss: 1.817, acc: 72.88 %
[Epoch 3, Batch  1400] loss: 1.820, acc: 72.73 %
[Epoch 3, Batch  1500] loss: 1.840, acc: 72.59 %
[Epoch 3, Batch  1600] loss: 1.788, acc: 72.73 %
[Epoch 3, Batch  1700] loss: 1.765, acc: 72.88 %
[Epoch 3, Batch  1800] loss: 1.766, acc: 73.01 %
[Epoch 3, Batch  1900] loss: 1.798, acc: 73.04 %
[Epoch 3, Batch  2000] loss: 1.797, acc: 73.03 %
[Epoch 3, Batch  210

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 4, Batch   100] loss: 1.166, acc: 78.12 %
[Epoch 4, Batch   200] loss: 1.184, acc: 77.54 %
[Epoch 4, Batch   300] loss: 1.161, acc: 77.72 %
[Epoch 4, Batch   400] loss: 1.181, acc: 77.51 %
[Epoch 4, Batch   500] loss: 1.232, acc: 77.26 %
[Epoch 4, Batch   600] loss: 1.159, acc: 77.33 %
[Epoch 4, Batch   700] loss: 1.200, acc: 77.23 %
[Epoch 4, Batch   800] loss: 1.160, acc: 77.24 %
[Epoch 4, Batch   900] loss: 1.180, acc: 77.24 %
[Epoch 4, Batch  1000] loss: 1.151, acc: 77.26 %
[Epoch 4, Batch  1100] loss: 1.155, acc: 78.09 %
[Epoch 4, Batch  1200] loss: 1.138, acc: 77.77 %
[Epoch 4, Batch  1300] loss: 1.161, acc: 77.67 %
[Epoch 4, Batch  1400] loss: 1.179, acc: 77.62 %
[Epoch 4, Batch  1500] loss: 1.127, acc: 77.63 %
[Epoch 4, Batch  1600] loss: 1.164, acc: 77.58 %
[Epoch 4, Batch  1700] loss: 1.159, acc: 77.55 %
[Epoch 4, Batch  1800] loss: 1.134, acc: 77.59 %
[Epoch 4, Batch  1900] loss: 1.138, acc: 77.59 %
[Epoch 4, Batch  2000] loss: 1.191, acc: 77.48 %
[Epoch 4, Batch  210

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 5, Batch   100] loss: 1.015, acc: 78.84 %
[Epoch 5, Batch   200] loss: 1.017, acc: 78.21 %
[Epoch 5, Batch   300] loss: 0.975, acc: 78.42 %
[Epoch 5, Batch   400] loss: 1.018, acc: 78.51 %
[Epoch 5, Batch   500] loss: 0.985, acc: 78.70 %
[Epoch 5, Batch   600] loss: 0.977, acc: 78.77 %
[Epoch 5, Batch   700] loss: 0.997, acc: 78.82 %
[Epoch 5, Batch   800] loss: 0.944, acc: 78.85 %
[Epoch 5, Batch   900] loss: 0.964, acc: 78.92 %
[Epoch 5, Batch  1000] loss: 0.979, acc: 78.95 %
[Epoch 5, Batch  1100] loss: 1.007, acc: 78.72 %
[Epoch 5, Batch  1200] loss: 0.944, acc: 79.04 %
[Epoch 5, Batch  1300] loss: 0.953, acc: 79.37 %
[Epoch 5, Batch  1400] loss: 0.978, acc: 79.52 %
[Epoch 5, Batch  1500] loss: 1.009, acc: 79.19 %
[Epoch 5, Batch  1600] loss: 0.980, acc: 79.12 %
[Epoch 5, Batch  1700] loss: 0.983, acc: 79.11 %
[Epoch 5, Batch  1800] loss: 0.943, acc: 79.14 %
[Epoch 5, Batch  1900] loss: 0.981, acc: 79.15 %
[Epoch 5, Batch  2000] loss: 0.968, acc: 79.19 %
[Epoch 5, Batch  210

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 6, Batch   100] loss: 0.961, acc: 78.34 %
[Epoch 6, Batch   200] loss: 0.914, acc: 79.23 %
[Epoch 6, Batch   300] loss: 0.913, acc: 79.53 %
[Epoch 6, Batch   400] loss: 0.883, acc: 79.67 %
[Epoch 6, Batch   500] loss: 0.906, acc: 79.64 %
[Epoch 6, Batch   600] loss: 0.919, acc: 79.54 %
[Epoch 6, Batch   700] loss: 0.937, acc: 79.40 %
[Epoch 6, Batch   800] loss: 0.898, acc: 79.52 %
[Epoch 6, Batch   900] loss: 0.895, acc: 79.56 %
[Epoch 6, Batch  1000] loss: 0.884, acc: 79.63 %
[Epoch 6, Batch  1100] loss: 0.935, acc: 79.23 %
[Epoch 6, Batch  1200] loss: 0.914, acc: 79.34 %
[Epoch 6, Batch  1300] loss: 0.920, acc: 79.46 %
[Epoch 6, Batch  1400] loss: 0.876, acc: 79.84 %
[Epoch 6, Batch  1500] loss: 0.906, acc: 79.87 %
[Epoch 6, Batch  1600] loss: 0.880, acc: 79.91 %
[Epoch 6, Batch  1700] loss: 0.899, acc: 79.95 %
[Epoch 6, Batch  1800] loss: 0.919, acc: 79.92 %
[Epoch 6, Batch  1900] loss: 0.925, acc: 79.85 %
[Epoch 6, Batch  2000] loss: 0.896, acc: 79.83 %
[Epoch 6, Batch  210

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 7, Batch   100] loss: 0.889, acc: 80.05 %
[Epoch 7, Batch   200] loss: 0.852, acc: 80.50 %
[Epoch 7, Batch   300] loss: 0.851, acc: 80.62 %
[Epoch 7, Batch   400] loss: 0.860, acc: 80.48 %
[Epoch 7, Batch   500] loss: 0.855, acc: 80.41 %
[Epoch 7, Batch   600] loss: 0.927, acc: 80.20 %
[Epoch 7, Batch   700] loss: 0.901, acc: 80.06 %
[Epoch 7, Batch   800] loss: 0.845, acc: 80.09 %
[Epoch 7, Batch   900] loss: 0.883, acc: 80.07 %
[Epoch 7, Batch  1000] loss: 0.855, acc: 80.14 %
[Epoch 7, Batch  1100] loss: 0.887, acc: 80.27 %
[Epoch 7, Batch  1200] loss: 0.900, acc: 80.23 %
[Epoch 7, Batch  1300] loss: 0.870, acc: 80.21 %
[Epoch 7, Batch  1400] loss: 0.840, acc: 80.27 %
[Epoch 7, Batch  1500] loss: 0.867, acc: 80.23 %
[Epoch 7, Batch  1600] loss: 0.843, acc: 80.27 %
[Epoch 7, Batch  1700] loss: 0.850, acc: 80.38 %
[Epoch 7, Batch  1800] loss: 0.875, acc: 80.39 %
[Epoch 7, Batch  1900] loss: 0.873, acc: 80.36 %
[Epoch 7, Batch  2000] loss: 0.905, acc: 80.27 %
[Epoch 7, Batch  210

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 8, Batch   100] loss: 0.836, acc: 81.14 %
[Epoch 8, Batch   200] loss: 0.813, acc: 81.08 %
[Epoch 8, Batch   300] loss: 0.833, acc: 81.07 %
[Epoch 8, Batch   400] loss: 0.834, acc: 80.95 %
[Epoch 8, Batch   500] loss: 0.830, acc: 80.91 %
[Epoch 8, Batch   600] loss: 0.812, acc: 81.03 %
[Epoch 8, Batch   700] loss: 0.858, acc: 80.93 %
[Epoch 8, Batch   800] loss: 0.807, acc: 80.97 %
[Epoch 8, Batch   900] loss: 0.831, acc: 80.96 %
[Epoch 8, Batch  1000] loss: 0.840, acc: 80.93 %
[Epoch 8, Batch  1100] loss: 0.821, acc: 81.02 %
[Epoch 8, Batch  1200] loss: 0.833, acc: 81.04 %
[Epoch 8, Batch  1300] loss: 0.812, acc: 81.15 %
[Epoch 8, Batch  1400] loss: 0.857, acc: 80.95 %
[Epoch 8, Batch  1500] loss: 0.829, acc: 81.02 %
[Epoch 8, Batch  1600] loss: 0.830, acc: 80.92 %
[Epoch 8, Batch  1700] loss: 0.838, acc: 80.85 %
[Epoch 8, Batch  1800] loss: 0.823, acc: 80.88 %
[Epoch 8, Batch  1900] loss: 0.841, acc: 80.86 %
[Epoch 8, Batch  2000] loss: 0.821, acc: 80.86 %
[Epoch 8, Batch  210

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 9, Batch   100] loss: 0.825, acc: 81.19 %
[Epoch 9, Batch   200] loss: 0.804, acc: 80.98 %
[Epoch 9, Batch   300] loss: 0.823, acc: 81.01 %
[Epoch 9, Batch   400] loss: 0.789, acc: 81.18 %
[Epoch 9, Batch   500] loss: 0.844, acc: 80.94 %
[Epoch 9, Batch   600] loss: 0.816, acc: 81.03 %
[Epoch 9, Batch   700] loss: 0.796, acc: 81.02 %
[Epoch 9, Batch   800] loss: 0.780, acc: 81.17 %
[Epoch 9, Batch   900] loss: 0.793, acc: 81.23 %
[Epoch 9, Batch  1000] loss: 0.813, acc: 81.22 %
[Epoch 9, Batch  1100] loss: 0.846, acc: 80.41 %
[Epoch 9, Batch  1200] loss: 0.821, acc: 80.71 %
[Epoch 9, Batch  1300] loss: 0.798, acc: 81.19 %
[Epoch 9, Batch  1400] loss: 0.778, acc: 81.22 %
[Epoch 9, Batch  1500] loss: 0.795, acc: 81.29 %
[Epoch 9, Batch  1600] loss: 0.808, acc: 81.27 %
[Epoch 9, Batch  1700] loss: 0.825, acc: 81.29 %
[Epoch 9, Batch  1800] loss: 0.811, acc: 81.26 %
[Epoch 9, Batch  1900] loss: 0.788, acc: 81.31 %
[Epoch 9, Batch  2000] loss: 0.806, acc: 81.26 %
[Epoch 9, Batch  210

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 10, Batch   100] loss: 0.803, acc: 81.52 %
[Epoch 10, Batch   200] loss: 0.766, acc: 81.72 %
[Epoch 10, Batch   300] loss: 0.781, acc: 81.60 %
[Epoch 10, Batch   400] loss: 0.812, acc: 81.44 %
[Epoch 10, Batch   500] loss: 0.813, acc: 81.46 %
[Epoch 10, Batch   600] loss: 0.786, acc: 81.45 %
[Epoch 10, Batch   700] loss: 0.786, acc: 81.44 %
[Epoch 10, Batch   800] loss: 0.805, acc: 81.49 %
[Epoch 10, Batch   900] loss: 0.771, acc: 81.53 %
[Epoch 10, Batch  1000] loss: 0.780, acc: 81.59 %
[Epoch 10, Batch  1100] loss: 0.815, acc: 81.47 %
[Epoch 10, Batch  1200] loss: 0.806, acc: 81.32 %
[Epoch 10, Batch  1300] loss: 0.801, acc: 81.11 %
[Epoch 10, Batch  1400] loss: 0.783, acc: 81.46 %
[Epoch 10, Batch  1500] loss: 0.803, acc: 81.44 %
[Epoch 10, Batch  1600] loss: 0.803, acc: 81.38 %
[Epoch 10, Batch  1700] loss: 0.824, acc: 81.26 %
[Epoch 10, Batch  1800] loss: 0.842, acc: 81.23 %
[Epoch 10, Batch  1900] loss: 0.780, acc: 81.29 %
[Epoch 10, Batch  2000] loss: 0.772, acc: 81.36 %


  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 11, Batch   100] loss: 0.759, acc: 82.25 %
[Epoch 11, Batch   200] loss: 0.776, acc: 81.94 %
[Epoch 11, Batch   300] loss: 0.755, acc: 82.21 %
[Epoch 11, Batch   400] loss: 0.740, acc: 82.21 %
[Epoch 11, Batch   500] loss: 0.782, acc: 82.03 %
[Epoch 11, Batch   600] loss: 0.771, acc: 82.09 %
[Epoch 11, Batch   700] loss: 0.804, acc: 81.94 %
[Epoch 11, Batch   800] loss: 0.759, acc: 81.96 %
[Epoch 11, Batch   900] loss: 0.787, acc: 81.92 %
[Epoch 11, Batch  1000] loss: 0.801, acc: 81.86 %
[Epoch 11, Batch  1100] loss: 0.791, acc: 81.67 %
[Epoch 11, Batch  1200] loss: 0.752, acc: 82.21 %
[Epoch 11, Batch  1300] loss: 0.794, acc: 82.00 %
[Epoch 11, Batch  1400] loss: 0.757, acc: 81.95 %
[Epoch 11, Batch  1500] loss: 0.788, acc: 81.86 %
[Epoch 11, Batch  1600] loss: 0.809, acc: 81.71 %
[Epoch 11, Batch  1700] loss: 0.762, acc: 81.67 %
[Epoch 11, Batch  1800] loss: 0.771, acc: 81.71 %
[Epoch 11, Batch  1900] loss: 0.780, acc: 81.77 %
[Epoch 11, Batch  2000] loss: 0.809, acc: 81.74 %


  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 12, Batch   100] loss: 0.744, acc: 82.36 %
[Epoch 12, Batch   200] loss: 0.770, acc: 82.17 %
[Epoch 12, Batch   300] loss: 0.748, acc: 82.07 %
[Epoch 12, Batch   400] loss: 0.760, acc: 82.12 %
[Epoch 12, Batch   500] loss: 0.731, acc: 82.31 %
[Epoch 12, Batch   600] loss: 0.732, acc: 82.36 %
[Epoch 12, Batch   700] loss: 0.741, acc: 82.40 %
[Epoch 12, Batch   800] loss: 0.729, acc: 82.47 %
[Epoch 12, Batch   900] loss: 0.759, acc: 82.44 %
[Epoch 12, Batch  1000] loss: 0.767, acc: 82.36 %
[Epoch 12, Batch  1100] loss: 0.755, acc: 82.12 %
[Epoch 12, Batch  1200] loss: 0.771, acc: 81.92 %
[Epoch 12, Batch  1300] loss: 0.796, acc: 81.91 %
[Epoch 12, Batch  1400] loss: 0.775, acc: 81.85 %
[Epoch 12, Batch  1500] loss: 0.781, acc: 81.80 %
[Epoch 12, Batch  1600] loss: 0.733, acc: 81.85 %
[Epoch 12, Batch  1700] loss: 0.762, acc: 81.92 %
[Epoch 12, Batch  1800] loss: 0.756, acc: 81.95 %
[Epoch 12, Batch  1900] loss: 0.776, acc: 81.90 %
[Epoch 12, Batch  2000] loss: 0.750, acc: 81.94 %


In [None]:
if __name__ == '__main__':
    trainer = PreTrainer()
    trainer.process(load=False)