In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from fastai.vision.all import *
from self_supervised.layers import *
import sklearn

In [3]:
from utils.custom_vit import *
from utils.attention import *
from utils.object_crops import *
from utils.part_crops import *
from utils.multi_crop_model import *

from fastai.callback.wandb import WandbCallback
import wandb

In [4]:
datapath = Path("../data/fgvc-aircraft/")

In [5]:
train_df = pd.read_csv(datapath/'train.csv')
valid_df = pd.read_csv(datapath/'val.csv')
test_df = pd.read_csv(datapath/'test.csv')

In [6]:
train_df.shape, test_df.shape, valid_df.shape

((3334, 3), (3333, 3), (3333, 3))

In [7]:
train_df.head()

Unnamed: 0,filename,Classes,Labels
0,1025794.jpg,707-320,0
1,1340192.jpg,707-320,0
2,0056978.jpg,707-320,0
3,0698580.jpg,707-320,0
4,0450014.jpg,707-320,0


In [8]:
(datapath/'fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images').ls().map(lambda o: o.name)

(#10000) ['0847415.jpg','1053510.jpg','0174925.jpg','1843606.jpg','1446516.jpg','0690072.jpg','1231579.jpg','1164220.jpg','2217846.jpg','0323097.jpg'...]

### Dataset

In [9]:
def read_image(filename): return PILImage.create(datapath/'fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images'/filename)
def read_image_size(filename): return PILImage.create(datapath/'fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images'/filename).shape

In [10]:
train_fns, valid_fns, test_fns = train_df['filename'].values,valid_df['filename'].values,test_df['filename'].values
filenames = np.concatenate([train_fns, valid_fns, test_fns])

In [11]:
fn2label = {**dict(zip(train_df['filename'],train_df['Classes'])),
            **dict(zip(valid_df['filename'],valid_df['Classes'])),
            **dict(zip(test_df['filename'],test_df['Classes']))} 

In [12]:
len(filenames), len(fn2label)

(10000, 10000)

In [13]:
def read_label(filename): return fn2label[filename]

In [14]:
# img_sizes = parallel(read_image_size, filenames)

In [15]:
# Counter(img_sizes)

In [16]:
# valid_filenames = sample_df.query("split == 'valid'")['filename'].values

In [17]:
# filenames = np.random.choice(filenames, 1000)

In [18]:
size,bs = 786,16

tfms = [[read_image, ToTensor, RandomResizedCrop(size, min_scale=.75)], 
        [read_label, Categorize()]]

valid_splitter = lambda o: True if o in test_fns else False 
dsets = Datasets(filenames, tfms=tfms, splits=FuncSplitter(valid_splitter)(filenames))

batch_augs = aug_transforms()

stats = imagenet_stats

batch_tfms = [IntToFloatTensor] + batch_augs + [Normalize.from_stats(*stats)]
dls = dsets.dataloaders(bs=bs, after_batch=batch_tfms)

In [19]:
len(dls.train_ds), len(dls.valid_ds)

(6667, 3333)

In [20]:
dls.c

100

### Model

In [20]:
def interpolate_bil(x,sz): return F.interpolate(x,mode='bilinear',align_corners=True, size=(sz,sz))

def apply_attn_erasing(x, attn_maps, thresh, p=0.5): 
    "x: bs x c x h x w, attn_maps: bs x h x w"
    erasing_mask = (attn_maps>thresh).unsqueeze(1)
    ps = torch.zeros(erasing_mask.size(0)).float().bernoulli(p).to(erasing_mask.device)
    rand_erasing_mask = 1-erasing_mask*ps[...,None,None,None]
    return rand_erasing_mask*x

class ViTEncoder(Module):
    "Timm ViT encoder which return encoder outputs and optionally returns attention weights with gradient checkpointing"
    def __init__(self, vit, nblocks=12, checkpoint_nchunks=2, return_attn_wgts=True):
                
        # initialize params
        self.patch_embed = vit.patch_embed
        self.cls_token = vit.cls_token
        self.pos_embed = vit.pos_embed
        self.pos_drop = vit.pos_drop
        
        # until any desired layers
        self.blocks = vit.blocks[:nblocks]        
        
        # gradient checkpointing
        self.checkpoint_nchunks = checkpoint_nchunks
        
        # return attention weights from L layers
        self.return_attn_wgts = return_attn_wgts
         
    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)
        x = self.pos_drop(x + self.pos_embed)

        # collect attn_wgts from all layers
        if self.return_attn_wgts:
            attn_wgts = []
            for i,blk in enumerate(self.blocks):
                if i<self.checkpoint_nchunks: x,attn_wgt = checkpoint(blk, x)
                else:                         x,attn_wgt = blk(x)
                attn_wgts.append(attn_wgt)
            return x,attn_wgts
        
        else:
            for i,blk in enumerate(self.blocks):
                if i<self.checkpoint_nchunks: x,_ = checkpoint(blk, x)
                else:                         x,_ = blk(x)
            return x
        
    def forward(self, x):
        return self.forward_features(x)
    
    
class MultiCropViT(Module):
    "Multi Scale Multi Crop ViT Model"
    def __init__(self, 
                 encoder, 
                 input_res=384, high_res=786, min_obj_area=112*112, crop_sz=224,
                 crop_object=True, crop_object_parts=True,
                 do_attn_erasing=True, p_attn_erasing=0.5, attn_erasing_thresh=0.7,
                 encoder_nblocks=12, checkpoint_nchunks=12):
        
        store_attr()

        self.image_encoder = ViTEncoder(encoder, nblocks=encoder_nblocks, checkpoint_nchunks=checkpoint_nchunks)
        self.norm = partial(nn.LayerNorm, eps=1e-6)(768)        
        self.classifier = create_cls_module(768, 120, lin_ftrs=[768], use_bn=False, first_bn=False, ps=0.)
    
        
    def forward(self, xb_high_res):

        # get full image attention weigths / feature
        self.image_encoder.return_attn_wgts = True
        xb_input_res = F.interpolate(xb_high_res, size=(self.input_res,self.input_res))
        _, attn_wgts = self.image_encoder(xb_input_res)
        self.image_encoder.return_attn_wgts = False
        
        # get attention maps
        attn_maps = generate_batch_attention_maps(attn_wgts, None, mode=None).detach()
        attn_maps_high_res = interpolate_bil(attn_maps[None,...],self.high_res)[0]
        attn_maps_input_res = interpolate_bil(attn_maps[None,...],self.input_res)[0]
        

        
        #### ORIGINAL IMAGE ####
        # original image attention erasing and features
        if (self.training and self.do_attn_erasing):
            xb_input_res = apply_attn_erasing(xb_input_res, attn_maps_input_res, self.attn_erasing_thresh, self.p_attn_erasing)
        x_full = self.image_encoder(xb_input_res)

        
        
        #### OBJECT CROP ####        
        if self.crop_object:
            # get object bboxes
            batch_object_bboxes = np.vstack([generate_attention_coordinates(attn_map, 
                                                                            num_bboxes=1,
                                                                            min_area=self.min_obj_area,
                                                                            random_crop_sz=self.input_res)
                                                    for attn_map in to_np(attn_maps_high_res)])
            # crop objects
            xb_objects, attn_maps_objects = [], []
            for i, obj_bbox in enumerate(batch_object_bboxes):
                minr, minc, maxr, maxc = obj_bbox
                xb_objects        += [interpolate_bil(xb_high_res[i][:,minr:maxr,minc:maxc][None,...],self.input_res)[0]]
                attn_maps_objects += [interpolate_bil(attn_maps_high_res[i][minr:maxr,minc:maxc][None,None,...],self.input_res)[0][0]]
            xb_objects,attn_maps_objects = torch.stack(xb_objects),torch.stack(attn_maps_objects)

            # object image attention erasing and features
            if (self.training and self.do_attn_erasing):
                xb_objects = apply_attn_erasing(xb_objects, attn_maps_objects, self.attn_erasing_thresh, self.p_attn_erasing)
            x_object = self.image_encoder(xb_objects)
                    
        

        #### OBJECT CROP PARTS ####
        if self.crop_object_parts:
            #get object crop bboxes
            small_attn_maps_objects = interpolate_bil(attn_maps_objects[None,],self.input_res//3)[0] # to speed up calculation
            batch_crop_bboxes = generate_batch_crops(small_attn_maps_objects.cpu(),
                                                     source_sz=self.input_res//3, 
                                                     targ_sz=self.input_res, 
                                                     targ_bbox_sz=self.crop_sz,
                                                     num_bboxes=2,
                                                     nms_thresh=0.1)

            # crop object parts
            xb_crops1,xb_crops2 = [],[]
            for i, crop_bboxes in enumerate(batch_crop_bboxes):
                minr, minc, maxr, maxc = crop_bboxes[0]
                xb_crops1 += [interpolate_bil(xb_objects[i][:,minr:maxr,minc:maxc][None,...],self.input_res)[0]]
                minr, minc, maxr, maxc = crop_bboxes[1]
                xb_crops2 += [interpolate_bil(xb_objects[i][:,minr:maxr,minc:maxc][None,...],self.input_res)[0]]
            xb_crops1,xb_crops2 = torch.stack(xb_crops1),torch.stack(xb_crops2)

            # crop features
            x_crops1 = self.image_encoder(xb_crops1)
            x_crops2 = self.image_encoder(xb_crops2)
        
        
        # predict
        x_full = self.norm(x_full)[:,0]
        if self.crop_object:
            x_object = self.norm(x_object)[:,0]
            if self.crop_object_parts:
                x_crops1 = self.norm(x_crops1)[:,0]
                x_crops2 = self.norm(x_crops2)[:,0]
                return self.classifier(x_full), self.classifier(x_object), self.classifier(x_crops1), self.classifier(x_crops2)
            return self.classifier(x_full), self.classifier(x_object)
        return  self.classifier(x_full)

### Training

In [21]:
def model_splitter(m): return L(m.image_encoder, m.norm, m.classifier).map(params)

In [22]:
class LossFuncA(Module): # only object
    def __init__(self):               self.lf = LabelSmoothingCrossEntropyFlat(0.1)
    def forward(self, preds, targs):  return self.lf(preds[1],targs)
    
class LossFuncB(Module): # full + object
    def __init__(self):               self.lf = LabelSmoothingCrossEntropyFlat(0.1)
    def forward(self, preds, targs):  return self.lf(preds[0],targs) + self.lf(preds[1],targs)
    
class LossFuncC(Module): # full + object + crops
    def __init__(self):               self.lf = LabelSmoothingCrossEntropyFlat(0.1)
    def forward(self, preds, targs):  return self.lf(preds[0],targs) + self.lf(preds[1],targs) + (self.lf(preds[2],targs)+self.lf(preds[3],targs))/2

In [23]:
def accuracyA(preds, targs): return accuracy(preds[0], targs) # full
def accuracyB(preds, targs): return accuracy(preds[1], targs) # full, object
def accuracyC(preds, targs): return accuracy((preds[2]+preds[3])/2, targs) # full, object, crops

In [24]:
import gc

In [None]:
for i in [0,4,6]:

    if i == 0:
        # exp 1 - full image
        model_config = dict(crop_object=False, crop_object_parts=False, do_attn_erasing=False, p_attn_erasing=0.5, attn_erasing_thresh=0.7)
        loss_func = LabelSmoothingCrossEntropyFlat(0.1)
        metrics =[accuracy] 

    if i == 1:
        # exp 2 - full image
        model_config = dict(crop_object=False, crop_object_parts=False, do_attn_erasing=True, p_attn_erasing=0.5, attn_erasing_thresh=0.7)
        loss_func = LabelSmoothingCrossEntropyFlat(0.1)
        metrics =[accuracy] 

    if i == 2:
        # exp 3 - object
        model_config = dict(crop_object=True, crop_object_parts=False, do_attn_erasing=False, p_attn_erasing=0.5, attn_erasing_thresh=0.7)
        loss_func = LossFuncA()
        metrics =[accuracyB] 

    if i == 3:
        # exp 4 - object
        model_config = dict(crop_object=True, crop_object_parts=False, do_attn_erasing=True, p_attn_erasing=0.5, attn_erasing_thresh=0.7)
        loss_func = LossFuncA()
        metrics =[accuracyB] 

    if i == 4:
        # exp 5 - full image + object
        model_config = dict(crop_object=True, crop_object_parts=False, do_attn_erasing=False, p_attn_erasing=0.5, attn_erasing_thresh=0.7)
        loss_func = LossFuncB()
        metrics =[accuracyA, accuracyB] 

    if i == 5:
        # exp 6 - full image + object
        model_config = dict(crop_object=True, crop_object_parts=False, do_attn_erasing=True, p_attn_erasing=0.5, attn_erasing_thresh=0.7)
        loss_func = LossFuncB()
        metrics =[accuracyA, accuracyB]

    if i == 6:
        # exp 7 - full image + object + crops
        model_config = dict(crop_object=True, crop_object_parts=True, do_attn_erasing=False, p_attn_erasing=0.5, attn_erasing_thresh=0.7)
        loss_func = LossFuncC()
        metrics =[accuracyA, accuracyB, accuracyC]

    if i == 7:
        # exp 8 - full image + object + crops
        model_config = dict(crop_object=True, crop_object_parts=True, do_attn_erasing=True, p_attn_erasing=0.5, attn_erasing_thresh=0.7)
        loss_func = LossFuncC()
        metrics =[accuracyA, accuracyB, accuracyC]

    # modified timm vit encoder
    arch = "vit_base_patch16_384"
    _encoder = create_encoder(arch, pretrained=True, n_in=3)
    encoder = VisionTransformer(img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12)
    encoder.head = Identity()
    encoder.load_state_dict(_encoder.state_dict());

    mcvit_model = MultiCropViT(encoder, input_res=384, high_res=786, min_obj_area=128*128, crop_sz=224,
                                 encoder_nblocks=12, checkpoint_nchunks=12, **model_config)

    WANDB = True
    if WANDB:
        xtra_config = model_config
        xtra_config.update({"Dataset":"FGVC-Aircraft"})
        wandb.init(project="fgvc-2021", config=xtra_config);

    cbs = []
    if WANDB: cbs += [WandbCallback(log_preds=False,log_model=False)]
    learn = Learner(dls, mcvit_model, opt_func=ranger, cbs=cbs, metrics=metrics, loss_func=loss_func, splitter=model_splitter)
    learn.to_fp16();
    
    epochs = 4
    lr = 3e-3
    
    learn.freeze_to(1)
    learn.fit_one_cycle(epochs, lr_max=(lr), pct_start=0.5)

    lr /= 3
    learn.unfreeze()
    learn.fit_one_cycle(int(epochs**2), lr_max=[lr/25,lr,lr], pct_start=0.5)
    
    del learn, encoder, mcvit_model
    gc.collect()
    
    torch.cuda.empty_cache()
    wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkeremturgutlu[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.25 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


epoch,train_loss,valid_loss,accuracy,time
0,3.636154,3.459458,0.181518,02:35
1,3.115287,3.159331,0.252025,02:36
2,2.553414,2.631083,0.415242,02:36
3,2.190985,2.462992,0.463846,02:36




epoch,train_loss,valid_loss,accuracy,time
0,2.020633,2.376261,0.49715,04:41
1,1.903404,2.253631,0.543354,04:41
2,1.814684,2.149034,0.574857,04:42
3,1.70384,2.035504,0.605161,04:42
4,1.560426,1.946473,0.642664,04:42
5,1.441374,1.93294,0.666967,04:42
6,1.3142,1.808912,0.686169,04:42
7,1.197915,1.81819,0.690969,04:42
8,1.102419,1.779369,0.709871,04:42
9,1.045751,1.716211,0.728173,04:42


VBox(children=(Label(value=' 0.02MB of 0.02MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,20.0
train_loss,0.87068
raw_loss,0.87086
wd_0,0.01
sqr_mom_0,0.99
lr_0,0.0
mom_0,0.95
eps_0,0.0
beta_0,0.0
wd_1,0.01


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
train_loss,█▇▆▆▅▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
raw_loss,█▇▆▅▅▅▃▃▄▃▂▃▂▃▃▂▃▃▂▂▁▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁
wd_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
sqr_mom_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr_0,▂▃▆▇█▆▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mom_0,█▆▃▂▁▃▅████▇▇▆▆▅▄▄▃▂▂▁▁▁▁▁▁▂▂▃▄▄▅▆▆▇▇███
eps_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
beta_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wd_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁


[34m[1mwandb[0m: wandb version 0.10.25 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


epoch,train_loss,valid_loss,accuracy,time
0,3.694692,3.513352,0.176118,02:37
1,3.254036,3.184973,0.230423,02:37
2,2.751527,2.627235,0.414341,02:37
3,2.34306,2.491151,0.461446,02:38


epoch,train_loss,valid_loss,accuracy,time
0,2.15637,2.416022,0.486349,04:42
1,2.083469,2.32056,0.509751,04:42
2,1.930871,2.169719,0.562856,04:42
3,1.8622,2.068355,0.59526,04:41
4,1.732558,1.945807,0.642664,04:41
5,1.614158,1.983082,0.624662,04:42
6,1.475708,1.803272,0.692469,04:42
7,1.354951,1.755215,0.69517,04:42
8,1.25389,1.818737,0.682568,04:43
9,1.156467,1.73438,0.708371,04:43


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,20.0
train_loss,0.92038
raw_loss,0.883
wd_0,0.01
sqr_mom_0,0.99
lr_0,0.0
mom_0,0.95
eps_0,0.0
beta_0,0.0
wd_1,0.01


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
train_loss,█▇▆▆▅▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
raw_loss,█▇▅▅▄▄▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
wd_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
sqr_mom_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr_0,▂▃▆▇█▆▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mom_0,█▆▃▂▁▃▅████▇▇▆▆▅▄▄▃▂▂▁▁▁▁▁▁▂▂▃▄▄▅▆▆▇▇███
eps_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
beta_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wd_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁


[34m[1mwandb[0m: wandb version 0.10.25 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


epoch,train_loss,valid_loss,accuracyA,accuracyB,time
0,7.221924,6.918906,0.188119,0.186619,09:57
1,6.209665,6.209541,0.254725,0.259826,09:55
2,4.996994,5.067193,0.432043,0.446445,09:55
3,4.222188,4.769153,0.483948,0.50195,09:55


epoch,train_loss,valid_loss,accuracyA,accuracyB,time
0,3.821311,4.632324,0.510651,0.522652,13:53
1,3.617161,4.369349,0.562856,0.563756,13:53
2,3.434632,4.129012,0.592859,0.615062,13:50
3,3.238214,3.882668,0.635464,0.651665,13:51
4,2.944451,3.723414,0.657666,0.685269,13:49
5,2.722154,3.713037,0.653165,0.673867,13:51
6,2.509734,3.491366,0.70267,0.724272,13:50
7,2.288327,3.581507,0.692169,0.707471,13:49
8,2.172228,3.365032,0.723972,0.751575,13:47
9,2.026994,3.265603,0.756676,0.774077,13:48


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,20.0
train_loss,1.7227
raw_loss,1.73072
wd_0,0.01
sqr_mom_0,0.99
lr_0,0.0
mom_0,0.95
eps_0,0.0
beta_0,0.0
wd_1,0.01


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
train_loss,█▇▆▆▅▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
raw_loss,█▇▅▆▅▅▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wd_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
sqr_mom_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr_0,▂▃▆▇█▆▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mom_0,█▆▃▂▁▃▅████▇▇▆▆▅▄▄▃▂▂▁▁▁▁▁▁▂▂▃▄▄▅▆▆▇▇███
eps_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
beta_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wd_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁


[34m[1mwandb[0m: wandb version 0.10.25 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


epoch,train_loss,valid_loss,accuracyA,accuracyB,time
0,7.316598,6.992489,0.174617,0.174917,09:54
1,6.321318,6.612463,0.209121,0.211221,09:55
2,5.397838,5.185391,0.417342,0.414641,09:54
3,4.453709,4.892651,0.457846,0.477648,09:54


epoch,train_loss,valid_loss,accuracyA,accuracyB,time
0,4.243174,4.756458,0.493549,0.49565,13:45
1,4.047066,4.449395,0.541854,0.555656,13:48
2,3.801645,4.205965,0.583858,0.60276,13:47
3,3.466686,4.174607,0.586259,0.59706,13:47
4,3.305728,3.721625,0.660066,0.683768,14:04
5,3.064926,3.586489,0.682568,0.708671,13:45
6,2.743096,3.646092,0.675067,0.690069,13:45
7,2.528816,3.320656,0.734173,0.759976,13:46
8,2.33376,3.381889,0.723972,0.749775,13:47


### fin