## Library

In [1]:
import torch
import torchvision
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.optim import AdamW, SGD
from torch import nn
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR

from vit_paper import ViT

gpu_ids = []
device_names = []
if torch.cuda.is_available():
    for gpu_id in range(torch.cuda.device_count()):
        gpu_ids += [gpu_id]
        device_names += [torch.cuda.get_device_name(gpu_id)]
print(gpu_ids)
print(device_names)

if len(gpu_ids) > 1:
    gpu = 'cuda:' + str(gpu_ids[0])  # GPU Number
else:
    gpu = "cuda" if torch.cuda.is_available() else "cpu"

[0, 1, 2, 3]
['NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090']


## Hyper parameter

In [2]:
device = gpu
BATCH_SIZE = 64
NUM_EPOCHS = 50
WARMUP_EPOCHS = 5
NUM_WORKERS = 2
LEARNING_RATE = 0.0003
pre_model_path = './save/ViT_i2012_ep300_lr0.0003.pt'
load_model_path = './save/ViT_i2012_ep300_lr0.0003.pt'
fine_model_path = './save/ViT_i2012_ep300_lr0.0003_augVanilla_i2012_ep7_lr0.03.pt'

IMAGE_SIZE = 224
PATCH_SIZE = 16
IN_CHANNELS = 3
NUM_CLASSES = 1000
EMBED_DIM = 768
DEPTH = 12
NUM_HEADS = 12
DROP_RATE = 0.1
WEIGHT_DECAY = 0.3

## Dataset

In [3]:
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
transform_test = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
train_set = torchvision.datasets.ImageFolder('./data/ImageNet/val', transform=transform_train)
train_loader = data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_set = torchvision.datasets.ImageFolder('./data/ImageNet/val', transform=transform_test)
test_loader = data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

## Pre-training Class

In [4]:
class PreTrainer(object):
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.epochs = [0]
        self.losses = [0]
        self.accuracies = [0]

    def process(self, load=False):
        self.build_model(load)
        self.pretrain_model()
        self.save_model()

    def build_model(self, load):
        self.model = ViT(image_size=IMAGE_SIZE,
                         patch_size=PATCH_SIZE,
                         in_channels=IN_CHANNELS,
                         num_classes=NUM_CLASSES,
                         embed_dim=EMBED_DIM,
                         depth=DEPTH,
                         num_heads=NUM_HEADS,
                         drop_rate=DROP_RATE,
                         ).to(device)
        if load:
            checkpoint = torch.load(load_model_path)
            self.model.load_state_dict(checkpoint['model'])
            self.epochs = checkpoint['epochs']
            self.losses = checkpoint['losses']
            self.accuracies = checkpoint['accuracies']
            print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
            print(f'Classes: {NUM_CLASSES}')
            print(f'Epoch: {self.epochs[-1]}')
            print(f'****** Reset epochs and losses ******')
            self.epochs = []
            self.losses = []
            self.accuracies = []

    def pretrain_model(self):
        model = self.model.train()
        criterion = nn.CrossEntropyLoss()
        optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

        for epoch in range(NUM_EPOCHS):
            if epoch < WARMUP_EPOCHS:
                lr_warmup = ((epoch + 1) / WARMUP_EPOCHS) * LEARNING_RATE
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_warmup
                if epoch + 1 == WARMUP_EPOCHS:
                    scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
            print(f"epoch {epoch + 1} learning rate : {optimizer.param_groups[0]['lr']}")
            running_loss = 0.0
            saving_loss = 0.0
            correct = 0
            total = 0
            for i, data in tqdm(enumerate(train_loader, 0), total=len(train_loader)):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
                saving_loss = loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
                if i % 100 == 99:
                    print(f'[Epoch {epoch}, Batch {i + 1:5d}] loss: {running_loss / 100:.3f}, acc: {correct/total*100:.2f} %')
                    running_loss = 0.0
                if i % 1000 == 999:
                    self.epochs.append(epoch + 1)
                    self.losses.append(saving_loss/1000)
                    self.accuracies.append(correct/total*100)
                    saving_loss = 0.0
                    correct = 0
                    total = 0
            self.model = model
            self.optimizer = optimizer
            self.scheduler = scheduler
            self.save_model()
            scheduler.step()
        print('****** Finished Fine-tuning ******')

    def save_model(self):
        checkpoint = {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'epochs': self.epochs,
            'losses': self.losses,
            'accuracies': self.accuracies,
        }
        torch.save(checkpoint, pre_model_path)
#         torch.save(checkpoint, dynamic_model_path+str(self.epochs[-1])+f'_lr{LEARNING_RATE}.pt')
        print(f"****** Model checkpoint saved at epochs {self.epochs[-1]} ******")

## Fine-tuning Class

In [5]:
class FineTunner(object):
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.epochs = [0]
        self.losses = [0]
        self.accuracies = [0]

    def process(self):
        self.build_model()
        self.finetune_model()
        self.save_model()

    def build_model(self):
        self.model = ViT(image_size=IMAGE_SIZE,
                         patch_size=PATCH_SIZE,
                         in_channels=IN_CHANNELS,
                         num_classes=NUM_CLASSES,
                         embed_dim=EMBED_DIM,
                         depth=DEPTH,
                         num_heads=NUM_HEADS,
                         drop_rate=DROP_RATE,
                         ).to(device)
        checkpoint = torch.load(load_model_path)
        self.model.load_state_dict(checkpoint['model'])
        self.epochs = checkpoint['epochs']
        self.losses = checkpoint['losses']
        self.accuracies = checkpoint['accuracies']
        print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
        print(f'Classes: {NUM_CLASSES}')
        print(f'Epoch: {self.epochs[-1]}')
        print(f'****** Reset epochs and losses ******')
        self.epochs = []
        self.losses = []
        self.accuracies = []

    def finetune_model(self):
        model = self.model.train()
        criterion = nn.CrossEntropyLoss()
        optimizer = SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)

        for epoch in range(NUM_EPOCHS):
            if epoch < WARMUP_EPOCHS:
                lr_warmup = ((epoch + 1) / WARMUP_EPOCHS) * LEARNING_RATE
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_warmup
                if epoch + 1 == WARMUP_EPOCHS:
                    scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
            print(f"epoch {epoch + 1} learning rate : {optimizer.param_groups[0]['lr']}")
            running_loss = 0.0
            saving_loss = 0.0
            correct = 0
            total = 0
            for i, data in tqdm(enumerate(train_loader, 0), total=len(train_loader)):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
                saving_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                if i % 100 == 99:
                    print(f'[Epoch {epoch}, Batch {i + 1:5d}] loss: {running_loss / 100:.3f}, acc: {correct/total*100:.2f} %')
                    running_loss = 0.0
                if i % 1000 == 999:
                    self.epochs.append(epoch + 1)
                    self.losses.append(saving_loss/1000)
                    self.accuracies.append(correct/total*100)
                    saving_loss = 0.0
                    correct = 0
                    total = 0
            self.model = model
            self.optimizer = optimizer
            self.scheduler = scheduler
            self.save_model()
            scheduler.step()
        print('****** Finished Fine-tuning ******')

    def save_model(self):
        checkpoint = {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'epochs': self.epochs,
            'losses': self.losses,
            'accuracies': self.accuracies,
        }
        torch.save(checkpoint, fine_model_path)
#         torch.save(checkpoint, dynamic_model_path+str(self.epochs[-1])+f'_lr{LEARNING_RATE}.pt')
        print(f"****** Model checkpoint saved at epochs {self.epochs[-1]} ******")

In [None]:
if __name__ == '__main__':
    trainer = PreTrainer()
    trainer.process(load=False)

  0%|          | 100/20019 [00:38<2:04:43,  2.66it/s]

[Epoch 0, Batch   100] loss: 6.950


  1%|          | 200/20019 [01:17<2:07:30,  2.59it/s]

[Epoch 0, Batch   200] loss: 6.895


  1%|▏         | 300/20019 [01:55<2:03:21,  2.66it/s]

[Epoch 0, Batch   300] loss: 6.841


  2%|▏         | 400/20019 [02:33<2:02:31,  2.67it/s]

[Epoch 0, Batch   400] loss: 6.794


  2%|▏         | 500/20019 [03:10<2:02:00,  2.67it/s]

[Epoch 0, Batch   500] loss: 6.746


  3%|▎         | 600/20019 [03:48<2:01:19,  2.67it/s]

[Epoch 0, Batch   600] loss: 6.702


  3%|▎         | 700/20019 [04:25<2:01:10,  2.66it/s]

[Epoch 0, Batch   700] loss: 6.669


  4%|▍         | 800/20019 [05:07<2:15:34,  2.36it/s]

[Epoch 0, Batch   800] loss: 6.658


  4%|▍         | 900/20019 [05:49<2:22:15,  2.24it/s]

[Epoch 0, Batch   900] loss: 6.627


  5%|▍         | 999/20019 [06:32<2:10:34,  2.43it/s]

[Epoch 0, Batch  1000] loss: 6.587


  5%|▍         | 1000/20019 [06:35<6:58:10,  1.32s/it]

****** Model checkpoint saved at epochs 1 ******


  5%|▌         | 1100/20019 [07:25<2:08:30,  2.45it/s]

[Epoch 0, Batch  1100] loss: 6.551


  6%|▌         | 1200/20019 [08:07<2:30:50,  2.08it/s]

[Epoch 0, Batch  1200] loss: 6.536


  6%|▋         | 1300/20019 [08:50<2:25:16,  2.15it/s]

[Epoch 0, Batch  1300] loss: 6.525


  7%|▋         | 1400/20019 [09:32<2:07:17,  2.44it/s]

[Epoch 0, Batch  1400] loss: 6.490


  7%|▋         | 1500/20019 [10:14<2:06:30,  2.44it/s]

[Epoch 0, Batch  1500] loss: 6.477


  8%|▊         | 1600/20019 [10:55<2:09:31,  2.37it/s]

[Epoch 0, Batch  1600] loss: 6.439


  8%|▊         | 1700/20019 [11:38<2:02:00,  2.50it/s]

[Epoch 0, Batch  1700] loss: 6.415


  9%|▉         | 1800/20019 [12:20<2:21:04,  2.15it/s]

[Epoch 0, Batch  1800] loss: 6.418


  9%|▉         | 1900/20019 [13:03<2:01:43,  2.48it/s]

[Epoch 0, Batch  1900] loss: 6.377


 10%|▉         | 1999/20019 [13:44<2:11:24,  2.29it/s]

[Epoch 0, Batch  2000] loss: 6.374


 10%|▉         | 2000/20019 [13:48<6:12:38,  1.24s/it]

****** Model checkpoint saved at epochs 1 ******


 10%|█         | 2100/20019 [14:32<2:15:19,  2.21it/s]

[Epoch 0, Batch  2100] loss: 6.317


 11%|█         | 2200/20019 [15:17<2:00:50,  2.46it/s]

[Epoch 0, Batch  2200] loss: 6.297


 11%|█▏        | 2300/20019 [15:59<2:13:27,  2.21it/s]

[Epoch 0, Batch  2300] loss: 6.298


 12%|█▏        | 2400/20019 [16:40<1:58:48,  2.47it/s]

[Epoch 0, Batch  2400] loss: 6.265


 12%|█▏        | 2500/20019 [17:22<2:02:40,  2.38it/s]

[Epoch 0, Batch  2500] loss: 6.268


 13%|█▎        | 2600/20019 [18:03<1:52:26,  2.58it/s]

[Epoch 0, Batch  2600] loss: 6.228


 13%|█▎        | 2700/20019 [18:44<1:53:47,  2.54it/s]

[Epoch 0, Batch  2700] loss: 6.237


 14%|█▍        | 2800/20019 [19:25<1:50:59,  2.59it/s]

[Epoch 0, Batch  2800] loss: 6.203


 14%|█▍        | 2900/20019 [20:05<1:58:56,  2.40it/s]

[Epoch 0, Batch  2900] loss: 6.175


 15%|█▍        | 2999/20019 [20:49<1:52:03,  2.53it/s]

[Epoch 0, Batch  3000] loss: 6.189


 15%|█▍        | 3000/20019 [20:52<6:04:15,  1.28s/it]

****** Model checkpoint saved at epochs 1 ******


 15%|█▌        | 3100/20019 [21:41<1:54:29,  2.46it/s]

[Epoch 0, Batch  3100] loss: 6.170


 16%|█▌        | 3200/20019 [22:21<1:48:33,  2.58it/s]

[Epoch 0, Batch  3200] loss: 6.131


 16%|█▋        | 3300/20019 [23:02<1:48:21,  2.57it/s]

[Epoch 0, Batch  3300] loss: 6.118


 17%|█▋        | 3400/20019 [23:44<1:53:03,  2.45it/s]

[Epoch 0, Batch  3400] loss: 6.087


 17%|█▋        | 3500/20019 [24:25<1:47:56,  2.55it/s]

[Epoch 0, Batch  3500] loss: 6.063


 18%|█▊        | 3600/20019 [25:05<1:46:45,  2.56it/s]

[Epoch 0, Batch  3600] loss: 6.055


 18%|█▊        | 3700/20019 [25:46<1:56:04,  2.34it/s]

[Epoch 0, Batch  3700] loss: 6.026


 19%|█▉        | 3800/20019 [26:26<1:41:44,  2.66it/s]

[Epoch 0, Batch  3800] loss: 6.030


 19%|█▉        | 3900/20019 [27:05<1:41:46,  2.64it/s]

[Epoch 0, Batch  3900] loss: 6.007


 20%|█▉        | 3999/20019 [27:45<1:42:18,  2.61it/s]

[Epoch 0, Batch  4000] loss: 5.973


 20%|█▉        | 4000/20019 [27:48<5:33:00,  1.25s/it]

****** Model checkpoint saved at epochs 1 ******


 20%|██        | 4100/20019 [28:33<1:48:16,  2.45it/s]

[Epoch 0, Batch  4100] loss: 5.958


 21%|██        | 4200/20019 [29:15<1:43:28,  2.55it/s]

[Epoch 0, Batch  4200] loss: 5.961


 21%|██▏       | 4300/20019 [29:55<1:37:52,  2.68it/s]

[Epoch 0, Batch  4300] loss: 5.907


 22%|██▏       | 4400/20019 [30:35<1:38:12,  2.65it/s]

[Epoch 0, Batch  4400] loss: 5.912


 22%|██▏       | 4500/20019 [31:14<1:37:03,  2.67it/s]

[Epoch 0, Batch  4500] loss: 5.908


 23%|██▎       | 4600/20019 [31:52<1:36:13,  2.67it/s]

[Epoch 0, Batch  4600] loss: 5.874


 23%|██▎       | 4700/20019 [32:32<1:37:52,  2.61it/s]

[Epoch 0, Batch  4700] loss: 5.865


 24%|██▍       | 4800/20019 [33:11<1:43:48,  2.44it/s]

[Epoch 0, Batch  4800] loss: 5.837


 24%|██▍       | 4900/20019 [33:50<1:34:36,  2.66it/s]

[Epoch 0, Batch  4900] loss: 5.809


 25%|██▍       | 4999/20019 [34:30<1:36:37,  2.59it/s]

[Epoch 0, Batch  5000] loss: 5.786


 25%|██▍       | 5000/20019 [34:33<5:20:27,  1.28s/it]

****** Model checkpoint saved at epochs 1 ******


 25%|██▌       | 5100/20019 [35:16<1:40:58,  2.46it/s]

[Epoch 0, Batch  5100] loss: 5.759


 26%|██▌       | 5200/20019 [35:59<1:41:28,  2.43it/s]

[Epoch 0, Batch  5200] loss: 5.791


 26%|██▋       | 5300/20019 [36:41<1:38:33,  2.49it/s]

[Epoch 0, Batch  5300] loss: 5.779


 27%|██▋       | 5400/20019 [37:25<1:39:44,  2.44it/s]

[Epoch 0, Batch  5400] loss: 5.759


 27%|██▋       | 5500/20019 [38:07<1:42:39,  2.36it/s]

[Epoch 0, Batch  5500] loss: 5.731


 28%|██▊       | 5600/20019 [38:52<1:45:16,  2.28it/s]

[Epoch 0, Batch  5600] loss: 5.704


 28%|██▊       | 5700/20019 [39:35<1:43:35,  2.30it/s]

[Epoch 0, Batch  5700] loss: 5.721


 29%|██▉       | 5800/20019 [40:20<1:52:16,  2.11it/s]

[Epoch 0, Batch  5800] loss: 5.702


 29%|██▉       | 5900/20019 [41:03<1:41:25,  2.32it/s]

[Epoch 0, Batch  5900] loss: 5.664


 30%|██▉       | 5999/20019 [41:46<1:33:49,  2.49it/s]

[Epoch 0, Batch  6000] loss: 5.663


 30%|██▉       | 6000/20019 [41:50<5:00:25,  1.29s/it]

****** Model checkpoint saved at epochs 1 ******


 30%|███       | 6100/20019 [42:36<1:36:55,  2.39it/s]

[Epoch 0, Batch  6100] loss: 5.631


 31%|███       | 6200/20019 [43:23<1:35:38,  2.41it/s]

[Epoch 0, Batch  6200] loss: 5.679


 31%|███▏      | 6300/20019 [44:06<1:34:11,  2.43it/s]

[Epoch 0, Batch  6300] loss: 5.666


 32%|███▏      | 6400/20019 [44:48<1:33:10,  2.44it/s]

[Epoch 0, Batch  6400] loss: 5.615


 32%|███▏      | 6500/20019 [45:30<1:37:43,  2.31it/s]

[Epoch 0, Batch  6500] loss: 5.602


 33%|███▎      | 6600/20019 [46:12<1:43:49,  2.15it/s]

[Epoch 0, Batch  6600] loss: 5.595


 33%|███▎      | 6700/20019 [46:54<1:28:03,  2.52it/s]

[Epoch 0, Batch  6700] loss: 5.592


 34%|███▍      | 6800/20019 [47:36<1:38:56,  2.23it/s]

[Epoch 0, Batch  6800] loss: 5.588


 34%|███▍      | 6900/20019 [48:17<1:27:42,  2.49it/s]

[Epoch 0, Batch  6900] loss: 5.560


 35%|███▍      | 6999/20019 [49:00<1:35:25,  2.27it/s]

[Epoch 0, Batch  7000] loss: 5.540


 35%|███▍      | 7000/20019 [49:03<5:08:42,  1.42s/it]

****** Model checkpoint saved at epochs 1 ******


 35%|███▌      | 7100/20019 [49:48<1:24:59,  2.53it/s]

[Epoch 0, Batch  7100] loss: 5.531


 36%|███▌      | 7200/20019 [50:33<1:24:08,  2.54it/s]

[Epoch 0, Batch  7200] loss: 5.563


 36%|███▋      | 7300/20019 [51:15<1:41:46,  2.08it/s]

[Epoch 0, Batch  7300] loss: 5.522


 37%|███▋      | 7400/20019 [51:56<1:19:25,  2.65it/s]

[Epoch 0, Batch  7400] loss: 5.488


 37%|███▋      | 7500/20019 [52:36<1:33:48,  2.22it/s]

[Epoch 0, Batch  7500] loss: 5.525


 38%|███▊      | 7600/20019 [53:16<1:18:52,  2.62it/s]

[Epoch 0, Batch  7600] loss: 5.527


 38%|███▊      | 7700/20019 [53:57<1:25:16,  2.41it/s]

[Epoch 0, Batch  7700] loss: 5.499


 39%|███▉      | 7800/20019 [54:40<1:20:22,  2.53it/s]

[Epoch 0, Batch  7800] loss: 5.471


 39%|███▉      | 7900/20019 [55:22<1:22:01,  2.46it/s]

[Epoch 0, Batch  7900] loss: 5.519


 40%|███▉      | 7999/20019 [56:08<1:35:22,  2.10it/s]

[Epoch 0, Batch  8000] loss: 5.463


 40%|███▉      | 8000/20019 [56:11<4:32:06,  1.36s/it]

****** Model checkpoint saved at epochs 1 ******


 40%|████      | 8100/20019 [57:06<1:36:13,  2.06it/s]

[Epoch 0, Batch  8100] loss: 5.492


 41%|████      | 8200/20019 [58:03<1:42:06,  1.93it/s]

[Epoch 0, Batch  8200] loss: 5.453


 41%|████▏     | 8300/20019 [58:55<1:52:03,  1.74it/s]

[Epoch 0, Batch  8300] loss: 5.441


 42%|████▏     | 8400/20019 [59:49<1:50:36,  1.75it/s]

[Epoch 0, Batch  8400] loss: 5.404


 42%|████▏     | 8500/20019 [1:00:40<1:42:16,  1.88it/s]

[Epoch 0, Batch  8500] loss: 5.435


 43%|████▎     | 8600/20019 [1:01:32<1:30:42,  2.10it/s]

[Epoch 0, Batch  8600] loss: 5.430


 43%|████▎     | 8700/20019 [1:02:24<1:37:13,  1.94it/s]

[Epoch 0, Batch  8700] loss: 5.396


 44%|████▍     | 8800/20019 [1:03:15<1:40:06,  1.87it/s]

[Epoch 0, Batch  8800] loss: 5.416


 44%|████▍     | 8900/20019 [1:04:07<1:40:25,  1.85it/s]

[Epoch 0, Batch  8900] loss: 5.405


 45%|████▍     | 8999/20019 [1:04:58<1:29:20,  2.06it/s]

[Epoch 0, Batch  9000] loss: 5.346


 45%|████▍     | 9000/20019 [1:05:01<4:08:42,  1.35s/it]

****** Model checkpoint saved at epochs 1 ******


 45%|████▌     | 9100/20019 [1:05:51<1:24:54,  2.14it/s]

[Epoch 0, Batch  9100] loss: 5.388


 46%|████▌     | 9200/20019 [1:06:49<1:29:25,  2.02it/s]

[Epoch 0, Batch  9200] loss: 5.359


 46%|████▋     | 9300/20019 [1:07:39<1:37:01,  1.84it/s]

[Epoch 0, Batch  9300] loss: 5.337


 47%|████▋     | 9400/20019 [1:08:29<1:19:37,  2.22it/s]

[Epoch 0, Batch  9400] loss: 5.309


 47%|████▋     | 9500/20019 [1:09:19<1:22:29,  2.13it/s]

[Epoch 0, Batch  9500] loss: 5.348


 48%|████▊     | 9600/20019 [1:10:07<1:23:54,  2.07it/s]

[Epoch 0, Batch  9600] loss: 5.332


 48%|████▊     | 9700/20019 [1:10:58<1:22:49,  2.08it/s]

[Epoch 0, Batch  9700] loss: 5.383


 49%|████▉     | 9800/20019 [1:11:46<1:21:21,  2.09it/s]

[Epoch 0, Batch  9800] loss: 5.345


 49%|████▉     | 9900/20019 [1:12:35<1:16:48,  2.20it/s]

[Epoch 0, Batch  9900] loss: 5.324


 50%|████▉     | 9999/20019 [1:13:23<1:20:37,  2.07it/s]

[Epoch 0, Batch 10000] loss: 5.308


 50%|████▉     | 10000/20019 [1:13:26<3:39:09,  1.31s/it]

****** Model checkpoint saved at epochs 1 ******


 50%|█████     | 10019/20019 [1:13:35<1:24:48,  1.97it/s]