## Import packages

In [None]:
!pip install datasets --user

In [1]:
from datasets import load_dataset
from transformers import AutoFeatureExtractor
from functools import partial

import numpy as np
import torch
import torch.nn as nn

import timm.models.vision_transformer
import os

from tqdm import tqdm
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from timm.models.vision_transformer import PatchEmbed, Block

## Food101 Dataset 

original train set: 75,750

original validation set: 25,250

total: 101,000

In [2]:
food101 = load_dataset('food101')

In [3]:
food101['train'][0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=384x512>,
 'label': 6}

In [4]:
from tqdm import tqdm
import numpy as np

total_r = 0
total_g = 0
total_b = 0
total_pixels = 0

# Variables for standard deviation
sum_squared_r = 0
sum_squared_g = 0
sum_squared_b = 0

indices_to_drop = []

for i in tqdm(range(len(food101['train']))):
    img = np.array(food101['train'][i]['image'].resize((224, 224))) / 255.
    
    if img.shape == (224, 224, 3):
        total_r += img[:, :, 0].sum()
        total_g += img[:, :, 1].sum()
        total_b += img[:, :, 2].sum()
        
        # For standard deviation
        sum_squared_r += np.sum(np.square(img[:, :, 0]))
        sum_squared_g += np.sum(np.square(img[:, :, 1]))
        sum_squared_b += np.sum(np.square(img[:, :, 2]))

        total_pixels += 224 * 224
    else:
        indices_to_drop.append(i)

mean_r = total_r / total_pixels
mean_g = total_g / total_pixels
mean_b = total_b / total_pixels

# Compute std for each channel
std_r = np.sqrt((sum_squared_r / total_pixels) - (mean_r ** 2))
std_g = np.sqrt((sum_squared_g / total_pixels) - (mean_g ** 2))
std_b = np.sqrt((sum_squared_b / total_pixels) - (mean_b ** 2))

print(f"Mean RGB: {mean_r}, {mean_g}, {mean_b}")
print(f"Std RGB: {std_r}, {std_g}, {std_b}")

100%|██████████| 75750/75750 [05:05<00:00, 248.06it/s]

Mean RGB: 0.5449871888617703, 0.4434935563380693, 0.34361316599832514
Std RGB: 0.27093838406970966, 0.2734508551865403, 0.2780531622290323





In [5]:
len(indices_to_drop)

3

In [6]:
food101_mean = np.array([mean_r, mean_g, mean_b])
food101_std = np.array([std_r, std_g, std_b])

In [7]:
def transform(example_batch):
    """
    reshape the images into 224 * 224
    """
    inputs = {}
    
    pixel_values = []
    labels = []
    for i in range(len(example_batch['image'])):
        x = example_batch['image'][i]
        y = example_batch['label'][i]
        if np.array(x.resize((224, 224))).shape == (224, 224, 3):
            pixel_values.append(torch.tensor(((np.array(x.resize((224, 224))) / 255. - food101_mean) / food101_std), dtype = torch.float).permute(2, 0, 1))
            labels.append(y)

    inputs['pixel_values'] = pixel_values
    inputs['label'] = labels
    return inputs

In [8]:
np.array(food101['train'][0]['image'].resize((224, 224))).shape

(224, 224, 3)

In [9]:
processed_food101 = food101.with_transform(transform)

In [10]:
train_dataset = processed_food101['train']

In [11]:
indices_to_choose = list(set(range(len(train_dataset))) - set(indices_to_drop))
filtered_train_dataset = train_dataset.select(indices_to_choose)

In [12]:
validation_dataset = processed_food101['validation']

In [13]:
indices_to_choose = list(set(range(len(validation_dataset))) - set(indices_to_drop))
filtered_valid_dataset = validation_dataset.select(indices_to_choose)

## Masked AutoEnocder 

Positional Embedding

In [14]:
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype = float)
    omega /= embed_dim /2.
    omega = 1. / 10000 ** omega # (D/2,)
    
    pos = pos.reshape(-1) # (M,)
    out = np.einsum('m, d->md', pos, omega) # (M, D/2), outer product
    
    emb_sin = np.sin(out) # (M, D/2)
    emb_cos = np.cos(out) # (M, D/2)
    emb = np.concatenate([emb_sin, emb_cos], axis = 1) # (M, D)
    return emb

In [15]:
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0
    
    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
    
    emb = np.concatenate([emb_h, emb_w], axis = 1)
    return emb

In [16]:
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token = False):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype = np.float32)
    grid_w = np.arange(grid_size, dtype = np.float32)
    grid = np.meshgrid(grid_h, grid_w)
    grid = np.stack(grid, axis = 0)
    
    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis = 0)
    return pos_embed

Model Structure

In [17]:
class MaskedAutoencoderViT(nn.Module):
    """
    Masked Autoencoder with VisionTransformer backbone
    """
    def __init__(self, img_size = 224, patch_size = 16, in_chans = 3,\
                 embed_dim = 1024, depth = 24, num_heads = 16,\
                decoder_embed_dim = 512, decoder_depth = 8, decoder_num_heads = 16,\
                mlp_ratio = 4., norm_layer = nn.LayerNorm, norm_pix_loss = False):
        super().__init__()
        
        # ---------------
        # MAE encoder specifics
        self.in_chans = in_chans
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad = False) # fixed sin-cos embedding
        
        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias = True, norm_layer = norm_layer)
            for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)
        # ---------------
        
        # ---------------
        # MAE decoder specifics
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias = True)
        
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        
        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad = False)
        
        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias = True, norm_layer = norm_layer)
            for i in range(decoder_depth)
        ])
        
        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias = True) # decoder to patch
        # ---------------
        
        self.norm_pix_loss = norm_pix_loss
        
        self.initialize_weights()
        
    def initialize_weights(self):
        # initilization
        # initialize (and freeze) pos_embed by sin-cos embedding
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token = True)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        
        decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token = True)
        self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
        
        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
        
        # timm's trunc_normal_(std=.02) is effectively normal_(std=.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)
        
        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)
        
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
            
    def patchify(self, imgs):
        """
        imgs: (N, 3, H, W)
        x: (N, L, patch_size**2 * 3)
        """
        p = self.patch_embed.patch_size[0]
        
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape = (imgs.shape[0], self.in_chans, h, p, w, p))
        x = torch.einsum('nchpwq -> nhpwqc', x)
        x = x.reshape(shape = (imgs.shape[0], h * w, p**2 * self.in_chans))
        return x
    
    def unpatchify(self, x):
        """
        x: (N, L, patch_size**2 * 3)
        imgs: (N, 3, H, W)
        """
        p = self.patch_embed.patch_size[0]
        h = w = int(x.shape[1]**.5)
        assert h * w == x.shape[1]
        
        x = x.reshape(shape = (x.shape[0], h, w, p, p, self.in_chans))
        x = torch.einsum('nhwpqc -> nchpwq', x)
        imgs = x.reshape(shape = (x.shape[0], self.in_chans, h * p, h * p))
        return imgs
    
    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        
        N, L, D = x.shape # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))
        
        noise = torch.rand(N, L, device = x.device) # noise in [0, 1]
        
        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim = 1) # ascend: small is kept, large is removed
        ids_restore = torch.argsort(ids_shuffle, dim = 1)
        
        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim = 1, index = ids_keep.unsqueeze(-1).repeat(1, 1, D))
        
        # generate the binary mask: 0 is kept, 1 is removed
        mask = torch.ones([N, L], device = x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim = 1, index = ids_restore)
        
        return x_masked, mask, ids_restore
    
    def forward_encoder(self, x, mask_ratio):
        # embed patches
        x = self.patch_embed(x)
        
        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]
        
        # masking: length -> length * mask_ratio
        x, mask, ids_restore = self.random_masking(x, mask_ratio)
        
        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim = 1)
        
        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
            
        x = self.norm(x)
        
        return x, mask, ids_restore
        
    def forward_decoder(self, x, ids_restore):
        # embed tokens
        x = self.decoder_embed(x)
        
        # append mask tokens to sequence
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim = 1) # no cls token
        x_ = torch.gather(x_, dim = 1, index = ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
        x = torch.cat([x[:, :1, :], x_], dim = 1) # append cls token
        
        # add pos embed
        x = x + self.decoder_pos_embed
        
        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)
        
        # predictor projection
        x = self.decoder_pred(x)
        
        # remove cls token
        x = x[:, 1:, :]
        
        return x
    
    def forward_loss(self, imgs, pred, mask, batch_mean):
        """
        imgs: [N, 3, H, W]
        pred: [N, L, p*p*3]
        mask: [N, L], 0 is kept, 1 is removed
        """
        target = self.patchify(imgs)
        if self.norm_pix_loss:
            mean = target.mean(dim = -1, keepdim = True)
            var = target.var(dim = -1, keepdim = True)
            target = (target - mean) / (var + 1.e-6) ** .5
        
        loss = (pred - target) ** 2
        loss = loss.mean(dim = -1) # [N, L], mean loss per patch
        if batch_mean:
            loss = (loss * mask).sum() / mask.sum()
        else:
            loss = (loss * mask).sum(dim = 1) / mask.sum(dim = 1)
        return loss
    
    def forward_feature(self, imgs):
        latent, _, _ = self.forward_encoder(imgs)
        return latent[:, 0]
    
    def forward(self, imgs, mask_ratio = 0.75, batch_mean = True):
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
        pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
        loss = self.forward_loss(imgs, pred, mask, batch_mean)
        return loss, self.unpatchify(pred), mask

Model Build-up

In [18]:
def mae_vit_base_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_large_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_huge_patch14_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=14, embed_dim=1280, depth=32, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

In [19]:
# set recommended archs
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b  # decoder: 512 dim, 8 blocks

## Fine-tune

In [20]:
from torch.optim import Adam, AdamW
from torch.utils.data import DataLoader
import torchvision
import numpy as np
from tqdm import tqdm
import torchvision.transforms as T
import matplotlib.pyplot as plt
from PIL import Image
from torchvision.transforms import ToTensor, Compose, Normalize

In [21]:
class Finetune_cls(nn.Module):
    
    def __init__(self, model, d_hidden = 1024, label_classes = 101):
        super().__init__()
        
        self.model = model
        self.proj_label_classes = nn.Linear(d_hidden, label_classes)
        
    def forward(self, x):
            
        cls_state = self.model.forward_encoder(x, 0)[0][:, 0, :]
        
        return self.proj_label_classes(cls_state)

In [22]:
def finetune_epoch(clf: Finetune_cls,
                   train_dataset,
                   batch_size=128,
                   lr=5e-5,
                   device="cuda:0"):
    clf.train()
    loader = DataLoader(train_dataset, batch_size, drop_last=True, shuffle = True)
    
    params_to_update = []
    for name, param in clf.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)

    optimizer = torch.optim.AdamW(params_to_update, lr=lr)
    criterion = nn.CrossEntropyLoss().to(device)

    loss_list = []
    avg_loss = 0
    for batch in tqdm(loader, desc="Train"):
        # Train loop.
        optimizer.zero_grad()
        cls_logit = clf(batch['pixel_values'].to(device))
        
        loss_cls = criterion(cls_logit, batch['label'].to(device).squeeze())
        
        loss = loss_cls

        loss.backward()
        
        # Have gradients at this point.
        nn.utils.clip_grad_norm_(clf.parameters(), max_norm=5.0, norm_type=2)
        optimizer.step()
        
        avg_loss += loss.item()
        
    return loss_list

In [23]:
@torch.no_grad()
def evaluate(clf: Finetune_cls, eval_dataset, batch_size):
    clf.eval()
    loader = DataLoader(eval_dataset, batch_size=batch_size, drop_last=True)

    n_right_classes = 0
    n_total = 0

    for batch in tqdm(loader, desc="Eval"):
        # Compute accuracy.
        cls_logit = clf(batch['pixel_values'].to(device))
        
        pred = cls_logit.argmax(dim=1)
        
        n_right_classes_batch = sum(pred == batch['label'].to(device)).item()
        
        n_right_classes += n_right_classes_batch
        
        n_total += pred.numel()

    print("  Acc_cls:", n_right_classes / n_total)

    return n_right_classes / n_total

In [24]:
def finetune(clf: Finetune_cls, train_dataset, test_dataset, n_epochs: int = 1, model_name=None, **args):
    print("Using device:", args["device"])
    train = train_dataset

    valid = test_dataset
    loss = []
    acc = []
    for epoch in range(n_epochs):
        print(f"Starting epoch {epoch+1}...")
        loss_list = finetune_epoch(clf, train, **args)
        loss += loss_list

        # Save the final checkpoints of the model
        if model_name is not None:
            torch.save(clf, model_path + model_name + 'epoch_' + str(epoch+1) + '.pt')

        acc_i = evaluate(clf, valid, 32)
        acc.append(acc_i)
    
    return loss, acc

In [25]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
EPOCHS = 20
model = torch.load('/gpfs/home/ht2413/mae_pretrain/mae_pretrain_200.pt')
classifier = Finetune_cls(model.to(device))
classifier.to(device)
optimizer = AdamW(model.parameters(), lr = 5e-5, betas=(0.9, 0.95), weight_decay=0.05)

loss = finetune(classifier, filtered_train_dataset, filtered_valid_dataset, EPOCHS, batch_size=64, device=device)

Using device: cuda
Starting epoch 1...


Train: 100%|██████████| 1183/1183 [39:10<00:00,  1.99s/it]
Eval: 100%|██████████| 789/789 [06:22<00:00,  2.06it/s]


  Acc_cls: 0.25871356147021546
Starting epoch 2...


Train: 100%|██████████| 1183/1183 [39:01<00:00,  1.98s/it]
Eval: 100%|██████████| 789/789 [06:23<00:00,  2.06it/s]


  Acc_cls: 0.3986850443599493
Starting epoch 3...


Train: 100%|██████████| 1183/1183 [39:01<00:00,  1.98s/it]
Eval: 100%|██████████| 789/789 [06:21<00:00,  2.07it/s]


  Acc_cls: 0.4858206590621039
Starting epoch 4...


Train: 100%|██████████| 1183/1183 [39:01<00:00,  1.98s/it]
Eval: 100%|██████████| 789/789 [06:22<00:00,  2.06it/s]


  Acc_cls: 0.5266951837769328
Starting epoch 5...


Train: 100%|██████████| 1183/1183 [39:01<00:00,  1.98s/it]
Eval: 100%|██████████| 789/789 [06:17<00:00,  2.09it/s]


  Acc_cls: 0.5366761723700887
Starting epoch 6...


Train: 100%|██████████| 1183/1183 [39:02<00:00,  1.98s/it]
Eval: 100%|██████████| 789/789 [06:20<00:00,  2.07it/s]


  Acc_cls: 0.5430529150823827
Starting epoch 7...


Train: 100%|██████████| 1183/1183 [39:02<00:00,  1.98s/it]
Eval: 100%|██████████| 789/789 [06:21<00:00,  2.07it/s]


  Acc_cls: 0.5311311787072244
Starting epoch 8...


Train: 100%|██████████| 1183/1183 [39:01<00:00,  1.98s/it]
Eval: 100%|██████████| 789/789 [06:18<00:00,  2.09it/s]


  Acc_cls: 0.5226552598225602
Starting epoch 9...


Train: 100%|██████████| 1183/1183 [39:02<00:00,  1.98s/it]
Eval: 100%|██████████| 789/789 [06:21<00:00,  2.07it/s]


  Acc_cls: 0.5160408745247148
Starting epoch 10...


Train: 100%|██████████| 1183/1183 [39:05<00:00,  1.98s/it]
Eval: 100%|██████████| 789/789 [06:20<00:00,  2.07it/s]


  Acc_cls: 0.5062975285171103
Starting epoch 11...


Train: 100%|██████████| 1183/1183 [39:05<00:00,  1.98s/it]
Eval: 100%|██████████| 789/789 [06:20<00:00,  2.07it/s]


  Acc_cls: 0.5085947401774398
Starting epoch 12...


Train: 100%|██████████| 1183/1183 [39:08<00:00,  1.99s/it]
Eval: 100%|██████████| 789/789 [06:22<00:00,  2.06it/s]


  Acc_cls: 0.5020199619771863
Starting epoch 13...


Train: 100%|██████████| 1183/1183 [39:20<00:00,  2.00s/it]
Eval: 100%|██████████| 789/789 [06:21<00:00,  2.07it/s]


  Acc_cls: 0.5023764258555133
Starting epoch 14...


Train: 100%|██████████| 1183/1183 [39:03<00:00,  1.98s/it]
Eval: 100%|██████████| 789/789 [06:20<00:00,  2.07it/s]


  Acc_cls: 0.507802598225602
Starting epoch 15...


Train: 100%|██████████| 1183/1183 [39:04<00:00,  1.98s/it]
Eval: 100%|██████████| 789/789 [06:21<00:00,  2.07it/s]


  Acc_cls: 0.506020278833967
Starting epoch 16...


Train: 100%|██████████| 1183/1183 [39:03<00:00,  1.98s/it]
Eval: 100%|██████████| 789/789 [06:19<00:00,  2.08it/s]


  Acc_cls: 0.5123970215462611
Starting epoch 17...


Train: 100%|██████████| 1183/1183 [39:02<00:00,  1.98s/it]
Eval: 100%|██████████| 789/789 [06:20<00:00,  2.07it/s]


  Acc_cls: 0.5136644486692015
Starting epoch 18...


Train: 100%|██████████| 1183/1183 [39:08<00:00,  1.99s/it]
Eval: 100%|██████████| 789/789 [06:23<00:00,  2.06it/s]


  Acc_cls: 0.5079610266159695
Starting epoch 19...


Train: 100%|██████████| 1183/1183 [39:06<00:00,  1.98s/it]
Eval: 100%|██████████| 789/789 [06:19<00:00,  2.08it/s]


  Acc_cls: 0.5156448035487959
Starting epoch 20...


Train: 100%|██████████| 1183/1183 [39:03<00:00,  1.98s/it]
Eval: 100%|██████████| 789/789 [06:20<00:00,  2.07it/s]

  Acc_cls: 0.5101790240811154





In [26]:
loss[1]

[0.25871356147021546,
 0.3986850443599493,
 0.4858206590621039,
 0.5266951837769328,
 0.5366761723700887,
 0.5430529150823827,
 0.5311311787072244,
 0.5226552598225602,
 0.5160408745247148,
 0.5062975285171103,
 0.5085947401774398,
 0.5020199619771863,
 0.5023764258555133,
 0.507802598225602,
 0.506020278833967,
 0.5123970215462611,
 0.5136644486692015,
 0.5079610266159695,
 0.5156448035487959,
 0.5101790240811154]