In [None]:
import os
import math
import time
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision, torchvision.transforms as T


# Utilities
def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

def try_compute_flops(model, input_size=(1,3,32,32)):
    try:
        from thop import profile
        device = next(model.parameters()).device
        dummy = torch.randn(*input_size).to(device)
        model.eval()
        macs, params = profile(model, inputs=(dummy,), verbose=False)
        flops = 2 * macs
        return flops
    except Exception:
        return None


# Model building blocks
class DWConv(nn.Module):
    def __init__(self, dim, kernel=3):
        super().__init__()
        self.op = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=kernel, padding=kernel//2, groups=dim, bias=False),
            nn.BatchNorm2d(dim),
            nn.ReLU(inplace=True)
        )
    def forward(self, x, H, W):
        # x: [B, N, C] -> to [B, C, H, W]
        B, N, C = x.shape
        x2 = x.transpose(1,2).reshape(B, C, H, W)
        x2 = self.op(x2)
        x2 = x2.reshape(B, C, -1).transpose(1,2)
        return x2

class FFN(nn.Module):
    def __init__(self, dim, hidden_dim, drop=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(drop),
        )
    def forward(self, x):
        return self.net(x)

class CascadedGroupAttention(nn.Module):
    def __init__(self, dim, num_heads=4, qk_ratio=0.5, kv_expand_v=False):
        super().__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads"
        self.dim = dim
        self.h = num_heads
        self.head_dim = dim // num_heads


        self.qk_dim = max(1, int(self.head_dim * qk_ratio))


        self.q_projs = nn.ModuleList([nn.Linear(self.head_dim, self.qk_dim, bias=False) for _ in range(self.h)])
        self.k_projs = nn.ModuleList([nn.Linear(self.head_dim, self.qk_dim, bias=False) for _ in range(self.h)])
        self.v_projs = nn.ModuleList([nn.Linear(self.head_dim, self.head_dim, bias=False) for _ in range(self.h)])
        self.out_proj = nn.Linear(dim, dim)


        self.q_dw = nn.ModuleList([DWConv(self.qk_dim) for _ in range(self.h)])

    def forward(self, x, H, W):

        B, N, C = x.shape

        x_split = x.reshape(B, N, self.h, self.head_dim).permute(2,0,1,3)  

        head_outputs = []
        prev_out = None
        for j in range(self.h):
            xj = x_split[j]  
            if prev_out is not None:
                xj = xj + prev_out

           
            q = self.q_projs[j](xj)   # [B, N, qk_dim]
            k = self.k_projs[j](xj)   # [B, N, qk_dim]
            v = self.v_projs[j](xj)   # [B, N, head_dim]

            q = q.transpose(1,2).reshape(B*self.qk_dim, 1, H, W)  
            q = q.reshape(B, self.qk_dim, H, W).reshape(B, self.qk_dim, -1).transpose(1,2)[:,:N,:] 

            qk = torch.matmul(q, k.transpose(-2, -1))  # [B, N, N]
            qk = qk / math.sqrt(max(1.0, float(self.qk_dim)))
            attn = torch.softmax(qk, dim=-1)
            out_j = torch.matmul(attn, v)  # [B, N, head_dim]

            head_outputs.append(out_j)
            prev_out = out_j  

        out = torch.cat(head_outputs, dim=-1)  # [B, N, C]
        out = self.out_proj(out)
        return out

class SandwichBlock(nn.Module):
    def __init__(self, dim, num_heads=4, mlp_ratio=2.0, drop=0.0, qk_ratio=0.5):
        super().__init__()
        hidden = int(dim * mlp_ratio)
        self.ffn1 = FFN(dim, hidden, drop)
        self.dw1 = DWConv(dim)
        self.attn = CascadedGroupAttention(dim, num_heads=num_heads, qk_ratio=qk_ratio)
        self.dw2 = DWConv(dim)
        self.ffn2 = FFN(dim, hidden, drop)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.norm4 = nn.LayerNorm(dim)

    def forward(self, x, H, W):
        # x: [B, N, C]
        y = self.norm1(x)
        y = self.ffn1(y)
        x = x + y

        y = self.norm2(x)
        y = self.dw1(y, H, W)
        x = x + y

        y = self.norm3(x)
        y = self.attn(y, H, W)
        x = x + y

        y = self.norm4(x)
        y = self.ffn2(y)
        x = x + y
        return x

# -------------------------
# TinyEfficientViT 
# -------------------------
class TinyEfficientViT(nn.Module):
    def __init__(self, img_size=32, num_classes=100,
                 stage_channels=(64,128,192), stage_depths=(1,2,3),
                 stage_heads=(4,4,4), embed_dim=64, mlp_ratio=2.0, qk_ratio=0.5):
        super().__init__()
        self.img_size = img_size
        self.stem = nn.Sequential(
            nn.Conv2d(3, stage_channels[0], kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(stage_channels[0]),
            nn.ReLU(inplace=True),
            nn.Conv2d(stage_channels[0], stage_channels[0], kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(stage_channels[0]),
            nn.ReLU(inplace=True),
        )
        self.stage_channels = stage_channels
        self.stage_depths = stage_depths
        self.stage_heads = stage_heads

        self.stage_blocks = nn.ModuleList()
        self.stage_proj = nn.ModuleList()

        for idx, (C, L, H) in enumerate(zip(stage_channels, stage_depths, stage_heads)):
            proj = nn.Sequential(
                nn.Conv2d(C, C, kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(C),
                nn.ReLU(inplace=True),
            )
            self.stage_proj.append(proj)

            blocks = nn.ModuleList([SandwichBlock(dim=C, num_heads=H, mlp_ratio=mlp_ratio, qk_ratio=qk_ratio) for _ in range(L)])
            self.stage_blocks.append(blocks)

            if idx < len(stage_channels) - 1:
                down = nn.Conv2d(C, stage_channels[idx+1], kernel_size=3, stride=2, padding=1, bias=False)
                setattr(self, f"down_{idx}", nn.Sequential(down, nn.BatchNorm2d(stage_channels[idx+1]), nn.ReLU(inplace=True)))

        # classifier head
        final_C = stage_channels[-1]
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.head = nn.Linear(final_C, num_classes)

    def forward(self, x):
        B = x.shape[0]
        # stage 0 stem
        x = self.stem(x)  # [B, C1, H/2, W/2]
        # stage 0
        for blk in self.stage_blocks[0]:
            B, C, H, W = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
            N = H * W
            tokens = x.reshape(B, C, -1).permute(0,2,1)  # [B, N, C]
            tokens = blk(tokens, H, W)
            x = tokens.permute(0,2,1).reshape(B, C, H, W)

        for idx in range(1, len(self.stage_blocks)):
            down = getattr(self, f"down_{idx-1}")
            x = down(x)  
            x = self.stage_proj[idx](x)
            # blocks
            for blk in self.stage_blocks[idx]:
                B, C, H, W = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
                tokens = x.reshape(B, C, -1).permute(0,2,1)
                tokens = blk(tokens, H, W)
                x = tokens.permute(0,2,1).reshape(B, C, H, W)

        # head
        x = self.global_pool(x).reshape(B, -1)
        out = self.head(x)
        return out


# CIFAR-100 
def get_cifar100_loaders(bs=128, num_workers=4):
    t_train = T.Compose([
        T.RandomCrop(32, padding=4),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize((0.5071,0.4865,0.4409),(0.2673,0.2564,0.2761)),
    ])
    t_test = T.Compose([
        T.ToTensor(),
        T.Normalize((0.5071,0.4865,0.4409),(0.2673,0.2564,0.2761)),
    ])
    train = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=t_train)
    test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=t_test)
    return DataLoader(train, batch_size=bs, shuffle=True, num_workers=num_workers), \
           DataLoader(test, batch_size=bs, shuffle=False, num_workers=num_workers)

# Training loop, metrics, plotting
def train_one_epoch(model, loader, opt, device, loss_fn):
    model.train()
    total_loss = 0.0
    correct_top1 = 0
    correct_top5 = 0
    total = 0
    pbar = tqdm(loader, leave=False)
    for x,y in pbar:
        x,y = x.to(device), y.to(device)
        opt.zero_grad()
        out = model(x)
        loss = loss_fn(out, y)
        loss.backward()
        opt.step()

        total_loss += loss.item() * x.size(0)
        total += x.size(0)
        pred = out.topk(5, dim=1)[1]
        correct_top1 += (pred[:,0] == y).sum().item()
        correct_top5 += sum([1 if y[i].item() in pred[i].tolist() else 0 for i in range(y.size(0))])

    avg_loss = total_loss / total
    top1 = correct_top1 / total
    top5 = correct_top5 / total
    return avg_loss, top1, top5

def evaluate(model, loader, device, loss_fn):
    model.eval()
    total_loss = 0.0
    correct_top1 = 0
    correct_top5 = 0
    total = 0
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            out = model(x)
            loss = loss_fn(out, y)
            total_loss += loss.item() * x.size(0)
            total += x.size(0)
            pred = out.topk(5, dim=1)[1]
            correct_top1 += (pred[:,0] == y).sum().item()
            correct_top5 += sum([1 if y[i].item() in pred[i].tolist() else 0 for i in range(y.size(0))])
    avg_loss = total_loss / total
    top1 = correct_top1 / total
    top5 = correct_top5 / total
    return avg_loss, top1, top5

def plot_metrics(history, out_dir='outputs'):
    os.makedirs(out_dir, exist_ok=True)
    epochs = len(history['train_loss'])
    # Loss
    plt.figure()
    plt.plot(range(1,epochs+1), history['train_loss'], label='train_loss')
    plt.plot(range(1,epochs+1), history['val_loss'], label='val_loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid()
    plt.title('Loss')
    plt.savefig(os.path.join(out_dir,'loss.png'), dpi=200)
    plt.close()
    # Accuracy
    plt.figure()
    plt.plot(range(1,epochs+1), np.array(history['train_top1'])*100, label='train_top1')
    plt.plot(range(1,epochs+1), np.array(history['val_top1'])*100, label='val_top1')
    plt.plot(range(1,epochs+1), np.array(history['val_top5'])*100, label='val_top5')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy (%)'); plt.legend(); plt.grid()
    plt.title('Accuracy')
    plt.savefig(os.path.join(out_dir,'accuracy.png'), dpi=200)
    plt.close()


# Main experiment
def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    batch_size = 128
    epochs = 150
    lr = 3e-4
    weight_decay = 0.05
    num_classes = 100

    stage_channels = (64, 128, 192)
    stage_depths = (1, 2, 3)
    stage_heads = (4, 4, 4)
    mlp_ratio = 2.0  
    qk_ratio = 0.25   

    print("Building model...")
    model = TinyEfficientViT(img_size=32, num_classes=num_classes,
                             stage_channels=stage_channels,
                             stage_depths=stage_depths,
                             stage_heads=stage_heads,
                             mlp_ratio=mlp_ratio,
                             qk_ratio=qk_ratio).to(device)

    total_params, trainable = count_params(model)
    print(f"Total params: {total_params:,} ({total_params/1e6:.3f}M) | Trainable: {trainable:,}")

    flops = try_compute_flops(model, input_size=(1,3,32,32))
    if flops is not None:
        print(f"FLOPs (approx): {flops:.2e}")
    else:
        print("thop not available — FLOPs not computed (install 'thop' to compute)")

    # Data
    train_loader, test_loader = get_cifar100_loaders(bs=batch_size, num_workers=4)

    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    loss_fn = nn.CrossEntropyLoss()

    history = {'train_loss':[], 'train_top1':[], 'train_top5':[], 'val_loss':[], 'val_top1':[], 'val_top5':[]}
    best_acc = 0.0
    os.makedirs('outputs', exist_ok=True)

    for ep in range(1, epochs+1):
        t0 = time.time()
        train_loss, train_top1, train_top5 = train_one_epoch(model, train_loader, opt, device, loss_fn)
        val_loss, val_top1, val_top5 = evaluate(model, test_loader, device, loss_fn)
        sched.step()
        history['train_loss'].append(train_loss); history['train_top1'].append(train_top1); history['train_top5'].append(train_top5)
        history['val_loss'].append(val_loss); history['val_top1'].append(val_top1); history['val_top5'].append(val_top5)

        if val_top1 > best_acc:
            best_acc = val_top1
            torch.save({'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': opt.state_dict(),
                        'epoch': ep},
                       os.path.join('outputs', 'best_tinyefficientvit_cifar100.pth'))

        t1 = time.time()
        print(f"Epoch {ep:03d}/{epochs} | train_loss {train_loss:.4f} | train_top1 {train_top1*100:.2f}% | val_top1 {val_top1*100:.2f}% | val_top5 {val_top5*100:.2f}% | time {t1-t0:.1f}s")

        if ep % 25 == 0:
            torch.save({'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': opt.state_dict(),
                        'epoch': ep},
                       os.path.join('outputs', f'ckpt_ep{ep:03d}.pth'))

    torch.save({'model_state_dict': model.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'epoch': epochs}, os.path.join('outputs', 'last_tinyefficientvit_cifar100.pth'))

    plot_metrics(history, out_dir='outputs')

    print("Training done. Best val top-1: {:.2f}%".format(best_acc*100))
    print("Saved outputs & plots to ./outputs")

if __name__ == '__main__':
    main()


Building model...
Total params: 1,829,828 (1.830M) | Trainable: 1,829,828
thop not available — FLOPs not computed (install 'thop' to compute)


100%|██████████| 169M/169M [00:05<00:00, 28.8MB/s]
                                                 

Epoch 001/150 | train_loss 3.6706 | train_top1 13.58% | val_top1 21.74% | val_top5 49.66% | time 34.8s


                                                 

Epoch 002/150 | train_loss 2.9327 | train_top1 26.03% | val_top1 30.11% | val_top5 61.88% | time 33.1s


                                                 

Epoch 003/150 | train_loss 2.5629 | train_top1 33.27% | val_top1 36.88% | val_top5 70.04% | time 33.3s


                                                 

Epoch 004/150 | train_loss 2.3050 | train_top1 38.78% | val_top1 41.43% | val_top5 73.58% | time 34.1s


                                                 

Epoch 005/150 | train_loss 2.1156 | train_top1 42.95% | val_top1 44.58% | val_top5 75.75% | time 34.3s


                                                 

Epoch 006/150 | train_loss 1.9740 | train_top1 46.20% | val_top1 45.70% | val_top5 78.10% | time 33.8s


                                                 

Epoch 007/150 | train_loss 1.8439 | train_top1 49.32% | val_top1 48.12% | val_top5 79.36% | time 34.0s


                                                 

Epoch 008/150 | train_loss 1.7389 | train_top1 51.63% | val_top1 49.89% | val_top5 80.69% | time 33.6s


                                                 

Epoch 009/150 | train_loss 1.6530 | train_top1 53.69% | val_top1 51.34% | val_top5 82.00% | time 33.9s


                                                 

Epoch 010/150 | train_loss 1.5653 | train_top1 55.71% | val_top1 52.65% | val_top5 82.90% | time 33.9s


                                                 

Epoch 011/150 | train_loss 1.4895 | train_top1 57.44% | val_top1 53.01% | val_top5 83.04% | time 34.2s


                                                 

Epoch 012/150 | train_loss 1.4252 | train_top1 58.96% | val_top1 54.88% | val_top5 84.36% | time 33.9s


                                                 

Epoch 013/150 | train_loss 1.3608 | train_top1 60.67% | val_top1 55.61% | val_top5 84.01% | time 33.9s


                                                 

Epoch 014/150 | train_loss 1.2950 | train_top1 62.40% | val_top1 55.95% | val_top5 84.97% | time 34.0s


                                                 

Epoch 015/150 | train_loss 1.2418 | train_top1 63.77% | val_top1 57.61% | val_top5 85.48% | time 33.8s


                                                 

Epoch 016/150 | train_loss 1.1924 | train_top1 65.16% | val_top1 57.69% | val_top5 85.45% | time 34.0s


                                                 

Epoch 017/150 | train_loss 1.1336 | train_top1 66.51% | val_top1 57.58% | val_top5 85.69% | time 34.0s


                                                 

Epoch 018/150 | train_loss 1.0951 | train_top1 67.57% | val_top1 58.66% | val_top5 86.44% | time 34.6s


                                                 

Epoch 019/150 | train_loss 1.0427 | train_top1 68.89% | val_top1 58.85% | val_top5 86.16% | time 34.7s


                                                 

Epoch 020/150 | train_loss 1.0057 | train_top1 69.76% | val_top1 59.46% | val_top5 86.09% | time 34.4s


                                                 

Epoch 021/150 | train_loss 0.9604 | train_top1 70.85% | val_top1 60.14% | val_top5 86.69% | time 34.5s


                                                 

Epoch 022/150 | train_loss 0.9250 | train_top1 71.82% | val_top1 60.57% | val_top5 86.83% | time 34.0s


                                                 

Epoch 023/150 | train_loss 0.8816 | train_top1 73.06% | val_top1 60.83% | val_top5 87.24% | time 34.9s


                                                 

Epoch 024/150 | train_loss 0.8386 | train_top1 74.29% | val_top1 60.43% | val_top5 87.29% | time 34.4s


                                                 

Epoch 025/150 | train_loss 0.8056 | train_top1 75.07% | val_top1 61.82% | val_top5 87.64% | time 34.2s


                                                 

Epoch 026/150 | train_loss 0.7751 | train_top1 75.76% | val_top1 61.44% | val_top5 87.11% | time 33.9s


                                                 

Epoch 027/150 | train_loss 0.7266 | train_top1 77.26% | val_top1 61.51% | val_top5 87.11% | time 34.2s


                                                 

Epoch 028/150 | train_loss 0.7006 | train_top1 77.96% | val_top1 61.83% | val_top5 87.13% | time 34.6s


                                                 

Epoch 029/150 | train_loss 0.6731 | train_top1 78.86% | val_top1 61.31% | val_top5 87.13% | time 34.2s


                                                 

Epoch 030/150 | train_loss 0.6347 | train_top1 79.94% | val_top1 62.01% | val_top5 87.63% | time 34.0s


                                                 

Epoch 031/150 | train_loss 0.6014 | train_top1 80.87% | val_top1 61.76% | val_top5 87.57% | time 34.5s


                                                 

Epoch 032/150 | train_loss 0.5805 | train_top1 81.30% | val_top1 62.14% | val_top5 87.08% | time 33.7s


                                                 

Epoch 033/150 | train_loss 0.5490 | train_top1 82.25% | val_top1 61.78% | val_top5 87.01% | time 33.7s


                                                 

Epoch 034/150 | train_loss 0.5212 | train_top1 83.14% | val_top1 62.05% | val_top5 87.42% | time 33.8s


                                                 

Epoch 035/150 | train_loss 0.4931 | train_top1 83.87% | val_top1 62.48% | val_top5 87.35% | time 34.0s


                                                 

Epoch 036/150 | train_loss 0.4744 | train_top1 84.38% | val_top1 62.66% | val_top5 87.57% | time 33.9s


                                                 

Epoch 037/150 | train_loss 0.4447 | train_top1 85.52% | val_top1 61.89% | val_top5 87.36% | time 34.0s


                                                 

Epoch 038/150 | train_loss 0.4218 | train_top1 86.15% | val_top1 63.00% | val_top5 87.59% | time 34.4s


                                                 

Epoch 039/150 | train_loss 0.4028 | train_top1 86.84% | val_top1 62.76% | val_top5 87.32% | time 33.6s


                                                 

Epoch 040/150 | train_loss 0.3909 | train_top1 87.20% | val_top1 62.14% | val_top5 87.11% | time 34.0s


                                                 

Epoch 041/150 | train_loss 0.3626 | train_top1 87.83% | val_top1 62.45% | val_top5 87.22% | time 33.7s


                                                 

Epoch 042/150 | train_loss 0.3431 | train_top1 88.51% | val_top1 63.19% | val_top5 87.36% | time 33.8s


                                                 

Epoch 043/150 | train_loss 0.3314 | train_top1 88.94% | val_top1 62.45% | val_top5 87.08% | time 33.7s


                                                 

Epoch 044/150 | train_loss 0.3133 | train_top1 89.68% | val_top1 62.47% | val_top5 87.10% | time 34.1s


                                                 

Epoch 045/150 | train_loss 0.2989 | train_top1 90.06% | val_top1 62.21% | val_top5 87.27% | time 34.0s


                                                 

Epoch 046/150 | train_loss 0.2781 | train_top1 90.82% | val_top1 62.62% | val_top5 87.48% | time 34.1s


                                                 

Epoch 047/150 | train_loss 0.2714 | train_top1 91.07% | val_top1 63.61% | val_top5 87.22% | time 34.1s


                                                 

Epoch 048/150 | train_loss 0.2524 | train_top1 91.65% | val_top1 62.60% | val_top5 87.55% | time 34.0s


                                                 

Epoch 049/150 | train_loss 0.2510 | train_top1 91.63% | val_top1 62.91% | val_top5 87.11% | time 33.7s


                                                 

Epoch 050/150 | train_loss 0.2397 | train_top1 91.99% | val_top1 63.27% | val_top5 87.38% | time 34.0s


                                                 

Epoch 051/150 | train_loss 0.2256 | train_top1 92.54% | val_top1 62.17% | val_top5 86.71% | time 33.7s


                                                 

Epoch 052/150 | train_loss 0.2123 | train_top1 92.97% | val_top1 62.65% | val_top5 86.92% | time 33.9s


                                                 

Epoch 053/150 | train_loss 0.2073 | train_top1 93.10% | val_top1 62.99% | val_top5 87.32% | time 33.6s


                                                 

Epoch 054/150 | train_loss 0.1936 | train_top1 93.59% | val_top1 63.40% | val_top5 86.99% | time 33.8s


                                                 

Epoch 055/150 | train_loss 0.1939 | train_top1 93.52% | val_top1 62.97% | val_top5 87.27% | time 33.9s


                                                 

Epoch 056/150 | train_loss 0.1874 | train_top1 93.80% | val_top1 63.44% | val_top5 87.02% | time 33.8s


                                                 

Epoch 057/150 | train_loss 0.1719 | train_top1 94.27% | val_top1 63.50% | val_top5 87.04% | time 33.7s


                                                 

Epoch 058/150 | train_loss 0.1670 | train_top1 94.44% | val_top1 62.88% | val_top5 86.76% | time 34.0s


                                                 

Epoch 059/150 | train_loss 0.1644 | train_top1 94.46% | val_top1 63.20% | val_top5 86.95% | time 33.8s


                                                 

Epoch 060/150 | train_loss 0.1527 | train_top1 94.91% | val_top1 63.47% | val_top5 87.11% | time 34.0s


                                                 

Epoch 061/150 | train_loss 0.1433 | train_top1 95.35% | val_top1 63.53% | val_top5 87.26% | time 33.9s


                                                 

Epoch 062/150 | train_loss 0.1440 | train_top1 95.26% | val_top1 63.23% | val_top5 87.25% | time 33.9s


                                                 

Epoch 063/150 | train_loss 0.1391 | train_top1 95.55% | val_top1 62.94% | val_top5 87.08% | time 34.0s


                                                 

Epoch 064/150 | train_loss 0.1317 | train_top1 95.71% | val_top1 63.13% | val_top5 87.38% | time 33.9s


                                                 

Epoch 065/150 | train_loss 0.1247 | train_top1 96.02% | val_top1 63.43% | val_top5 86.98% | time 34.0s


                                                 

Epoch 066/150 | train_loss 0.1224 | train_top1 95.87% | val_top1 63.31% | val_top5 86.89% | time 33.9s


                                                 

Epoch 067/150 | train_loss 0.1153 | train_top1 96.19% | val_top1 63.93% | val_top5 87.17% | time 33.9s


                                                 

Epoch 068/150 | train_loss 0.1148 | train_top1 96.28% | val_top1 63.96% | val_top5 87.23% | time 33.9s


                                                 

Epoch 069/150 | train_loss 0.1060 | train_top1 96.54% | val_top1 63.27% | val_top5 86.84% | time 33.8s


                                                 

Epoch 070/150 | train_loss 0.1038 | train_top1 96.67% | val_top1 64.25% | val_top5 87.47% | time 33.8s


                                                 

Epoch 071/150 | train_loss 0.1025 | train_top1 96.67% | val_top1 63.64% | val_top5 87.23% | time 33.8s


                                                 

Epoch 072/150 | train_loss 0.0926 | train_top1 97.01% | val_top1 63.71% | val_top5 87.26% | time 33.7s


                                                 

Epoch 073/150 | train_loss 0.0914 | train_top1 97.03% | val_top1 63.55% | val_top5 86.90% | time 33.8s


                                                 

Epoch 074/150 | train_loss 0.0904 | train_top1 97.11% | val_top1 64.18% | val_top5 87.46% | time 34.0s


                                                 

Epoch 075/150 | train_loss 0.0861 | train_top1 97.17% | val_top1 63.86% | val_top5 87.31% | time 33.9s


                                                 

Epoch 076/150 | train_loss 0.0767 | train_top1 97.61% | val_top1 63.97% | val_top5 87.50% | time 33.9s


                                                 

Epoch 077/150 | train_loss 0.0778 | train_top1 97.52% | val_top1 64.24% | val_top5 87.23% | time 34.1s


                                                 

Epoch 078/150 | train_loss 0.0729 | train_top1 97.75% | val_top1 64.46% | val_top5 87.19% | time 34.0s


                                                 

Epoch 079/150 | train_loss 0.0751 | train_top1 97.63% | val_top1 64.42% | val_top5 87.21% | time 33.4s


                                                 

Epoch 080/150 | train_loss 0.0704 | train_top1 97.81% | val_top1 64.22% | val_top5 87.12% | time 33.7s


                                                 

Epoch 081/150 | train_loss 0.0668 | train_top1 97.90% | val_top1 63.23% | val_top5 87.39% | time 33.8s


                                                 

Epoch 082/150 | train_loss 0.0657 | train_top1 97.91% | val_top1 64.14% | val_top5 87.16% | time 33.9s


                                                 

Epoch 083/150 | train_loss 0.0602 | train_top1 98.19% | val_top1 63.94% | val_top5 87.35% | time 33.8s


                                                 

Epoch 084/150 | train_loss 0.0571 | train_top1 98.26% | val_top1 63.87% | val_top5 87.32% | time 33.9s


                                                 

Epoch 085/150 | train_loss 0.0562 | train_top1 98.28% | val_top1 64.41% | val_top5 87.43% | time 33.9s


                                                 

Epoch 086/150 | train_loss 0.0561 | train_top1 98.26% | val_top1 64.42% | val_top5 87.29% | time 35.2s


                                                 

Epoch 087/150 | train_loss 0.0547 | train_top1 98.30% | val_top1 64.57% | val_top5 87.64% | time 35.1s


                                                 

Epoch 088/150 | train_loss 0.0523 | train_top1 98.42% | val_top1 64.51% | val_top5 87.46% | time 36.1s


                                                 

Epoch 089/150 | train_loss 0.0449 | train_top1 98.67% | val_top1 64.70% | val_top5 87.50% | time 35.7s


                                                 

Epoch 090/150 | train_loss 0.0462 | train_top1 98.63% | val_top1 64.55% | val_top5 87.53% | time 35.7s


                                                 

Epoch 091/150 | train_loss 0.0434 | train_top1 98.70% | val_top1 64.67% | val_top5 87.45% | time 35.1s


                                                 

Epoch 092/150 | train_loss 0.0421 | train_top1 98.76% | val_top1 64.37% | val_top5 87.53% | time 35.4s


                                                 

Epoch 093/150 | train_loss 0.0404 | train_top1 98.81% | val_top1 64.85% | val_top5 87.75% | time 34.6s


                                                 

Epoch 094/150 | train_loss 0.0386 | train_top1 98.86% | val_top1 64.28% | val_top5 87.49% | time 34.4s


                                                 

Epoch 095/150 | train_loss 0.0372 | train_top1 98.90% | val_top1 64.80% | val_top5 87.17% | time 34.7s


                                                 

Epoch 096/150 | train_loss 0.0345 | train_top1 98.99% | val_top1 64.90% | val_top5 87.21% | time 34.4s


                                                 

Epoch 097/150 | train_loss 0.0336 | train_top1 99.02% | val_top1 64.91% | val_top5 87.33% | time 34.1s


                                                 

Epoch 098/150 | train_loss 0.0304 | train_top1 99.15% | val_top1 65.26% | val_top5 87.25% | time 34.2s


                                                 

Epoch 099/150 | train_loss 0.0289 | train_top1 99.20% | val_top1 65.05% | val_top5 87.71% | time 34.2s


                                                 

Epoch 100/150 | train_loss 0.0280 | train_top1 99.20% | val_top1 65.02% | val_top5 87.38% | time 34.4s


                                                 

Epoch 101/150 | train_loss 0.0262 | train_top1 99.26% | val_top1 64.91% | val_top5 87.55% | time 34.5s


                                                 

Epoch 102/150 | train_loss 0.0245 | train_top1 99.36% | val_top1 64.86% | val_top5 87.43% | time 34.2s


                                                 

Epoch 103/150 | train_loss 0.0246 | train_top1 99.34% | val_top1 64.82% | val_top5 87.62% | time 34.4s


                                                 

Epoch 104/150 | train_loss 0.0228 | train_top1 99.38% | val_top1 65.25% | val_top5 87.59% | time 34.5s


                                                 

Epoch 105/150 | train_loss 0.0249 | train_top1 99.30% | val_top1 65.11% | val_top5 87.81% | time 34.1s


                                                 

Epoch 106/150 | train_loss 0.0198 | train_top1 99.47% | val_top1 65.31% | val_top5 87.79% | time 34.0s


                                                 

Epoch 107/150 | train_loss 0.0200 | train_top1 99.46% | val_top1 64.87% | val_top5 87.74% | time 34.2s


                                                 

Epoch 108/150 | train_loss 0.0213 | train_top1 99.41% | val_top1 64.88% | val_top5 87.69% | time 34.1s


                                                 

Epoch 109/150 | train_loss 0.0179 | train_top1 99.57% | val_top1 65.27% | val_top5 87.62% | time 34.2s


                                                 

Epoch 110/150 | train_loss 0.0173 | train_top1 99.57% | val_top1 65.38% | val_top5 88.01% | time 34.9s


                                                 

Epoch 111/150 | train_loss 0.0150 | train_top1 99.64% | val_top1 65.70% | val_top5 87.68% | time 34.7s


                                                 

Epoch 112/150 | train_loss 0.0153 | train_top1 99.64% | val_top1 65.32% | val_top5 87.90% | time 34.5s


                                                 

Epoch 113/150 | train_loss 0.0141 | train_top1 99.63% | val_top1 65.27% | val_top5 87.86% | time 34.1s


                                                 

Epoch 114/150 | train_loss 0.0147 | train_top1 99.61% | val_top1 65.00% | val_top5 87.84% | time 34.3s


                                                 

Epoch 115/150 | train_loss 0.0132 | train_top1 99.69% | val_top1 65.44% | val_top5 87.87% | time 34.1s


                                                 

Epoch 116/150 | train_loss 0.0119 | train_top1 99.73% | val_top1 65.37% | val_top5 87.99% | time 33.9s


                                                 

Epoch 117/150 | train_loss 0.0115 | train_top1 99.77% | val_top1 65.42% | val_top5 87.88% | time 34.1s


                                                 

Epoch 118/150 | train_loss 0.0113 | train_top1 99.74% | val_top1 65.46% | val_top5 87.69% | time 33.9s


                                                 

Epoch 119/150 | train_loss 0.0098 | train_top1 99.81% | val_top1 65.80% | val_top5 87.76% | time 33.8s


                                                 

Epoch 120/150 | train_loss 0.0099 | train_top1 99.78% | val_top1 65.80% | val_top5 87.74% | time 34.7s


                                                 

Epoch 121/150 | train_loss 0.0109 | train_top1 99.75% | val_top1 65.57% | val_top5 87.73% | time 34.6s


                                                 

Epoch 122/150 | train_loss 0.0091 | train_top1 99.80% | val_top1 65.56% | val_top5 87.94% | time 34.5s


                                                 

Epoch 123/150 | train_loss 0.0084 | train_top1 99.82% | val_top1 65.84% | val_top5 87.90% | time 34.1s


                                                 

Epoch 124/150 | train_loss 0.0090 | train_top1 99.80% | val_top1 65.31% | val_top5 87.99% | time 33.9s


                                                 

Epoch 125/150 | train_loss 0.0090 | train_top1 99.77% | val_top1 65.52% | val_top5 88.03% | time 34.1s


                                                 

Epoch 126/150 | train_loss 0.0081 | train_top1 99.84% | val_top1 65.42% | val_top5 88.09% | time 34.1s


                                                 

Epoch 127/150 | train_loss 0.0076 | train_top1 99.85% | val_top1 65.79% | val_top5 87.88% | time 34.4s


                                                 

Epoch 128/150 | train_loss 0.0067 | train_top1 99.87% | val_top1 65.85% | val_top5 87.87% | time 34.1s


                                                 

Epoch 129/150 | train_loss 0.0071 | train_top1 99.83% | val_top1 65.71% | val_top5 88.07% | time 34.4s


                                                 

Epoch 130/150 | train_loss 0.0062 | train_top1 99.86% | val_top1 65.90% | val_top5 88.20% | time 34.8s


                                                 

Epoch 131/150 | train_loss 0.0063 | train_top1 99.86% | val_top1 65.95% | val_top5 88.13% | time 34.3s


                                                 

Epoch 132/150 | train_loss 0.0061 | train_top1 99.89% | val_top1 65.96% | val_top5 88.08% | time 34.1s


                                                 

Epoch 133/150 | train_loss 0.0056 | train_top1 99.89% | val_top1 66.10% | val_top5 87.94% | time 34.3s


                                                 

Epoch 134/150 | train_loss 0.0056 | train_top1 99.89% | val_top1 66.06% | val_top5 88.04% | time 34.6s


                                                 

Epoch 135/150 | train_loss 0.0050 | train_top1 99.91% | val_top1 65.89% | val_top5 88.11% | time 34.6s


                                                 

Epoch 136/150 | train_loss 0.0058 | train_top1 99.87% | val_top1 66.02% | val_top5 88.13% | time 34.6s


                                                 

Epoch 137/150 | train_loss 0.0051 | train_top1 99.90% | val_top1 65.94% | val_top5 88.16% | time 34.2s


                                                 

Epoch 138/150 | train_loss 0.0051 | train_top1 99.89% | val_top1 65.86% | val_top5 87.97% | time 34.4s


                                                 

Epoch 139/150 | train_loss 0.0053 | train_top1 99.91% | val_top1 65.88% | val_top5 87.95% | time 34.4s


                                                 

Epoch 140/150 | train_loss 0.0044 | train_top1 99.92% | val_top1 66.07% | val_top5 88.02% | time 34.4s


                                                 

Epoch 141/150 | train_loss 0.0043 | train_top1 99.91% | val_top1 66.02% | val_top5 87.89% | time 34.6s


                                                 

Epoch 142/150 | train_loss 0.0048 | train_top1 99.90% | val_top1 66.13% | val_top5 88.07% | time 34.5s


                                                 

Epoch 143/150 | train_loss 0.0045 | train_top1 99.92% | val_top1 65.97% | val_top5 88.01% | time 34.2s


                                                 

Epoch 144/150 | train_loss 0.0044 | train_top1 99.92% | val_top1 66.12% | val_top5 87.96% | time 34.2s


                                                 

Epoch 145/150 | train_loss 0.0043 | train_top1 99.91% | val_top1 65.95% | val_top5 88.02% | time 34.1s


                                                 

Epoch 146/150 | train_loss 0.0043 | train_top1 99.93% | val_top1 66.07% | val_top5 87.95% | time 34.1s


                                                 

Epoch 147/150 | train_loss 0.0040 | train_top1 99.95% | val_top1 66.19% | val_top5 87.94% | time 34.2s


                                                 

Epoch 148/150 | train_loss 0.0044 | train_top1 99.92% | val_top1 66.10% | val_top5 88.03% | time 34.1s


                                                 

Epoch 149/150 | train_loss 0.0042 | train_top1 99.93% | val_top1 66.24% | val_top5 87.99% | time 34.6s


                                                 

Epoch 150/150 | train_loss 0.0043 | train_top1 99.91% | val_top1 66.04% | val_top5 88.04% | time 34.2s
Training done. Best val top-1: 66.24%
Saved outputs & plots to ./outputs
