In [None]:
'''Train ImageNet with PyTorch.'''
import os
import time

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from tqdm import tqdm
from ViT import ViT
import torch.backends.cudnn as cudnn

from torchvision.models import vit_b_16, ViT_B_16_Weights

#os.environ['CUDA_VISIBLE_DEVICES'] = '1'
cudnn.benchmark = True

# if torch.cuda.is_available():
#     device_ids = [0]
#     for device_id in device_ids:
#         torch.cuda.set_device(device_id)
# else:
#     device = 'cpu'

device0 = 'cuda:0' if torch.cuda.is_available() else 'cpu'
#print(device0)
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

temp_dir = 'temp'

# save temp pth
if not os.path.exists(temp_dir):
    os.makedirs(temp_dir)
    
train_directory = './imagenet/train'
test_directory = './imagenet/val'

transform_train = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.ImageFolder(root=train_directory, transform=transform_train)
testset = torchvision.datasets.ImageFolder(root=test_directory, transform=transform_test)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=450, shuffle=True, num_workers = 2, pin_memory=True)
#trainloader = tqdm(trainloader, total=len(trainloader))

testloader = torch.utils.data.DataLoader(
    testset, batch_size=450, shuffle=False,num_workers = 2, pin_memory=True)

# trainloader = tqdm(testloader, total=len(testloader))

def load_state_dict_ignore_mismatch(model, state_dict):
    model_dict = model.state_dict()
   
    matched_state_dict = {k: v for k, v in state_dict.items() if k in model_dict and v.shape == model_dict[k].shape}
    
    model_dict.update(matched_state_dict)

    model.load_state_dict(model_dict)

    return len(matched_state_dict), len(model_dict)



# load resnet101
net = ViT(num_classes=1000, image_size=224, patch_size=16, hidden_dim=768, num_heads=12, num_layers=12, mlp_dim=3072)
#net = ViT(num_classes=1000, img_size=224, patch_size=16, d_model=768, n_head=12, n_layers=12, d_mlp=3072)
import hashlib

def get_model_hash(model):
    md5 = hashlib.md5()
    for param in model.parameters():
        md5.update(param.data.cpu().numpy().tobytes())
    return md5.hexdigest()

# 加载预训练权重前的哈希值
initial_hash = get_model_hash(net)
print(f"Initial model hash: {initial_hash}")

pre = vit_b_16(weights = ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1)
pretrained_state_dict = pre.state_dict()
#net.load_state_dict(state)
#pretrained_state_dict = torch.load('pretrained_image1k.pth')
#net.load_state_dict(pretrained_state_dict)
matched_params, total_params = load_state_dict_ignore_mismatch(net, pretrained_state_dict)
print(f"Loaded {matched_params} out of {total_params} parameters")

loaded_hash = get_model_hash(net)
print(f"Loaded model hash: {loaded_hash}")

# state_dict = torch.load('temp/vit_epoch1.pth')
# new_state_dict = {}
# for k, v in state_dict.items():
#     if k.startswith("module."):
#         new_key = k[len("module."):]
#     else:
#         new_key = k
#     new_state_dict[new_key] = v
#net = nn.DataParallel(net, device_ids=device_ids)
if torch.cuda.device_count() > 1:
    #print(torch.cuda.device_count())
    net = nn.DataParallel(net)
    
net.to(device0)

criterion_1 = nn.CrossEntropyLoss().to(device0)

# define optimizer and loss function
optimizer_1 = optim.SGD([
    {'params': net.module.encoder.layers[:3].parameters()},
], lr=0.001, weight_decay=5e-4,momentum=0.9)  # update first layer

optimizer_2 = optim.SGD([
    {'params': net.module.encoder.layers[3:6].parameters()},
], lr=0.001, weight_decay=5e-4,momentum=0.9)  # update second layer

optimizer_3 = optim.SGD([
     {'params': net.module.encoder.layers[6:9].parameters()},
], lr=0.001, weight_decay=5e-4,momentum=0.9)  # update third layer

optimizer_4 = optim.SGD([
    {'params': net.module.encoder.layers[9:].parameters()},
    {'params': net.module.heads.parameters()},
], lr=0.001,weight_decay=5e-4,momentum=0.9)  # update fourth layer

# for param in net.module.mlp_head2.parameters():
#     param.requires_grad = False

# for param in net.module.mlp_head3.parameters():
#     param.requires_grad = False

# for param in net.module.mlp_head4.parameters():
#     param.requires_grad = False
    
# optimizer_1 = optim.SGD([
#     {'params': net.module.transformer_blocks[:3].parameters()},
#     {'params': net.module.mlp_head1.parameters()}
# ], lr=0.001, momentum=0.9, weight_decay=5e-4)

# optimizer_2 = optim.SGD([
#     {'params': net.module.transformer_blocks[3:6].parameters()},
#      {'params': net.module.mlp_head2.parameters()}
# ], lr=0.001, momentum=0.9, weight_decay=5e-4)

# optimizer_3 = optim.SGD([
#     {'params': net.module.transformer_blocks[6:9].parameters()},
#      {'params': net.module.mlp_head3.parameters()}
# ], lr=0.001, momentum=0.9, weight_decay=5e-4)

# optimizer_4 = optim.SGD([
#     {'params': net.module.transformer_blocks[9:].parameters()},
#      {'params': net.module.mlp_head4.parameters()}
# ], lr=0.001, momentum=0.9, weight_decay=5e-4)

scheduler_1 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_1, T_max=200)
scheduler_2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_2, T_max=200)
scheduler_3 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_3, T_max=200)
scheduler_4 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_4, T_max=200)



train_losses = []
test_losses = []

log_file = open("vit_training_log.txt", "w")

# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for i, (inputs, targets) in tqdm(enumerate(trainloader)):
        inputs, targets = inputs.to(device0), targets.to(device0)
        
        if targets.max() >= 1000 or targets.min() < 0:
            raise ValueError(f"Targets out of range: min {targets.min()}, max {targets.max()}")

        optimizer_1.zero_grad()
        optimizer_2.zero_grad()
        optimizer_3.zero_grad()
        optimizer_4.zero_grad()

        outputs, extra_1, extra_2, extra_3 = net(inputs)
        loss_1 = criterion_1(extra_1, targets)
        #loss_1.backward(retain_graph=True)
        
        loss_2 = criterion_1(extra_2, targets)
        #loss_2.backward(retain_graph=True)
        
        loss_3 = criterion_1(extra_3, targets)
        #loss_3.backward(retain_graph=True)
        
        loss_4 = criterion_1(outputs, targets)
        #loss_4.backward(retain_graph=True)
        
        #loss = 0.5*loss_1 + 0.5*loss_2 + 0.5*loss_3 + loss_4
        loss = 0.5 * (loss_1 + loss_2 + loss_3) + loss_4
        loss.backward()

        optimizer_1.step()
        optimizer_2.step()
        optimizer_3.step()
        optimizer_4.step()

        train_loss += loss_4.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        if i % 100 == 0:
            print(f'Step [{i}/{len(trainloader)}] | Loss: {loss.item():.4f}')
            log_file.write(f'Step [{i}/{len(trainloader)}] | Loss: {loss.item():.4f}\n')
            log_file.flush()
            print(f'Loss1:{loss_1.item():.4f}')
            print(f'Loss2:{loss_2.item():.4f}')
            print(f'Loss3:{loss_3.item():.4f}')
            print(f'Loss4:{loss_4.item():.4f}')
            log_file.write(f'loss1:{loss_1.item():.4f},loss2:{loss_2.item():.4f},loss3:{loss_3.item():.4f},loss4:{loss_4.item():.4f}\n ')
            log_file.flush()
    print('Train Loss: %.3f | Acc: %.3f%% (%d/%d)'
          % (train_loss / len(trainloader), 100. * correct / total, correct, total))
    log_file.write(
        f"Epoch {epoch}: Train Loss = {train_loss / len(trainloader):.3f}, Accuracy = {100. * correct / total:.3f}%")
    log_file.flush()
    train_losses.append(train_loss / len(trainloader))

def test(epoch):
    net.eval()
    test_loss = 0
    correct_top1 = 0
    correct_top5 = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in tqdm(testloader):
            inputs, targets = inputs.to(device0), targets.to(device0)
            outputs, _, _, _ = net(inputs)
            loss = criterion_1(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.topk(5, 1)

            total += targets.size(0)
            correct_top1 += predicted[:, 0].eq(targets).sum().item()

            for i in range(targets.size(0)):
                if targets[i] in predicted[i]:
                    correct_top5 += 1

    top1_accuracy = 100. * correct_top1 / total
    top5_accuracy = 100. * correct_top5 / total

    print('Test Loss: %.3f | Top-1 Acc: %.3f%% | Top-5 Acc: %.3f%% (%d/%d)'
          % (test_loss / len(testloader), top1_accuracy, top5_accuracy, correct_top1, total))
    log_file.write(
        f" Test Loss = {test_loss / len(testloader):.3f}, Top-1 Accuracy = {top1_accuracy:.3f}%, Top-5 Accuracy = {top5_accuracy:.3f}%\n")
    log_file.flush()
    test_losses.append(test_loss / len(testloader))

for epoch in range(start_epoch, start_epoch + 100):
    train(epoch)
    test(epoch)
  
    check_path = os.path.join('temp', f'vit_epoch{epoch + 1}.pth')
    torch.save(net.state_dict(), check_path)
    
    scheduler_1.step()
    scheduler_2.step()
    scheduler_3.step()
    scheduler_4.step()

log_file.close()
# Save the trained weights
save_path = 'vit_4out_imagenet.pth'
torch.save(net.state_dict(), save_path)
print("Trained weights saved to:", save_path)

Initial model hash: 299e89b467c34c6ce28cc70ee20b2d8c
Loaded 151 out of 152 parameters
Loaded model hash: b44732fe4f0ea2b095790e3945bab2a5

Epoch: 0


1it [00:20, 20.40s/it]

Step [0/2848] | Loss: 11.9238
Loss1:6.9044
Loss2:6.9033
Loss3:6.9020
Loss4:1.5690


101it [22:36, 13.41s/it]

Step [100/2848] | Loss: 10.9545
Loss1:6.8999
Loss2:6.8944
Loss3:6.7914
Loss4:0.6617


201it [44:58, 13.43s/it]

Step [200/2848] | Loss: 10.9404
Loss1:6.9000
Loss2:6.8863
Loss3:6.7082
Loss4:0.6931


301it [1:07:19, 13.41s/it]

Step [300/2848] | Loss: 10.7286
Loss1:6.8868
Loss2:6.8709
Loss3:6.6278
Loss4:0.5359


401it [1:29:40, 13.40s/it]

Step [400/2848] | Loss: 10.7403
Loss1:6.8804
Loss2:6.8515
Loss3:6.5648
Loss4:0.5920


501it [1:52:01, 13.40s/it]

Step [500/2848] | Loss: 10.6186
Loss1:6.8641
Loss2:6.7898
Loss3:6.4976
Loss4:0.5428


601it [2:14:22, 13.43s/it]

Step [600/2848] | Loss: 10.4813
Loss1:6.8501
Loss2:6.7648
Loss3:6.4893
Loss4:0.4292


701it [2:36:43, 13.43s/it]

Step [700/2848] | Loss: 10.5826
Loss1:6.8338
Loss2:6.7618
Loss3:6.4500
Loss4:0.5598


801it [2:59:04, 13.40s/it]

Step [800/2848] | Loss: 10.3482
Loss1:6.7890
Loss2:6.6889
Loss3:6.3379
Loss4:0.4403


901it [3:21:25, 13.40s/it]

Step [900/2848] | Loss: 10.3909
Loss1:6.7424
Loss2:6.6307
Loss3:6.2897
Loss4:0.5595


1001it [3:43:46, 13.40s/it]

Step [1000/2848] | Loss: 10.2461
Loss1:6.6825
Loss2:6.5882
Loss3:6.2402
Loss4:0.4907


1101it [4:06:05, 13.40s/it]

Step [1100/2848] | Loss: 10.1711
Loss1:6.6640
Loss2:6.6069
Loss3:6.1359
Loss4:0.4677


1201it [4:28:24, 13.41s/it]

Step [1200/2848] | Loss: 10.1658
Loss1:6.5994
Loss2:6.5808
Loss3:6.2132
Loss4:0.4692


1301it [4:50:46, 13.42s/it]

Step [1300/2848] | Loss: 10.0338
Loss1:6.5288
Loss2:6.5354
Loss3:6.0708
Loss4:0.4663


1401it [5:13:06, 13.40s/it]

Step [1400/2848] | Loss: 9.9754
Loss1:6.5017
Loss2:6.5007
Loss3:6.1130
Loss4:0.4177


1501it [5:35:26, 13.40s/it]

Step [1500/2848] | Loss: 9.8663
Loss1:6.4444
Loss2:6.4346
Loss3:6.0486
Loss4:0.4025


1601it [5:57:49, 13.40s/it]

Step [1600/2848] | Loss: 9.8665
Loss1:6.4678
Loss2:6.4696
Loss3:5.9979
Loss4:0.3988


1701it [6:20:09, 13.40s/it]

Step [1700/2848] | Loss: 9.7055
Loss1:6.3624
Loss2:6.3757
Loss3:5.9126
Loss4:0.3802


1801it [6:42:31, 13.40s/it]

Step [1800/2848] | Loss: 9.7294
Loss1:6.3317
Loss2:6.3605
Loss3:5.9076
Loss4:0.4295


1901it [7:04:50, 13.40s/it]

Step [1900/2848] | Loss: 9.7058
Loss1:6.3445
Loss2:6.3802
Loss3:5.9341
Loss4:0.3763


2001it [7:27:10, 13.40s/it]

Step [2000/2848] | Loss: 9.6905
Loss1:6.3244
Loss2:6.3131
Loss3:5.9611
Loss4:0.3911


2101it [7:49:31, 13.40s/it]

Step [2100/2848] | Loss: 9.5418
Loss1:6.2443
Loss2:6.2489
Loss3:5.8821
Loss4:0.3542


2201it [8:11:52, 13.43s/it]

Step [2200/2848] | Loss: 9.5169
Loss1:6.1998
Loss2:6.1885
Loss3:5.8114
Loss4:0.4171


2301it [8:34:13, 13.40s/it]

Step [2300/2848] | Loss: 9.4410
Loss1:6.1402
Loss2:6.1748
Loss3:5.8063
Loss4:0.3803


2401it [8:56:32, 13.40s/it]

Step [2400/2848] | Loss: 9.5981
Loss1:6.2001
Loss2:6.2215
Loss3:5.9378
Loss4:0.4184


2501it [9:18:53, 13.40s/it]

Step [2500/2848] | Loss: 9.5083
Loss1:6.1444
Loss2:6.1157
Loss3:5.8535
Loss4:0.4514


2601it [9:41:12, 13.39s/it]

Step [2600/2848] | Loss: 9.4927
Loss1:6.1195
Loss2:6.1224
Loss3:5.7495
Loss4:0.4969


2701it [10:03:31, 13.40s/it]

Step [2700/2848] | Loss: 9.1964
Loss1:6.0195
Loss2:5.9948
Loss3:5.7043
Loss4:0.3371


2801it [10:25:52, 13.44s/it]

Step [2800/2848] | Loss: 9.4525
Loss1:6.1054
Loss2:6.0881
Loss3:5.8568
Loss4:0.4274


2848it [10:36:09, 13.40s/it]


Train Loss: 0.470 | Acc: 86.347% (1106251/1281167)


100%|██████████| 112/112 [12:45<00:00,  6.83s/it]


Test Loss: 0.812 | Top-1 Acc: 78.710% | Top-5 Acc: 94.786% (39355/50000)

Epoch: 1


1it [00:17, 17.98s/it]

Step [0/2848] | Loss: 9.2329
Loss1:6.0200
Loss2:5.9800
Loss3:5.7789
Loss4:0.3435


101it [22:38, 13.43s/it]

Step [100/2848] | Loss: 9.1841
Loss1:6.0295
Loss2:5.9882
Loss3:5.7338
Loss4:0.3083


201it [44:59, 13.40s/it]

Step [200/2848] | Loss: 9.3273
Loss1:6.0765
Loss2:6.0587
Loss3:5.7849
Loss4:0.3672


301it [1:07:21, 13.44s/it]

Step [300/2848] | Loss: 9.0367
Loss1:5.8914
Loss2:5.9042
Loss3:5.6450
Loss4:0.3164


401it [1:29:42, 13.41s/it]

Step [400/2848] | Loss: 8.9938
Loss1:5.9107
Loss2:5.8949
Loss3:5.6439
Loss4:0.2690


501it [1:52:03, 13.40s/it]

Step [500/2848] | Loss: 8.9272
Loss1:5.8282
Loss2:5.8191
Loss3:5.5674
Loss4:0.3199


601it [2:14:23, 13.41s/it]

Step [600/2848] | Loss: 9.1336
Loss1:5.9885
Loss2:5.9604
Loss3:5.7969
Loss4:0.2606


701it [2:36:45, 13.40s/it]

Step [700/2848] | Loss: 9.1590
Loss1:5.9223
Loss2:5.9180
Loss3:5.7551
Loss4:0.3613


801it [2:59:05, 13.41s/it]

Step [800/2848] | Loss: 8.9905
Loss1:5.8699
Loss2:5.8731
Loss3:5.6440
Loss4:0.2970


901it [3:21:26, 13.42s/it]

Step [900/2848] | Loss: 8.9468
Loss1:5.8270
Loss2:5.8144
Loss3:5.5872
Loss4:0.3325


1001it [3:43:46, 13.40s/it]

Step [1000/2848] | Loss: 9.0857
Loss1:5.9222
Loss2:5.9120
Loss3:5.7267
Loss4:0.3052


1101it [4:06:07, 13.41s/it]

Step [1100/2848] | Loss: 8.9425
Loss1:5.8745
Loss2:5.8574
Loss3:5.6136
Loss4:0.2698


1201it [4:28:28, 13.48s/it]

Step [1200/2848] | Loss: 8.7615
Loss1:5.7205
Loss2:5.7041
Loss3:5.5284
Loss4:0.2850


1301it [4:50:49, 13.45s/it]

Step [1300/2848] | Loss: 8.9160
Loss1:5.8414
Loss2:5.8009
Loss3:5.6450
Loss4:0.2723


1401it [5:13:10, 13.40s/it]

Step [1400/2848] | Loss: 8.8690
Loss1:5.7900
Loss2:5.7857
Loss3:5.6126
Loss4:0.2748


1501it [5:35:32, 13.41s/it]

Step [1500/2848] | Loss: 9.1066
Loss1:5.8868
Loss2:5.8936
Loss3:5.6694
Loss4:0.3817


1601it [5:57:54, 13.40s/it]

Step [1600/2848] | Loss: 8.9197
Loss1:5.7836
Loss2:5.8016
Loss3:5.5775
Loss4:0.3383


1701it [6:20:15, 13.40s/it]

Step [1700/2848] | Loss: 8.9429
Loss1:5.7966
Loss2:5.8182
Loss3:5.6474
Loss4:0.3118


1801it [6:42:37, 13.40s/it]

Step [1800/2848] | Loss: 8.6843
Loss1:5.6502
Loss2:5.6182
Loss3:5.4213
Loss4:0.3394


1901it [7:04:58, 13.40s/it]

Step [1900/2848] | Loss: 8.7556
Loss1:5.6898
Loss2:5.7023
Loss3:5.5407
Loss4:0.2893


2001it [7:27:18, 13.40s/it]

Step [2000/2848] | Loss: 9.0694
Loss1:5.8032
Loss2:5.7997
Loss3:5.6533
Loss4:0.4413


2101it [7:49:40, 13.40s/it]

Step [2100/2848] | Loss: 8.6789
Loss1:5.6101
Loss2:5.6265
Loss3:5.4579
Loss4:0.3316


2201it [8:12:02, 13.43s/it]

Step [2200/2848] | Loss: 8.6266
Loss1:5.5879
Loss2:5.5701
Loss3:5.4286
Loss4:0.3333


2301it [8:34:21, 13.40s/it]

Step [2300/2848] | Loss: 8.7869
Loss1:5.7211
Loss2:5.6909
Loss3:5.6017
Loss4:0.2801


2401it [8:56:43, 13.40s/it]

Step [2400/2848] | Loss: 8.8083
Loss1:5.7389
Loss2:5.6953
Loss3:5.5794
Loss4:0.3016


2501it [9:19:03, 13.40s/it]

Step [2500/2848] | Loss: 8.7354
Loss1:5.6942
Loss2:5.6853
Loss3:5.5208
Loss4:0.2852


2601it [9:41:23, 13.40s/it]

Step [2600/2848] | Loss: 8.6234
Loss1:5.6355
Loss2:5.5973
Loss3:5.4468
Loss4:0.2837


2701it [10:03:47, 13.45s/it]

Step [2700/2848] | Loss: 8.8029
Loss1:5.7053
Loss2:5.6946
Loss3:5.5297
Loss4:0.3381


2801it [10:26:08, 13.40s/it]

Step [2800/2848] | Loss: 8.7529
Loss1:5.6549
Loss2:5.6101
Loss3:5.5048
Loss4:0.3680


2848it [10:36:26, 13.41s/it]


Train Loss: 0.315 | Acc: 90.704% (1162068/1281167)


100%|██████████| 112/112 [12:38<00:00,  6.77s/it]

Test Loss: 0.814 | Top-1 Acc: 79.086% | Top-5 Acc: 94.900% (39543/50000)

Epoch: 2



1it [00:18, 18.45s/it]

Step [0/2848] | Loss: 8.6059
Loss1:5.6612
Loss2:5.6390
Loss3:5.5059
Loss4:0.2028


101it [22:39, 13.40s/it]

Step [100/2848] | Loss: 8.5396
Loss1:5.6071
Loss2:5.5857
Loss3:5.4370
Loss4:0.2247


201it [44:59, 13.42s/it]

Step [200/2848] | Loss: 8.5904
Loss1:5.6222
Loss2:5.6420
Loss3:5.4537
Loss4:0.2314


301it [1:07:21, 13.41s/it]

Step [300/2848] | Loss: 8.5610
Loss1:5.6035
Loss2:5.5813
Loss3:5.4541
Loss4:0.2416


401it [1:29:40, 13.41s/it]

Step [400/2848] | Loss: 8.5991
Loss1:5.6403
Loss2:5.6515
Loss3:5.4583
Loss4:0.2240


501it [1:52:01, 13.40s/it]

Step [500/2848] | Loss: 8.5309
Loss1:5.6269
Loss2:5.5751
Loss3:5.4454
Loss4:0.2072


601it [2:14:21, 13.40s/it]

Step [600/2848] | Loss: 8.5137
Loss1:5.6026
Loss2:5.5643
Loss3:5.4354
Loss4:0.2126


701it [2:36:38, 13.40s/it]

Step [700/2848] | Loss: 8.4877
Loss1:5.5627
Loss2:5.5328
Loss3:5.3805
Loss4:0.2497


801it [2:58:59, 13.41s/it]

Step [800/2848] | Loss: 8.5449
Loss1:5.5729
Loss2:5.5266
Loss3:5.3953
Loss4:0.2976


901it [3:21:20, 13.40s/it]

Step [900/2848] | Loss: 8.3826
Loss1:5.4606
Loss2:5.4615
Loss3:5.3371
Loss4:0.2530


1001it [3:43:39, 13.40s/it]

Step [1000/2848] | Loss: 8.3258
Loss1:5.4602
Loss2:5.4165
Loss3:5.3195
Loss4:0.2277


1101it [4:05:58, 13.45s/it]

Step [1100/2848] | Loss: 8.6129
Loss1:5.6140
Loss2:5.6037
Loss3:5.4915
Loss4:0.2583


1201it [4:28:19, 13.41s/it]

Step [1200/2848] | Loss: 8.4030
Loss1:5.5062
Loss2:5.4687
Loss3:5.3675
Loss4:0.2319


1301it [4:50:39, 13.38s/it]

Step [1300/2848] | Loss: 8.5591
Loss1:5.6090
Loss2:5.5865
Loss3:5.4825
Loss4:0.2201


1401it [5:13:01, 13.40s/it]

Step [1400/2848] | Loss: 8.6184
Loss1:5.6099
Loss2:5.5718
Loss3:5.4404
Loss4:0.3073


1501it [5:35:21, 13.40s/it]

Step [1500/2848] | Loss: 8.3143
Loss1:5.4875
Loss2:5.4123
Loss3:5.3001
Loss4:0.2144


1601it [5:57:40, 13.37s/it]

Step [1600/2848] | Loss: 8.4927
Loss1:5.5928
Loss2:5.5202
Loss3:5.4178
Loss4:0.2273


1701it [6:19:59, 13.40s/it]

Step [1700/2848] | Loss: 8.6165
Loss1:5.6325
Loss2:5.6050
Loss3:5.4976
Loss4:0.2490


1801it [6:42:21, 13.40s/it]

Step [1800/2848] | Loss: 8.2916
Loss1:5.4460
Loss2:5.4203
Loss3:5.3157
Loss4:0.2007


1901it [7:04:43, 13.41s/it]

Step [1900/2848] | Loss: 8.3372
Loss1:5.4212
Loss2:5.4142
Loss3:5.3135
Loss4:0.2627


2001it [7:27:03, 13.40s/it]

Step [2000/2848] | Loss: 8.4984
Loss1:5.5698
Loss2:5.5495
Loss3:5.4259
Loss4:0.2258


2101it [7:49:24, 13.40s/it]

Step [2100/2848] | Loss: 8.3045
Loss1:5.4206
Loss2:5.4107
Loss3:5.3147
Loss4:0.2315


2201it [8:11:45, 13.40s/it]

Step [2200/2848] | Loss: 8.5942
Loss1:5.5975
Loss2:5.5510
Loss3:5.5066
Loss4:0.2667


2301it [8:34:05, 13.40s/it]

Step [2300/2848] | Loss: 8.3398
Loss1:5.4160
Loss2:5.4397
Loss3:5.3191
Loss4:0.2524


2401it [8:56:26, 13.40s/it]

Step [2400/2848] | Loss: 8.3801
Loss1:5.4196
Loss2:5.4639
Loss3:5.3424
Loss4:0.2672


2501it [9:18:48, 13.42s/it]

Step [2500/2848] | Loss: 8.2925
Loss1:5.4472
Loss2:5.3977
Loss3:5.2880
Loss4:0.2260


2601it [9:41:11, 13.40s/it]

Step [2600/2848] | Loss: 8.4643
Loss1:5.4914
Loss2:5.5095
Loss3:5.3822
Loss4:0.2727


2701it [10:03:32, 13.40s/it]

Step [2700/2848] | Loss: 8.5075
Loss1:5.5384
Loss2:5.4939
Loss3:5.4453
Loss4:0.2687


2801it [10:25:54, 13.40s/it]

Step [2800/2848] | Loss: 8.3968
Loss1:5.5026
Loss2:5.4724
Loss3:5.3388
Loss4:0.2399


2848it [10:36:12, 13.40s/it]


Train Loss: 0.240 | Acc: 92.927% (1190554/1281167)


100%|██████████| 112/112 [12:42<00:00,  6.81s/it]

Test Loss: 0.823 | Top-1 Acc: 79.080% | Top-5 Acc: 94.894% (39540/50000)

Epoch: 3



1it [00:18, 18.21s/it]

Step [0/2848] | Loss: 7.9808
Loss1:5.2469
Loss2:5.2280
Loss3:5.1431
Loss4:0.1718


101it [22:40, 13.40s/it]

Step [100/2848] | Loss: 8.4083
Loss1:5.5561
Loss2:5.5265
Loss3:5.4290
Loss4:0.1525


201it [45:00, 13.40s/it]

Step [200/2848] | Loss: 8.4433
Loss1:5.5281
Loss2:5.5420
Loss3:5.4337
Loss4:0.1914


301it [1:07:21, 13.40s/it]

Step [300/2848] | Loss: 8.4145
Loss1:5.5299
Loss2:5.5034
Loss3:5.4210
Loss4:0.1873


401it [1:29:41, 13.40s/it]

Step [400/2848] | Loss: 8.3935
Loss1:5.5351
Loss2:5.4717
Loss3:5.4202
Loss4:0.1800


501it [1:52:02, 13.40s/it]

Step [500/2848] | Loss: 8.3789
Loss1:5.5463
Loss2:5.4858
Loss3:5.3668
Loss4:0.1794


601it [2:14:23, 13.43s/it]

Step [600/2848] | Loss: 8.3081
Loss1:5.4566
Loss2:5.4241
Loss3:5.3103
Loss4:0.2126


701it [2:36:46, 13.40s/it]

Step [700/2848] | Loss: 8.2538
Loss1:5.4605
Loss2:5.4424
Loss3:5.2897
Loss4:0.1575


801it [2:59:07, 13.41s/it]

Step [800/2848] | Loss: 8.4675
Loss1:5.5566
Loss2:5.5153
Loss3:5.4242
Loss4:0.2195


901it [3:21:28, 13.44s/it]

Step [900/2848] | Loss: 8.3915
Loss1:5.5385
Loss2:5.4723
Loss3:5.3736
Loss4:0.1994


1001it [3:43:49, 13.40s/it]

Step [1000/2848] | Loss: 8.1425
Loss1:5.3373
Loss2:5.3470
Loss3:5.2034
Loss4:0.1987


1101it [4:06:09, 13.42s/it]

Step [1100/2848] | Loss: 8.2908
Loss1:5.4567
Loss2:5.4334
Loss3:5.3048
Loss4:0.1933


1201it [4:28:30, 13.40s/it]

Step [1200/2848] | Loss: 8.2492
Loss1:5.4556
Loss2:5.4423
Loss3:5.3297
Loss4:0.1354


1301it [4:50:51, 13.45s/it]

Step [1300/2848] | Loss: 8.1124
Loss1:5.3251
Loss2:5.3218
Loss3:5.2115
Loss4:0.1832


1401it [5:13:13, 13.51s/it]

Step [1400/2848] | Loss: 8.0220
Loss1:5.2524
Loss2:5.2198
Loss3:5.2049
Loss4:0.1834


1501it [5:35:35, 13.42s/it]

Step [1500/2848] | Loss: 8.1824
Loss1:5.4369
Loss2:5.3628
Loss3:5.3017
Loss4:0.1317


1592it [5:55:54, 13.40s/it]