In [1]:
### Necessary Imports and dependencies
### Wandb project_name is baseline_ImageNet
import os
import shutil
import time
import math
from enum import Enum
from functools import partial
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as datasets
from torchvision.transforms import v2
import torchvision.transforms as transforms
from typing import Any, Dict, Union, Type, Callable, Optional, List
from torchvision.models.vision_transformer import MLPBlock
import wandb


num_epochs=90

# Parameters specific to CIFAR-10
batch_size = 128
num_workers = 4 

# Dataset loading code
# Define CIFAR-10 datasets
train_dataset = datasets.CIFAR100(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])
    ])
)

val_dataset = datasets.CIFAR100(
    root='./data',
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261]),
    ])
)

n = len(train_dataset)

total_steps = round((n * num_epochs) / batch_size)

start_step=0

mixup = v2.MixUp(alpha=0.2, num_classes=100)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    collate_fn=lambda batch: mixup(*torch.utils.data.default_collate(batch)), 
    drop_last=True
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

warmup_try=10000

# Taken from https://github.com/lucidrains/vit-pytorch, likely ported from https://github.com/google-research/big_vision/
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
    assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
    omega = torch.arange(dim // 4) / (dim // 4 - 1)
    omega = 1.0 / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
    return pe.type(dtype)


class EncoderBlock(nn.Module):
    """Transformer encoder block."""

    def __init__(
        self,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.num_heads = num_heads

        # Attention block
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        # MLP block
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

        # Fix init discrepancy between nn.MultiheadAttention and that of big_vision
        bound = math.sqrt(3 / hidden_dim)
        nn.init.uniform_(self.self_attention.in_proj_weight, -bound, bound)
        nn.init.uniform_(self.self_attention.out_proj.weight, -bound, bound)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        x = self.ln_1(input)
        x, _ = self.self_attention(x, x, x, need_weights=False)
        x = self.dropout(x)
        x = x + input

        y = self.ln_2(x)
        y = self.mlp(y)
        return x + y


class Encoder(nn.Module):
    """Transformer Model Encoder for sequence to sequence translation."""

    def __init__(
        self,
        seq_length: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        layers: OrderedDict[str, nn.Module] = OrderedDict()
        for i in range(num_layers):
            layers[f"encoder_layer_{i}"] = EncoderBlock(
                num_heads,
                hidden_dim,
                mlp_dim,
                dropout,
                attention_dropout,
                norm_layer,
            )
        self.layers = nn.Sequential(layers)
        self.ln = norm_layer(hidden_dim)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        return self.ln(self.layers(self.dropout(input)))


class SimpleVisionTransformer(nn.Module):
    """Vision Transformer modified per https://arxiv.org/abs/2205.01580."""

    def __init__(
        self,
        image_size: int,
        patch_size: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float = 0.0,
        attention_dropout: float = 0.0,
        num_classes: int = 100,
        representation_size: Optional[int] = None,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
        self.image_size = image_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.mlp_dim = mlp_dim
        self.attention_dropout = attention_dropout
        self.dropout = dropout
        self.num_classes = num_classes
        self.representation_size = representation_size
        self.norm_layer = norm_layer

        self.conv_proj = nn.Conv2d(
            in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
        )

        h = w = image_size // patch_size
        seq_length = h * w
        self.register_buffer("pos_embedding", posemb_sincos_2d(h=h, w=w, dim=hidden_dim))

        self.encoder = Encoder(
            seq_length,
            num_layers,
            num_heads,
            hidden_dim,
            mlp_dim,
            dropout,
            attention_dropout,
            norm_layer,
        )
        self.seq_length = seq_length

        heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
        if representation_size is None:
            heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
        else:
            heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
            heads_layers["act"] = nn.Tanh()
            heads_layers["head"] = nn.Linear(representation_size, num_classes)

        self.heads = nn.Sequential(heads_layers)

        if isinstance(self.conv_proj, nn.Conv2d):
            # Init the patchify stem
            fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
            # constant is stddev of standard normal truncated to (-2, 2)
            std = math.sqrt(1 / fan_in) / .87962566103423978
            nn.init.trunc_normal_(self.conv_proj.weight, std=std, a=-2 * std, b=2 * std)
            if self.conv_proj.bias is not None:
                nn.init.zeros_(self.conv_proj.bias)
        elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):
            # Init the last 1x1 conv of the conv stem
            nn.init.normal_(
                self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
            )
            if self.conv_proj.conv_last.bias is not None:
                nn.init.zeros_(self.conv_proj.conv_last.bias)

        if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):
            fan_in = self.heads.pre_logits.in_features
            nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
            nn.init.zeros_(self.heads.pre_logits.bias)

        if isinstance(self.heads.head, nn.Linear):
            nn.init.zeros_(self.heads.head.weight)
            nn.init.zeros_(self.heads.head.bias)

    def _process_input(self, x: torch.Tensor) -> torch.Tensor:
        n, c, h, w = x.shape
        p = self.patch_size
        torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!")
        torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!")
        n_h = h // p
        n_w = w // p

        # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
        x = self.conv_proj(x)
        
        # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
        x = x.reshape(n, self.hidden_dim, n_h * n_w)
        
        
        
        # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
        # The self attention layer expects inputs in the format (N, S, E)
        # where S is the source sequence length, N is the batch size, E is the
        # embedding dimension
        x = x.permute(0, 2, 1)
        
        return x

    def forward(self, x: torch.Tensor):
        # Reshape and permute the input tensor
        x = self._process_input(x)
        x = x + self.pos_embedding
        x = self.encoder(x)
        x = x.mean(dim = 1)
        x = self.heads(x)

        return x
    
def weight_decay_param(n, p):
    return p.ndim >= 2 and n.endswith('weight')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# create model
model = SimpleVisionTransformer(
    image_size=32,
    patch_size=4,
    num_layers=12,
    num_heads=6,
    hidden_dim=384,
    mlp_dim=1536,
).to(device)
wd_params = [p for n, p in model.named_parameters() if weight_decay_param(n, p) and p.requires_grad]
non_wd_params = [p for n, p in model.named_parameters() if not weight_decay_param(n, p) and p.requires_grad]

original_model = model

weight_decay = 0.1
learning_rate = 1e-3

criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.AdamW(
    [
        {"params": wd_params, "weight_decay": 0.1},
        {"params": non_wd_params, "weight_decay": 0.},
    ],
    lr=learning_rate,
)

warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: step / warmup_try)
cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps - warmup_try)
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, [warmup, cosine], [warmup_try])

## Path_to_be_changed
checkpoint_path = "/kaggle/working/"

def save_checkpoint(state, is_best, path, filename='baselinecheckpoint_imagenet.pth.tar'):
    filename = os.path.join(path, filename)
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, os.path.join(path, 'model_best.pth.tar'))

class Summary(Enum):
    NONE = 0
    AVERAGE = 1
    SUM = 2
    COUNT = 3

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
        self.name = name
        self.fmt = fmt
        self.summary_type = summary_type
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def all_reduce(self):
        if torch.cuda.is_available():
            device = torch.device("cuda")
        elif torch.backends.mps.is_available():
            device = torch.device("mps")
        else:
            device = torch.device("cpu")
        total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
        dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
        self.sum, self.count = total.tolist()
        self.avg = self.sum / self.count
    
    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)
    
    def summary(self):
        fmtstr = ''
        if self.summary_type is Summary.NONE:
            fmtstr = ''
        elif self.summary_type is Summary.AVERAGE:
            fmtstr = '{name} {avg:.3f}'
        elif self.summary_type is Summary.SUM:
            fmtstr = '{name} {sum:.3f}'
        elif self.summary_type is Summary.COUNT:
            fmtstr = '{name} {count:.3f}'
        else:
            raise ValueError('invalid summary type %r' % self.summary_type)
        
        return fmtstr.format(**self.__dict__)
    
class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))
        
    def display_summary(self):
        entries = [" *"]
        entries += [meter.summary() for meter in self.meters]
        print(' '.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'
    
def accuracy(output, target, topk=(1,), class_prob=False):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        
        # with e.g. MixUp target is now given by probabilities for each class so we need to convert to class indices
        if class_prob:
            _, target = target.topk(1, 1, True, True)
            target = target.squeeze(dim=1)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(1.0 / batch_size))
        return res
    
log_steps = 2500

wandb.login(key="2d1b2da6b789a71e0c259cede4c9b770b2e44281")

# Initialize a new run
wandb.init(project="fractual_transformer", name="Baseline_ImageNet_run")

def validate(val_loader, model, criterion, step, use_wandb=False, accum_freq=1, print_freq=100):
    
    def run_validate(loader, base_progress=0):
        with torch.no_grad():
            torch.cuda.empty_cache()
            end = time.time()
            for i, (images, target) in enumerate(loader):
                i = base_progress + i

                if torch.cuda.is_available():
                    images = images.cuda(non_blocking=True)
                    target = target.cuda(non_blocking=True)
                elif torch.backends.mps.is_available():
                    images = images.to('mps')
                    target = target.to('mps')

                for img, trt in zip(images.chunk(accum_freq), target.chunk(accum_freq)):
                    # compute output
                    output = model(img)
                    loss = criterion(output, trt)

                    # measure accuracy and record loss
                    acc1, acc5 = accuracy(output, trt, topk=(1, 5))
                    losses.update(loss.item(), img.size(0))
                    top1.update(acc1[0].item(), img.size(0))
                    top5.update(acc5[0].item(), img.size(0))
                    
                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i % print_freq == 0:
                    progress.display(i)

    batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
    losses = AverageMeter('Loss', ':.4e', Summary.NONE)
    top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
    top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    run_validate(val_loader)

    progress.display_summary()

    if use_wandb:        
        log_data = {
            'val/loss': losses.avg,
            'val/acc@1': top1.avg,
            'val/acc@5': top5.avg,
        }
        wandb.log(log_data, step=step)

    return top1.avg

def train(train_loader, val_loader, start_step, total_steps, original_model, model, criterion, optimizer, scheduler, device):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    print_freq = 100  # Print frequency (adjust as needed)
    log_steps = 2500  # Log steps (adjust as needed)
    accum_freq = 1  # Gradient accumulation frequency (adjust as needed)
    
    progress = ProgressMeter(
        total_steps,
        [batch_time, data_time, losses, top1, top5]
    )

    # switch to train mode
    model.train()
    end = time.time()
    best_acc1 = 0

    def infinite_loader():
        while True:
            yield from train_loader

    for step, (images, target) in zip(range(start_step + 1, total_steps + 1), infinite_loader()):
        # measure data loading time
        data_time.update(time.time() - end)

        # move data to the same device as model
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)
        step_loss = step_acc1 = step_acc5 = 0.0

        for img, trt in zip(images.chunk(accum_freq), target.chunk(accum_freq)):
            # compute output
            output = model(img)
            loss = criterion(output, trt)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, trt, topk=(1, 5), class_prob=True)
            step_loss += loss.item()
            step_acc1 += acc1[0].item()
            step_acc5 += acc5[0].item()
            
            # compute gradient
            (loss / accum_freq).backward()

        step_loss /= accum_freq
        step_acc1 /= accum_freq
        step_acc5 /= accum_freq

        losses.update(step_loss, images.size(0))
        top1.update(step_acc1, images.size(0))
        top5.update(step_acc5, images.size(0))

        # do SGD step
        l2_grads = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        
        if step % print_freq == 0:
            print(step)
            progress.display(step)
            if wandb:
                
                with torch.no_grad():
                    l2_params = sum(p.square().sum().item() for _, p in model.named_parameters())

                samples_per_second_per_gpu = batch_size / batch_time.val
                samples_per_second = samples_per_second_per_gpu 
                log_data = {
                    "train/loss": step_loss,
                    'train/acc@1': step_acc1,
                    'train/acc@5': step_acc5,
                    "data_time": data_time.val,
                    "batch_time": batch_time.val,
                    "samples_per_second": samples_per_second,
                    "samples_per_second_per_gpu": samples_per_second_per_gpu,
                    "lr": scheduler.get_last_lr()[0],
                    "l2_grads": l2_grads.item(),
                    "l2_params": math.sqrt(l2_params)
                }
                wandb.log(log_data, step=step)

        if step % log_steps == 0 or step == total_steps:

            acc1 = validate(val_loader, original_model, criterion, step)

            # remember best acc@1 and save checkpoint
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)
            
            save_checkpoint({
                'step': step,
                'state_dict': original_model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
                'scheduler' : scheduler.state_dict()
            }, is_best,checkpoint_path)

        scheduler.step()
        
train(train_loader, val_loader, start_step, total_steps, original_model, model, criterion, optimizer, scheduler, device)

wandb.finish()

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:04<00:00, 35178594.54it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33matharv_m[0m ([33matharv_m-iit-roorkee[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.18.7 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.17.7
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20241122_213057-sfgxihe6[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mBaseline_ImageNet_run[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/atharv_m-iit-roorkee/fractual_transformer[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/atharv_m-iit-roorkee/fractual_transformer/runs/sfgxihe6[0m


100
[  100/35156]	Time  0.204 ( 0.215)	Data  0.000 ( 0.003)	Loss 4.5854e+00 (4.5968e+00)	Acc@1   0.02 (  0.03)	Acc@5   0.11 (  0.13)
200
[  200/35156]	Time  0.204 ( 0.210)	Data  0.000 ( 0.002)	Loss 4.4474e+00 (4.5642e+00)	Acc@1   0.08 (  0.04)	Acc@5   0.23 (  0.15)
300
[  300/35156]	Time  0.204 ( 0.208)	Data  0.000 ( 0.002)	Loss 4.4273e+00 (4.5165e+00)	Acc@1   0.05 (  0.04)	Acc@5   0.16 (  0.16)
400
[  400/35156]	Time  0.203 ( 0.207)	Data  0.000 ( 0.002)	Loss 4.2234e+00 (4.4605e+00)	Acc@1   0.08 (  0.04)	Acc@5   0.24 (  0.18)
500
[  500/35156]	Time  0.204 ( 0.206)	Data  0.000 ( 0.002)	Loss 4.1439e+00 (4.4006e+00)	Acc@1   0.05 (  0.05)	Acc@5   0.30 (  0.20)
600
[  600/35156]	Time  0.203 ( 0.206)	Data  0.000 ( 0.002)	Loss 3.9667e+00 (4.3433e+00)	Acc@1   0.09 (  0.06)	Acc@5   0.30 (  0.21)
700
[  700/35156]	Time  0.202 ( 0.205)	Data  0.000 ( 0.002)	Loss 3.8297e+00 (4.2879e+00)	Acc@1   0.15 (  0.06)	Acc@5   0.37 (  0.23)
800
[  800/35156]	Time  0.204 ( 0.205)	Data  0.000 ( 0.002)	Loss 3.57



10100
[10100/35156]	Time  0.204 ( 0.207)	Data  0.000 ( 0.004)	Loss 3.7110e+00 (2.6262e+00)	Acc@1   0.29 (  0.40)	Acc@5   0.58 (  0.69)
10200
[10200/35156]	Time  0.204 ( 0.207)	Data  0.000 ( 0.004)	Loss 3.7433e+00 (2.6192e+00)	Acc@1   0.19 (  0.40)	Acc@5   0.49 (  0.69)
10300
[10300/35156]	Time  0.203 ( 0.207)	Data  0.000 ( 0.004)	Loss 1.0965e+00 (2.6109e+00)	Acc@1   0.71 (  0.40)	Acc@5   0.95 (  0.69)
10400
[10400/35156]	Time  0.210 ( 0.207)	Data  0.000 ( 0.004)	Loss 3.6550e+00 (2.6019e+00)	Acc@1   0.20 (  0.41)	Acc@5   0.50 (  0.70)
10500
[10500/35156]	Time  0.203 ( 0.207)	Data  0.000 ( 0.004)	Loss 1.2738e+00 (2.5953e+00)	Acc@1   0.73 (  0.41)	Acc@5   0.95 (  0.70)
10600
[10600/35156]	Time  0.204 ( 0.207)	Data  0.000 ( 0.004)	Loss 8.1518e-01 (2.5873e+00)	Acc@1   0.77 (  0.41)	Acc@5   0.98 (  0.70)
10700
[10700/35156]	Time  0.204 ( 0.206)	Data  0.000 ( 0.004)	Loss 9.1377e-01 (2.5799e+00)	Acc@1   0.76 (  0.41)	Acc@5   0.97 (  0.70)
10800
[10800/35156]	Time  0.202 ( 0.206)	Data  0.000 ( 

[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:                 batch_time ▃▃▄▃▄▃▄▃▄▄▁▃▃█▄▃▄▄▄▄▄▃▄▃▄▇▄▃▃▄▄▇▅▄▄▃▄▃▂▄
[34m[1mwandb[0m:                  data_time ▇▃▃▇▆▇▄▆▆▆▇▆▆▆▆▇▃▅▄▄▄▆▇▆▃▆▃█▃▅▅▆▃▇▆▁▃▆▅▃
[34m[1mwandb[0m:                   l2_grads ▆▇███▇▆▇▆▆▇▆▅▅▅▅▅▅▃▂▄▃▅▂▂▂▂▂▃▁▁▁▂▃▁▁▁▄▂▁
[34m[1mwandb[0m:                  l2_params ▁▁▁▁▁▁▁▂▂▃▄▅▆▇▇██████▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▅▄▄▄
[34m[1mwandb[0m:                         lr ▁▂▂▃▄▄▅▅▆▇▇█████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▂▂▂▂▁▁▁▁▁
[34m[1mwandb[0m:         samples_per_second ▆▆▅▅▅▆▅▅▅▅█▆▆▁▅▆▅▅▅▅▅▆▅▆▅▂▅▆▅▅▅▂▄▅▅▆▅▆▇▅
[34m[1mwandb[0m: samples_per_second_per_gpu ▆▆▅▅▅▆▅▅▅▅█▆▆▁▅▆▅▅▅▅▅▆▅▆▅▂▅▆▅▅▅▂▄▅▅▆▅▆▇▅
[34m[1mwandb[0m:                train/acc@1 ▁▂▃▃▃▃▄▄▅▆▅▆▆▆▆▅▇▆██▅█▅██████████████▇██
[34m[1mwandb[0m:                train/acc@5 ▁▄▅▅▅▄▆▆▇▇▇▇██▇▆████▇█▆█████████████████
[34m[1mwandb[0m:                 train/loss █▆▆▆▅▇▆