In [1]:
cd ..

D:\Kevin\Machine Learning\Cassava Leaf Disease Classification


In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch import optim
from torch.optim.optimizer import Optimizer
from adabelief_pytorch import AdaBelief
from ranger_adabelief import RangerAdaBelief
from warmup_scheduler import GradualWarmupScheduler

import timm

from src.dataset import get_loaders
from src.optim import get_optimizer_and_scheduler
from src.engine import get_device, get_net, train_one_epoch, valid_one_epoch
from src import config
from src.utils import *
from src.loss import FocalCosineLoss, SmoothCrossEntropyLoss, bi_tempered_logistic_loss

%matplotlib inline

In [3]:
from IPython.display import FileLinks
FileLinks(config.WEIGHTS_PATH)

In [4]:
class GeneralizedCassavaClassifier(nn.Module):
    def __init__(self, model_arch, n_class=5, pretrained=False):
        super().__init__()
        self.name = model_arch
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        model_list = list(self.model.children())
        model_list[-1] = nn.Linear(
            in_features=model_list[-1].in_features,
            out_features=n_class,
            bias=True
        )
        self.model = nn.Sequential(*model_list)

    def forward(self, x):
        x = self.model(x)
        return x
    
class RAdam(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        self.buffer = [[None, None, None] for ind in range(10)]
        super(RAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(RAdam, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError(
                        'RAdam does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(
                        p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                exp_avg.mul_(beta1).add_(1 - beta1, grad)

                state['step'] += 1
                buffered = self.buffer[int(state['step'] % 10)]
                if state['step'] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2 ** state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * \
                        state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma

                    # more conservative since it's an approximated value
                    if N_sma >= 5:
                        step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (
                            N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    else:
                        step_size = group['lr'] / (1 - beta1 ** state['step'])
                    buffered[2] = step_size

                if group['weight_decay'] != 0:
                    p_data_fp32.add_(-group['weight_decay']
                                     * group['lr'], p_data_fp32)

                # more conservative since it's an approximated value
                if N_sma >= 5:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
                else:
                    p_data_fp32.add_(-step_size, exp_avg)

                p.data.copy_(p_data_fp32)

        return loss

class GradualWarmupSchedulerV2(GradualWarmupScheduler):
        def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
            super(GradualWarmupSchedulerV2, self).__init__(optimizer, multiplier, total_epoch, after_scheduler)
        def get_lr(self):
            if self.last_epoch > self.total_epoch:
                if self.after_scheduler:
                    if not self.finished:
                        self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                        self.finished = True
                    return self.after_scheduler.get_lr()
                return [base_lr * self.multiplier for base_lr in self.base_lrs]
            if self.multiplier == 1.0:
                return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
            else:
                return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]


net = GeneralizedCassavaClassifier(model_arch="resnext50_32x4d")

In [5]:
pull = False # [True, False]
fold = 0

if not pull:
    model = net
    train_loader, valid_loader          = get_loaders(fold)
    optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE * config.WARMUP_FACTOR)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=5, eta_min=0)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer,
                T_0=14,
                T_mult=1,
                eta_min=0,
                last_epoch=-1)
    scheduler = optimizer = RAdam(
                    net.parameters(),
                    lr=config.LEARNING_RATE * m
                )
#     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
#                 optimizer, T_max=5, eta_min=0)
#     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
#                 optimizer, T_max=5, eta_min=0)
#     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
#                 optimizer, T_max=5, eta_min=0)
#     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
#                 optimizer, T_max=5, eta_min=0)
    scheduler = GradualWarmupSchedulerV2(optimizer, multiplier=config.WARMUP_FACTOR, total_epoch=config.WARMUP_EPOCHS, after_scheduler=scheduler)

if pull:
    train_loader, valid_loader          = get_loaders(fold)
    optimizer, scheduler                = get_optimizer_and_scheduler(net=net, dataloader=train_loader)

lrs = []
for epoch in range(config.MAX_EPOCHS):
    for step in range(len(train_loader)):
        scheduler.step(epoch + (step / len(train_loader)))
        lrs.append(optimizer.param_groups[0]["lr"])
plt.rcParams['figure.figsize'] = 20,3
plt.plot(lrs)

NameError: name 'm' is not defined

In [6]:
for lr in lrs:
    print(lr)

0.0007
0.0007039289055191768
0.0007078578110383536
0.0007117867165575304
0.0007157156220767071
0.0007196445275958839
0.0007235734331150609
0.0007275023386342376
0.0007314312441534144
0.0007353601496725912
0.000739289055191768
0.0007432179607109448
0.0007471468662301216
0.0007510757717492984
0.0007550046772684751
0.000758933582787652
0.0007628624883068289
0.0007667913938260055
0.0007707202993451825
0.0007746492048643592
0.000778578110383536
0.0007825070159027128
0.0007864359214218896
0.0007903648269410664
0.0007942937324602431
0.00079822263797942
0.0008021515434985968
0.0008060804490177735
0.0008100093545369505
0.0008139382600561271
0.000817867165575304
0.0008217960710944809
0.0008257249766136577
0.0008296538821328344
0.0008335827876520112
0.000837511693171188
0.0008414405986903648
0.0008453695042095415
0.0008492984097287184
0.0008532273152478953
0.000857156220767072
0.0008610851262862489
0.0008650140318054257
0.0008689429373246024
0.0008728718428437793
0.000876800748362956
0.0008807296

0.004858222924301669
0.0048581283015746825
0.004858033572734869
0.004857938737786405
0.004857843796733465
0.004857748749580237
0.004857653596330906
0.004857558336989665
0.004857462971560714
0.004857367500048252
0.004857271922456489
0.004857176238789635
0.0048570804490519065
0.004856984553247523
0.004856888551380712
0.004856792443455704
0.0048566962294767325
0.004856599909448038
0.004856503483373864
0.004856406951258461
0.004856310313106081
0.004856213568920982
0.004856116718707429
0.0048560197624696886
0.004855922700212032
0.004855825531938738
0.004855728257654088
0.004855630877362367
0.004855533391067868
0.004855435798774885
0.004855338100487718
0.004855240296210673
0.0048551423859480615
0.004855044369704194
0.004854946247483393
0.0048548480192899795
0.004854749685128283
0.004854651245002638
0.0048545526989173795
0.0048544540468768525
0.004854355288885401
0.00485425642494738
0.004854157455067143
0.004854058379249054
0.004853959197497476
0.004853859909816781
0.004853760516211343
0.0048

0.004699840653892613
0.004699637006265159
0.004699433259508707
0.004699229413632238
0.004699025468644732
0.004698821424555177
0.0046986172813725646
0.004698413039105888
0.004698208697764149
0.004698004257356351
0.004697799717891502
0.004697595079378617
0.004697390341826711
0.004697185505244807
0.004696980569641931
0.004696775535027112
0.0046965704014093855
0.004696365168797792
0.0046961598372013726
0.004695954406629177
0.004695748877090256
0.004695543248593667
0.0046953375211484715
0.0046951316947637335
0.004694925769448522
0.004694719745211914
0.0046945136220629844
0.004694307400010819
0.0046941010790645015
0.004693894659233126
0.004693688140525787
0.004693481522951585
0.004693274806519625
0.004693067991239015
0.004692861077118868
0.004692654064168303
0.0046924469523964415
0.004692239741812408
0.004692032432425335
0.004691825024244356
0.004691617517278612
0.0046914099115372455
0.004691202207029405
0.004690994403764243
0.004690786501750916
0.004690578500998585
0.004690370401516415
0.00

0.004315447343500218
0.004315113899895111
0.004314780374104803
0.004314446766143992
0.004314113076027377
0.0043137793037696616
0.004313445449385554
0.004313111512889763
0.004312777494297007
0.004312443393622003
0.004312109210879472
0.004311774946084139
0.0043114405992507356
0.0043111061703939935
0.0043107716595286475
0.004310437066669439
0.004310102391831113
0.004309767635028416
0.004309432796276097
0.004309097875588913
0.00430876287298162
0.00430842778846898
0.00430809262206576
0.004307757373786728
0.004307422043646656
0.00430708663166032
0.004306751137842501
0.004306415562207981
0.0043060799047715485
0.004305744165547992
0.0043054083445521075
0.0043050724417986914
0.0043047364573025465
0.004304400391078477
0.004304064243141291
0.004303728013505801
0.004303391702186824
0.004303055309199177
0.004302718834557684
0.004302382278277173
0.004302045640372471
0.004301708920858415
0.004301372119749842
0.004301035237061591
0.004300698272808507
0.00430036122700544
0.004300024099667239
0.00429968

0.0038712029020226127
0.0038707839498746737
0.003870364935120684
0.0038699458577791074
0.0038695267178684117
0.0038691075154070644
0.0038686882504135396
0.003868268922906309
0.003867849532903852
0.003867430080424648
0.0038670105654871807
0.0038665909881099355
0.003866171348311401
0.003865751646110067
0.0038653318815244295
0.003864912054572984
0.00386449216527423
0.00386407221364667
0.003863652199708809
0.0038632321234791543
0.003862811984976216
0.0038623917842185084
0.0038619715212245467
0.0038615511960128504
0.0038611308086019384
0.0038607103590103382
0.003860289847256574
0.003859869273359177
0.003859448637336679
0.003859027939207615
0.0038586071789905225
0.0038581863567039433
0.0038577654723664194
0.003857344525996497
0.003856923517612726
0.0038565024472336567
0.0038560813148778438
0.003855660120563844
0.0038552388643102176
0.003854817546135527
0.003854396166058336
0.003853974724097213
0.00385355322027073
0.0038531316545974592
0.0038527100270959766
0.003852288337784862
0.003851866586

0.0032822011222962875
0.003281717389592265
0.003281233620239081
0.0032807498142580532
0.0032802659716704994
0.0032797820924977397
0.0032792981767610957
0.003278814224481891
0.0032783302356814525
0.0032778462103811045
0.0032773621486021772
0.0032768780503659987
0.003276393915693901
0.003275909744607218
0.0032754255371272843
0.0032749412932754365
0.003274457013073011
0.003273972696541348
0.00327348834370179
0.0032730039545756785
0.0032725195291843584
0.0032720350675491745
0.0032715505696914752
0.0032710660356326097
0.0032705814653939294
0.0032700968589967853
0.0032696122164625314
0.003269127537812523
0.0032686428230681184
0.003268158072250676
0.003267673285381556
0.00326718846248212
0.003266703603573731
0.0032662187086777545
0.003265733777815557
0.0032652488110085074
0.0032647638082779745
0.0032642787696453303
0.003263793695131947
0.0032633085847591997
0.0032628234385484647
0.0032623382565211204
0.0032618530386985443
0.003261367785102118
0.0032608824957532246
0.0032603971706732468
0.0032

0.0025646453534207306
0.002564131621745519
0.0025636178850411626
0.002563104143330297
0.002562590396635561
0.002562076644979592
0.002561562888385029
0.002561049126874508
0.0025605353604706703
0.002560021589196154
0.002559507813073598
0.002558994032125641
0.0025584802463749236
0.002557966455844084
0.0025574526605557633
0.0025569388605326017
0.002556425055797238
0.0025559112463723146
0.0025553974322804697
0.0025548836135443477
0.0025543697901865875
0.002553855962229831
0.002553342129696719
0.0025528282926098943
0.002552314450991998
0.0025518006048656735
0.0025512867542535618
0.0025507728991783062
0.0025502590396625485
0.0025497451757289322
0.002549231307400101
0.002548717434698698
0.002548203557647366
0.00254768967626875
0.0025471757905854917
0.0025466619006202377
0.0025461480063956304
0.0025456341079343153
0.002545120205258936
0.0025446062983921387
0.002544092387356568
0.002543578472174868
0.002543064552869686
0.0025425506294636667
0.0025420367019794556
0.0025415227704396973
0.002541008

0.0018413474113402187
0.0018408492553220665
0.0018403511261458009
0.0018398530238333728
0.0018393549484067293
0.0018388568998878182
0.0018383588782985856
0.001837860883660978
0.0018373629159969373
0.0018368649753284076
0.0018363670616773295
0.0018358691750656442
0.0018353713155152898
0.0018348734830482046
0.001834375677686325
0.0018338778994515881
0.0018333801483659262
0.0018328824244512715
0.0018323847277295592
0.001831887058222719
0.0018313894159526793
0.001830891800941369
0.0018303942132107149
0.001829896652782644
0.0018293991196790798
0.001828901613921946
0.0018284041355331666
0.0018279066845346599
0.0018274092609483464
0.0018269118647961464
0.001826414496099978
0.0018259171548817546
0.0018254198411633928
0.0018249225549668065
0.0018244252963139087
0.00182392806522661
0.0018234308617268206
0.0018229336858364494
0.0018224365375774068
0.0018219394169715945
0.0018214423240409216
0.0018209452588072912
0.001820448221292607
0.0018199512115187689
0.0018194542295076781
0.001818957275281234

0.0011637668485208136
0.0011633291591637617
0.00116289152650314
0.001162453950558234
0.001162016431348324
0.00116157896889269
0.0011611415632106074
0.0011607042143213528
0.001160266922244194
0.001159829686998402
0.0011593925086032443
0.0011589553870779835
0.0011585183224418817
0.0011580813147141968
0.001157644363914186
0.0011572074700611046
0.0011567706331742028
0.001156333853272729
0.001155897130375931
0.0011554604645030526
0.0011550238556733334
0.001154587303906015
0.0011541508092203314
0.0011537143716355196
0.0011532779911708072
0.0011528416678454245
0.0011524054016785995
0.0011519691926895543
0.0011515330408975101
0.001151096946321687
0.0011506609089813006
0.0011502249288955634
0.0011497890060836887
0.0011493531405648834
0.0011489173323583563
0.0011484815814833073
0.0011480458879589395
0.0011476102518044524
0.001147174673039041
0.0011467391516818983
0.0011463036877522167
0.0011458682812691828
0.001145432932251985
0.0011449976407198054
0.0011445624066918237
0.0011441272301872193
0.0

0.0006743262528587159
0.0006739719462992734
0.0006736177179995138
0.000673263567975046
0.0006729094962414744
0.000672555502814402
0.0006722015877094287
0.0006718477509421468
0.0006714939925281497
0.0006711403124830247
0.0006707867108223574
0.0006704331875617285
0.0006700797427167145
0.0006697263763028926
0.000669373088335832
0.0006690198788311
0.0006686667478042614
0.0006683136952708764
0.0006679607212465018
0.0006676078257466904
0.0006672550087869942
0.0006669022703829589
0.0006665496105501271
0.0006661970293040398
0.0006658445266602321
0.0006654921026342382
0.0006651397572415866
0.0006647874904978024
0.00066443530241841
0.0006640831930189271
0.0006637311623148695
0.0006633792103217486
0.000663027337055074
0.0006626755425303506
0.0006623238267630781
0.0006619721897687577
0.0006616206315628822
0.0006612691521609439
0.0006609177515784299
0.0006605664298308238
0.0006602151869336082
0.0006598640229022592
0.0006595129377522506
0.0006591619314990523
0.0006588110041581319
0.00065846015574495

0.0003036534555344743
0.0003034055097075046
0.0003031576584690208
0.0003029099018299439
0.00030266223980119115
0.00030241467239367604
0.0003021671996183072
0.00030191982148598914
0.0003016725380076231
0.0003014253491941045
0.00030117825505632733
0.00030093125560517727
0.00030068435085153967
0.00030043754080629446
0.00030019082548031657
0.00029994420488447783
0.000299697679029644
0.00029945124792668033
0.00029920491158644464
0.000298958670019791
0.00029871252323757086
0.00029846647125063026
0.00029822051406981116
0.0002979746517059513
0.00029772888416988453
0.0002974832114724419
0.0002972376336244467
0.00029699215063672146
0.00029674676252008247
0.0002965014692853435
0.00029625627094331266
0.00029601116750479426
0.00029576615898058964
0.00029552124538149434
0.00029527642671829993
0.00029503170300179487
0.0002947870742427617
0.00029454254045198244
0.00029429810164022875
0.00029405375781827467
0.00029380950899688513
0.00029356535518682415
0.00029332129639884983
0.0002930773326437158
0.000

6.656382352919753e-05
6.644480908498575e-05
6.632589967079969e-05
6.620709529188033e-05
6.608839595346214e-05
6.596980166077494e-05
6.585131241904558e-05
6.573292823349411e-05
6.561464910933837e-05
6.549647505178863e-05
6.537840606605377e-05
6.526044215733534e-05
6.514258333083162e-05
6.502482959173653e-05
6.490718094523775e-05
6.47896373965205e-05
6.467219895076402e-05
6.45548656131424e-05
6.443763738882669e-05
6.432051428298228e-05
6.420349630077018e-05
6.408658344734652e-05
6.396977572786305e-05
6.385307314746692e-05
6.373647571130012e-05
6.361998342450134e-05
6.350359629220304e-05
6.338731431953415e-05
6.327113751161841e-05
6.315506587357496e-05
6.303909941051884e-05
6.292323812755966e-05
6.280748202980294e-05
6.269183112234904e-05
6.257628541029424e-05
6.246084489873075e-05
6.234550959274424e-05
6.223027949741763e-05
6.211515461782821e-05
6.200013495904882e-05
6.188522052614828e-05
6.177041132418916e-05
6.165570735823181e-05
6.154110863332954e-05
6.142661515453186e-05
6.1312226926

In [7]:
len(train_loader) * 15

16035

In [4]:
df = pd.read_csv(TRAIN_FOLDS)
dataset = CassavaDataset(df=df,
                         data_root=TRAIN_IMAGES_DIR,
                         transforms=get_valid_transforms())
dataloader = DataLoader(dataset,
                        batch_size=3,
                        drop_last=False,
                        num_workers=0,
                        shuffle=False)
device = get_device(n=0)
net = get_net(name=NET, pretrained=False)
net.load_state_dict(torch.load("./generated/weights\SEResNeXt50_32x4d_BH/SEResNeXt50_32x4d_BH_fold_2_11.bin"))
net = net.to(device)

Device:                      GPU


In [6]:
preds = np.empty((0, 5), dtype=np.float64)
for images, labels in tqdm(dataloader):
    images, labels = images.to(device), labels.to(device)
    predictions = net(images).detach().cpu().numpy()
    preds = np.concatenate([preds, predictions], axis=0)
    
print(preds.shape)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=7133.0), HTML(value='')))




KeyboardInterrupt: 

In [9]:
preds = np.empty((0, 5), dtype=np.float64)