In [None]:
#|default_exp tmp

# Original Chris 94.9 on 5ep (not reporduced)

In [None]:
import pickle,gzip,math,os,time,shutil,torch,random
import fastcore.all as fc,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
from collections.abc import Mapping
from pathlib import Path
from operator import attrgetter,itemgetter
from functools import partial
from copy import copy
from contextlib import contextmanager

import torchvision.transforms.functional as TF,torch.nn.functional as F
from torch import tensor,nn,optim
from torch.utils.data import DataLoader,default_collate
from torch.nn import init
from torch.optim import lr_scheduler
from torcheval.metrics import MulticlassAccuracy
from datasets import load_dataset,load_dataset_builder

from miniai.datasets import *
from miniai.conv import *
from miniai.learner import *
from miniai.activations import *
from miniai.init import *
from miniai.sgd import *
from miniai.resnet import *

In [None]:
# speed mods
from miniai.utils import *

MetricsCB = LazyMetricsCB
ProgressCB = LazyProgressCB

In [None]:
from fastcore.test import test_close
from torch import distributions

torch.set_printoptions(precision=8, linewidth=140, sci_mode=False)
mpl.rcParams['image.cmap'] = 'gray'

import logging
logging.disable(logging.WARNING)

if fc.defaults.cpus>8: fc.defaults.cpus=8

In [None]:
class Dropout(nn.Module):
    def __init__(self, p=0.1):
        super().__init__()
        self.p = p

    def forward(self, x):
        if not self.training: return x
        dist = distributions.binomial.Binomial(tensor(1.0).to(x.device), probs=1-self.p)
        return x * dist.sample(x.size()) * 1/(1-self.p)

In [None]:
#Mish - "Mish: A Self Regularized Non-Monotonic Neural Activation Function"
#https://arxiv.org/abs/1908.08681v1
#implemented for PyTorch / FastAI by lessw2020 
#github: https://github.com/lessw2020/mish

class Mish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        #inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!)
        return x *( torch.tanh(F.softplus(x)))

In [None]:
from typing import Iterator
from torch.utils.data import DataLoader, WeightedRandomSampler

class TopLossesCallback(Callback):
  epoch_preds = []
  epoch_targets = []
  
  @torch.no_grad()
  def _calculate_top_losses(self):
    preds = torch.cat(self.epoch_preds, dim=0)
    targets = torch.cat(self.epoch_targets, dim=0)
    losses = F.cross_entropy(preds, targets, reduce=False)
    return torch.topk(losses, preds.shape[0]).indices

  def after_batch(self, learn):
    if not learn.model.training:
      return
    self.epoch_preds.append(learn.preds)
    self.epoch_targets.append(learn.batch[1])

  def before_epoch(self, learn):
    self.epoch_preds = []
    self.epoch_targets = []

  def after_epoch(self, learn):
    if not learn.model.training:
      return
    learn.dls.train.sampler.top_losses = self._calculate_top_losses()

# tweaked from tommyc's version
# Before certain epoch drop a % of the training dataset with the lowest losses.
# Replace them with the the same % of the training dataset with the highest losses.
# This gives the model two opportunities to train on the most challenging images.
class CustomTrainingSampler(WeightedRandomSampler):
  def __init__(self, *args, **kwargs):
    WeightedRandomSampler.__init__(self, *args, **kwargs)
    self.data_indexes_for_epoch = []
    self.top_losses = []
    self.epoch = -1
    self.n = {
        0: 0,
        1: 0.21,
        2: 0.42,
        3: 0.21,
        4: 0
    }

  def __iter__(self) -> Iterator[int]:
      self.epoch += 1
      rand_tensor = torch.randperm(self.num_samples, generator=self.generator).tolist()
      n = int(self.n[self.epoch] * self.num_samples)

      if n != 0:
        # TODO: Cleanup the code below
        inverted_losses_for_epochs = torch.flip(torch.tensor(self.top_losses.clone().detach()[:n]), dims=(0,)).cpu()
        self.data_indexes_for_epoch = torch.tensor(self.data_indexes_for_epoch).cpu()
        self.data_indexes_for_epoch[self.top_losses[-n:].cpu()] = self.data_indexes_for_epoch[inverted_losses_for_epochs]
        self.data_indexes_for_epoch = self.data_indexes_for_epoch[rand_tensor]
        self.data_indexes_for_epoch = self.data_indexes_for_epoch.tolist()
      else:
        self.data_indexes_for_epoch = rand_tensor

      yield from self.data_indexes_for_epoch


class CustomDataLoader:
    def __init__(self, *dls): 
      self.train,self.valid = dls[:2]

    def get_sampler(num_samples, mode="train"):
      if mode != "train":
        return None
      return CustomTrainingSampler(weights=[1 for _ in range(num_samples)], num_samples=num_samples)

    @classmethod
    def from_dd(cls, dd, batch_size, as_tuple=True, **kwargs):
      return cls(*[DataLoader(ds, batch_size, sampler=cls.get_sampler(len(ds), mode), collate_fn=collate_dict(ds), **kwargs) for mode, ds in dd.items()])

In [None]:
xl,yl = 'image','label'
name = "fashion_mnist"
bs = 256
xmean,xstd = 0.28, 0.35

@inplace
def transformi(b): b[xl] = [(TF.to_tensor(o)-xmean)/xstd for o in b[xl]]

dsd = load_dataset(name)
tds = dsd.with_transform(transformi)


  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
tds.cached = cache_dataset_as_dict(tds)

In [None]:
def get_model9(act=nn.ReLU, nfs=(32,288,288,288,288,288), norm=nn.BatchNorm2d):#,256
    layers = [ResBlock(1, 32, ks=5, stride=1, act=act, norm=norm)]
    layers += [ResBlock(nfs[i], nfs[i+1], act=act, norm=norm, stride=2) for i in range(len(nfs)-1)]
    layers += [nn.Flatten(), nn.Linear(nfs[-1], 10, bias=False), nn.BatchNorm1d(10)]
    return nn.Sequential(*layers).to(def_device)

In [None]:
from torchvision import transforms

In [None]:
def tfm_batch(b, tfm_x=fc.noop, tfm_y = fc.noop): return tfm_x(b[0]),tfm_y(b[1])

In [None]:
tfms = nn.Sequential(transforms.RandomCrop(28, padding=1),
                     transforms.RandomHorizontalFlip(0.65))
augcb = BatchTransformCB(partial(tfm_batch, tfm_x=tfms), on_val=False)

In [None]:
#| export
class CapturePreds(Callback):
    def before_fit(self, learn): self.all_preds,self.all_targs = [],[]
    def after_batch(self, learn):
        self.all_preds.append(to_cpu(learn.preds))
        self.all_targs.append(to_cpu(learn.batch[1]))
    def after_fit(self, learn): self.all_preds,self.all_targs = torch.cat(self.all_preds),torch.cat(self.all_targs)

In [None]:
#| export
@fc.patch
def capture_preds(self: Learner, cbs=None):
    cp = CapturePreds()
    self.fit(1, train=False, cbs=[cp]+fc.L(cbs))
    return cp.all_preds,cp.all_targs

# run

In [None]:
dls = DataLoaders.from_dd(tds.cached, bs, num_workers=0)

# tweaked from rohitgeo's version
metrics = MetricsCB(accuracy=MulticlassAccuracy())
cbs = [DeviceCB(), metrics] 
#0.0003 from https://github.com/digantamisra9

In [None]:
def upscale_cb(scale, mode='bilinear'): 
    return BatchTransformCB(lambda b: (F.interpolate(b[0], scale_factor=scale, mode=mode),b[1]),
                            on_val=True, on_train=True)

In [None]:
import timm

In [None]:
rng = rng_seed

In [None]:
def run(model, leaky=0.0003, seed=1, m=1, cbs=tuple(), fit=True, train_cb=TrainCB(), epochs=5, base_lr=2e-2, 
        loss_func=F.cross_entropy, bs=bs, tta=False, dls=None, verbose=True):
    rng.set_seed(seed)
    if verbose: print(torch.randn([3]))
    iw = partial(init_weights, leaky=leaky) if leaky is not None else fc.noop
    lr = base_lr*m
    if verbose: print("Batch size", bs*m)
    dls = dls or DataLoaders.from_dd(tds, bs*m, num_workers=0) 
    tmax = epochs * len(dls.train)
    sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
    cbs = [DeviceCB(), rng_seed, metrics, BatchSchedCB(sched), *cbs, train_cb] 
    if verbose: print(torch.randn([3]))
    model=model.apply(iw)
    learn = Learner(model, dls, loss_func, lr=lr, cbs=cbs, opt_func=optim.AdamW)
    if verbose: print(torch.randn([3]))
    if verbose: print(next(iter(learn.dls.train))[1])
    if fit:
        learn.fit(epochs, cbs=[TimeItCB(), ProgressCB(plot=True)])
    if tta:
        ## TTA
        ap1, at = learn.capture_preds()
        ttacb = BatchTransformCB(partial(tfm_batch, tfm_x=TF.hflip), on_val=True)
        ap2, at = learn.capture_preds(cbs=[ttacb])
        ap = torch.stack([ap1,ap2]).mean(0).argmax(1)
        if verbose: print('TTA:', round((ap==at).float().mean().item(), 4))
    
    return learn

# MixUp v3

### dynamic mixup loss function

In [None]:
xb,yb = fc.first(dls.train)

In [None]:
yab = torch.tensor([1.,0.,0.,1.])
s = torch.tensor([0.1,0.2,0.3,0.4])

#yab.T @ s @ yab


In [None]:
from types import SimpleNamespace
def mixup_data(b2, b1=None, sampler=torch.distributions.Beta(tensor(0.5), tensor(0.5)).sample, classes=10, permute_1=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    r = SimpleNamespace()
    if b1 is None: b1=b2 # first batch uses it self
    x1,y1 = b1
    x2,y2 = b2
    
    if x1.shape[0] != x2.shape[0]: x1,y1=x2,y2 # last batch uses it self
    
    # permute prev batch
    batch_size = x2.shape[0]
    if permute_1:
        index = torch.randperm(batch_size).to(x2.device)
        x1,y1 = x1[index], y1[index]
    
    r.lam = sampler([batch_size]).to(x2.device)
    r.mixed_x = torch.lerp(x2, x1, r.lam.reshape(-1,*[1]*(len(x2.shape)-1)))
    
    r.y1 = F.one_hot(y1, num_classes=classes)
    r.y2 = F.one_hot(y2, num_classes=classes)
    r.mixed_y = r.y2 + r.lam.reshape(-1,1)*(r.y1-r.y2)
    #mixed_y = torch.lerp(y2, y2, lam.reshape(-1,1))
    return r


In [None]:
r = mixup_data((xb,yb))
show_images(r.mixed_x[:16], imsize=1.5)
r.lam[:16]

In [None]:
set_seed(1)
b = 20
a = 10.
x1 = tensor([[1,  1, a, 1],
             [a,  1, 1, 1]]).float()
x2 = tensor([[1,  b, 1, 1],
             [1,  1, 1, b]]).float()
y1 = (x1).argmax(-1)
y2 = (x2).argmax(-1) 

print(y1, y2)
r = mixup_data((x1,y1), (x2,y2), classes=4, permute_1=False,)


In [None]:
(0.73 * b+ 0.27*1), 0.2689*a+ 0.73*1

In [None]:
# lets construct the mask manually
maskb = r.mixed_x>14
maska = (r.mixed_x>2) & (~maskb)
maska

In [None]:
pred=r.mixed_x[~maskb].reshape(2, -1) 
lbl=maska[~maskb].reshape(2, -1).float()
print(pred,lbl)
F.cross_entropy(pred, lbl, reduction='mean').item()

In [None]:
def ce_masked(preds, y, ignored_y, **kwargs):
    """
    y - one hot encoded label
    ingored_y - one hot encloded of ignored class
    """
    N, C = preds.shape
    mask = ignored_y==0
    mpreds = preds[mask].reshape(N, C-1) # to fail early
    my = y[mask].float().reshape(N, C-1)
    return F.cross_entropy(mpreds, my, **kwargs)

ce_masked(r.mixed_x, r.y2, r.y1).item()

In [None]:
ce_masked(r.mixed_x, r.y1, r.y1).item() + ce_masked(r.mixed_y, r.y2, r.y2).item() 

In [None]:
#https://github.com/Westlake-AI/openmixup/blob/c042813ee0af577d365f0e13b13a4c8486d6e8f7/openmixup/models/losses/cross_entropy_loss.py#L83

def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
    """Apply element-wise weight and reduce loss.
    Args:
        loss (Tensor): Element-wise loss tensor.
        weight (Tensor): Element-wise weights.
        reduction (str): Same as built-in losses of PyTorch. Options are "none",
            "mean" and "sum".
        avg_factor (float): Avarage factor when computing the mean of losses.
    Returns:
        Tensor: Processed loss values.
    """
    # if weight is specified, apply element-wise weight
    if weight is not None:
        loss = loss * weight

    reduction_enum = F._Reduction.get_enum(reduction)
    # if avg_factor is not specified, just reduce the loss
    if avg_factor is None:
        # none: 0, elementwise_mean:1, sum: 2
        if reduction_enum == 1:
            loss = loss.mean()
        elif reduction_enum == 2:
            loss = loss.sum()
    else:
        # if reduction is 'mean', then average the loss by avg_factor
        if reduction_enum == 1:
            loss = loss.sum() / avg_factor
        # if reduction is 'none', then do nothing; otherwise raise an error
        elif reduction != 0:
            raise ValueError('avg_factor can not be used with reduction="sum"')
    return loss

def soft_mix_cross_entropy(pred,
                           label,
                           weight=None,
                           reduction='mean',
                           class_weight=None,
                           avg_factor=None,
                           eta_weight=None,
                           eps_smooth=1e-3,
                           verbose=False,
                           **kwargs):
    r"""Calculate the Soft Decoupled Mixup CrossEntropy loss using softmax
        The label can be float mixup label (class-wise sum to 1, k-mixup, k>=2).
       *** Warnning: this mixup and label-smoothing cannot be set simultaneously ***
    Decoupled Mixup for Data-efficient Learning. In arXiv, 2022.
    <https://arxiv.org/abs/2203.10761>
    Args:
        pred (torch.Tensor): The prediction with shape (N, C), C is the number
            of classes.
        label (torch.Tensor): The gt label of the prediction with shape (N, C).
            When using "mixup", the label can be float (mixup one-hot label).
        weight (torch.Tensor, optional): Sample-wise loss weight.
        reduction (str): The method used to reduce the loss.
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.
        class_weight (torch.Tensor, optional): The weight for each class with
            shape (C), C is the number of classes. Default None.
        eta_weight (list): Reweight the global loss in mixup cls loss as,
            loss = loss_local + eta_weight[i] * loss_global[i]. Default to None.
        eps_smooth (float): If using label smoothing, we assume eps < lam < 1-eps.
    Returns:
        torch.Tensor: The calculated loss
    """
    # *** Assume k-mixup in C classes, k >= 2 and k << C ***
    # step 1: remove labels have less than k-hot (mixed between the
    #    same class will result in the original onehot)
    _eps = max(1e-3, eps_smooth)  # assuming _eps < lam < 1-_eps
    mask_one = (label > _eps).sum(dim=-1)
    mix_num = max(mask_one)
    mask_one = mask_one >= mix_num
    if mask_one.sum() < label.size(0):
        pred_one = pred[mask_one==False, :]
        label_one = label[mask_one==False, :]
        pred = pred[mask_one, :]
        label = label[mask_one, :]
        weight_one = None
        if weight is not None:
            weight_one = weight[mask_one==False, ...].float()
            weight = weight[mask_one, ...].float()
        if verbose: print(f"pred_one: {mask_one=} {pred_one=}")
    else:
        if weight is not None:
            weight = weight.float()
        pred_one, label_one, weight_one = None, None, None
        if verbose: print(f"no pred_one {mask_one=}")
    # step 2: select k-mixup for the local and global
    bs, cls_num = label.size()  # N, C
    assert isinstance(eta_weight, list)
    # local: between k classes
    mask_lam_k = label > _eps  # [N, N], top k is true
    lam_k = label[0, label[0, :] > _eps]  # [k,] k-mix relevant classes

    # local: original mixup CE loss between C classes
    loss = -label * F.log_softmax(pred, dim=-1)  # [N, N]
    if class_weight is not None:
        loss *= class_weight
    loss = loss.sum(dim=-1)  # reduce class

    # global: between lam_i and C-k classes
    if len(set(lam_k.cpu().numpy())) == lam_k.size(0) and lam_k.size(0) > 1:
        if verbose: print("calculating global loss for", lam_k, 'loss so far', loss)
        # *** trivial solution: lam=0.5, lam=1.0 ***
        assert len(eta_weight) == lam_k.size(0), \
            "eta weight={}, lam_k={}".format(eta_weight, lam_k)
        for i in range(lam_k.size(0)):
            # selected (C-k+1), except lam_k[j], where j!=i (k-1)
            mask_lam_i = (label == lam_k[i]) | ~mask_lam_k  # [N, N]
            pred_lam_i  = pred.reshape([1, bs, -1])[:, mask_lam_i].reshape(
                [-1, cls_num+1-lam_k.size(0)])  # [N, C-k+1]
            label_lam_i = label.reshape([1, bs, -1])[:, mask_lam_i].reshape(
                [-1, cls_num+1-lam_k.size(0)])  # [N, C-k+1]
            # convert to onehot
            label_lam_i = (label_lam_i > 0).type(torch.float)
            # element-wise losses
            loss_global = -label_lam_i * F.log_softmax(pred_lam_i, dim=-1)  # [N, C-1]
            if class_weight is not None:
                loss_global *= class_weight
            # eta reweight
            if verbose: print(f"global loss: {loss_global.sum(dim=-1)} for {lam_k[i]}")
            loss += eta_weight[i] * loss_global.sum(dim=-1)  # reduce class
    # apply weight and do the reduction
    loss = weight_reduce_loss(
        loss, weight=weight, reduction=reduction, avg_factor=avg_factor)

    # step 3: original soft CE loss
    if label_one is not None:
        loss_one = -label_one * F.log_softmax(pred_one, dim=-1)
        if class_weight is not None:
            loss_one *= class_weight
        loss_one = loss_one.sum(dim=-1)  # reduce class
        if verbose: print(f"loss_one: {loss_one=} {loss=}")

        loss_one = weight_reduce_loss(
            loss_one, weight=weight_one, reduction=reduction, avg_factor=avg_factor)
        loss += loss_one or 0.0

    return loss


In [None]:
def dmce(preds, r, eta=0.1, verbose=False, **kw): 
    mce = F.cross_entropy(preds, r.mixed_y, **kw) 
    dmce1 = ce_masked(preds, r.y1, r.y2, **kw)
    dmce2 = ce_masked(preds, r.y2, r.y1, **kw)
    if verbose: print(f"{mce} + {eta}*{dmce1} + {eta}*{dmce2}")
    return (mce + eta*dmce1 + eta*dmce2)

def dmce_c(preds, r, eta=[0.1,0.1], verbose=False, **kw):
    return soft_mix_cross_entropy(preds, r.mixed_y, eta_weight=eta, verbose=verbose)                

In [None]:
rng_seed.set_new(1)
b = 20
a = 11.
x1 = tensor([[1,  1, a, 1],
             [a,  1, 1, 1]]).float()
x2 = tensor([[1,  b, 1, 1],
             [1,  1, 1, b]]).float()
y1 = (x1).argmax(-1)
y2 = (x2).argmax(-1) 

print(y1, y2)
r = mixup_data((x1,y1), (x2,y2), classes=4, permute_1=False, sampler=lambda x: torch.tensor([0.3, 1.0]))

dmce_c(r.mixed_x, r, verbose=True), dmce(r.mixed_x, r, verbose=True)

In [None]:
r.mixed_y

In [None]:
def mce(preds, r, eta=None,**kw): return F.cross_entropy(preds, r.mixed_y,**kw)

## MixUp4CB

In [None]:
class MixUp4CB(TrainCB):
    def __init__(self,alpha=0.4, use_prev=False, eta=0.1, per_batch=False, loss_func=dmce, **kw): 
        super().__init__(**kw)
        self.alpha = alpha
        self.prev = None
        self.use_prev = use_prev
        self.eta = eta
        self.dist = torch.distributions.Beta(self.alpha,self.alpha) 
        self.per_batch = per_batch
        self.loss_func=loss_func
        
    def before_epoch(self, learn):
        self.prev = None
            
    def sample(self, shape): 
        if self.per_batch: return self.dist.sample([1])
        return self.dist.sample(shape)
    
    def before_batch(self, learn):
        if learn.training and self.alpha is not None: 
            r = mixup_data(learn.batch, self.prev, sampler=self.sample)
            if self.use_prev: self.prev = learn.batch
            learn.mixup = r
            learn.batch = r.mixed_x, r.mixed_y.argmax(-1)
            
    def get_loss(self, learn):
        if learn.training and self.alpha is not None:
            learn.loss = self.loss_func(learn.preds, learn.mixup, eta=self.eta)
        else:
            super().get_loss(learn)

## MixUpCB

In [None]:
def mixup_data_old(x, y, sampler=torch.distributions.Beta(tensor(1), tensor(1)).sample):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    lam = sampler([1]).to(x.device)
    batch_size = x.shape[0]
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, (y_a, y_b, lam)

def mixup_criterion_old(lf, pred, my):
    y_a, y_b, lam = my
    return lam * lf(pred, y_a) + (1 - lam) * lf(pred, y_b)

class MixUpCB(TrainCB):
    def __init__(self,alpha=0.4, mix_data=mixup_data_old, mix_loss=mixup_criterion_old,**kw): 
        super().__init__(**kw)
        self.alpha = alpha
        self.dist = torch.distributions.Beta(self.alpha,self.alpha)
        self.mix_data = mix_data
        self.mix_loss = mix_loss
        
    def before_fit(self, learn):
        self.base_lf = learn.loss_func
        
    def sample(self, n): return self.dist.sample(n)
    
    def before_batch(self, learn):
        if learn.training and self.alpha is not None: 
            bx, mixup = self.mix_data(*learn.batch, self.sample)
            learn.batch = bx, learn.batch[1]
            learn.mixup = mixup

    def get_loss(self, learn):
        if learn.training and self.alpha is not None:
            learn.loss = self.mix_loss(learn.loss_func, learn.preds, learn.mixup) # todo  *learn.batch[self.n_inp:]   
        else:
            super().get_loss(learn)

### dbl check 


In [None]:
rng.previous()
model = get_model9(act_gr, norm=nn.BatchNorm2d)
print(torch.randn([3]))
metrics = MetricsCB(accuracy=MulticlassAccuracy())
cbs = [DeviceCB(),rng, metrics, ProgressCB(plot=True)] 
act_gr = nn.SiLU
iw = partial(init_weights, leaky=0.0003)
epochs = 1
lr = 1e-2
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
xtra = [BatchSchedCB(sched), MixUpCB(0.4)] 
print(torch.randn([3]))
model = model.apply(iw)
learn2 = Learner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW)
print(torch.randn([3]))
print(next(iter(learn2.dls.train))[1])
learn2.fit(epochs, cbs=TimeItCB())
print(learn2.model(xb.to('cuda')).mean().item())

In [None]:
rng.previous()
learn = run(
    get_model9(nn.SiLU, norm=nn.BatchNorm2d), leaky=0.0003, 
    base_lr=1e-2, epochs=5, 
    cbs=[], 
    train_cb=MixUp4CB(0.4, use_prev=False, eta=None, per_batch=True, loss_func=mce))

In [None]:
self.sample??

In [None]:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

### We have weight decay !!!
```
AdamW (
Parameter Group 0
    amsgrad: False
    base_momentum: 0.85
    betas: (0.9499489264325014, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    initial_lr: 0.0004
    lr: 0.0004049030624798597
    max_lr: 0.01
    max_momentum: 0.95
    maximize: False
    min_lr: 4e-08
    weight_decay: 0.01
)
```

In [None]:
class CmpOneBatchCB(Callback):
    def __init__(self, batches=1):self.batches=batches
    def after_batch(self, learn):
        learn.batches = getattr(learn, 'batches', self.batches)
        #print("y:",learn.batch[1])
        print("loss:",learn.loss.item(), "iter", learn.iter)
        if learn.iter+1 == self.batches:
            raise CancelFitException()
    

In [None]:
learn1.iter

In [None]:
set_seed(1)
learn1 = run(
    get_model9(nn.SiLU, norm=nn.BatchNorm2d), leaky=0.0003, 
    base_lr=1e-2, epochs=1, 
    cbs=[CmpOneBatchCB()], 
    train_cb=MixUpCB(0.4), fit=True)
#print('lam', learn1.mixup[-1])
print(learn1.model(xb.to('cuda')).mean().item())
A=torch.get_rng_state().clone()
print(torch.randn([10]))

In [None]:
set_seed(1)
learn1 = run(
    get_model9(nn.SiLU, norm=nn.BatchNorm2d), leaky=0.0003, 
    base_lr=1e-2, epochs=1, 
    cbs=[CmpOneBatchCB()], 
    train_cb=MixUp4CB(0.4, use_prev=False, eta=None, per_batch=True, loss_func=mce), fit=True)
print('lam', learn1.mixup.lam)
print(learn1.model(xb.to('cuda')).mean().item())
B=torch.get_rng_state().clone()
print(torch.randn([10]))

In [None]:
(A == B).sum(), A.shape

In [None]:
A.max()

In [None]:
set_seed(1)
learn1 = run(
    get_model9(nn.SiLU, norm=nn.BatchNorm2d), leaky=0.0003, 
    base_lr=1e-2, epochs=1, 
    cbs=[], 
    train_cb=MixUp4CB(0.4, use_prev=False, eta=None, per_batch=True, loss_func=mce), dls=dls, fit=True)

In [None]:
# learn3 = run(
#     get_model9(nn.SiLU, norm=nn.BatchNorm2d), leaky=0.0003, 
#     base_lr=1e-2, epochs=1, 
#     cbs=[], 
#     train_cb=MixUpCB(0.4), dls=dls, fit=False)
# epochs = 1
# lr = 1e-2
# tmax = epochs * len(dls.train)
# sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
# xtra = [BatchSchedCB(sched), MixUpCB(0.4)] 

# learn3.cbs = cbs+xtra
# print(learn1.cbs)
# learn3.fit(1)
# print(learn3.model(xb).mean().item())

In [None]:
learn1.batch[1]

In [None]:
learn2.batch[1]

In [None]:
learn2.

In [None]:
set_seed(1)
metrics = MetricsCB(accuracy=MulticlassAccuracy())
cbs = [DeviceCB(), metrics, ProgressCB()] 
act_gr = nn.SiLU
iw = partial(init_weights, leaky=0.0003)
epochs = 5
lr = 1e-2
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
xtra = [BatchSchedCB(sched), MixUp4CB(0.4, use_prev=False, eta=None, per_batch=True, loss_func=mce)] 
model = get_model9(act_gr, norm=nn.BatchNorm2d).apply(iw)
learn = Learner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW)
learn.fit(epochs, cbs=TimeItCB())

### [93.8] MCE, single_lam

In [None]:
learn = run(
    get_model9(nn.SiLU, norm=nn.BatchNorm2d), leaky=0.0003, 
    base_lr=2e-2, epochs=5, 
    cbs=[], 
    train_cb=MixUp4CB(0.4, use_prev=True, eta=None, per_batch=True, loss_func=mce))

### [92.1] DMCE copy, single_lam eta 0.1, 0.1

In [None]:
learn = run(
    get_model9(nn.SiLU, norm=nn.BatchNorm2d), leaky=0.0003, 
    base_lr=2e-2, epochs=5, 
    cbs=[], 
    train_cb=MixUp4CB(0.4, use_prev=True, eta=[0.1,0.1], per_batch=True, loss_func=dmce_c))

### [93.8] DMCE (our), single_lam

In [None]:
learn = run(
    get_model9(nn.SiLU, norm=nn.BatchNorm2d), leaky=0.0003, 
    base_lr=2e-2, epochs=5, 
    cbs=[], 
    train_cb=MixUp4CB(0.4, use_prev=True, eta=0.1, per_batch=True, loss_func=dmce))

### [92.6] DMCE copy, single_lam eta[0.1,0.9]

In [None]:
learn = run(
    get_model9(nn.SiLU, norm=nn.BatchNorm2d), leaky=0.0003, 
    base_lr=2e-2, epochs=5, 
    cbs=[], 
    train_cb=MixUp4CB(0.4, use_prev=True, eta=[0.1,0.9], per_batch=True, loss_func=dmce_c))

### [93.9] DCME (our), mlam

In [None]:
learn = run(
    get_model9(nn.SiLU, norm=nn.BatchNorm2d), leaky=0.0003, 
    base_lr=2e-2, epochs=5, 
    cbs=[], 
    train_cb=MixUp4CB(0.4, use_prev=True, eta=0.1, per_batch=False, loss_func=dmce))
