In [14]:
import torch
import torchvision
from torchvision import transforms

In [25]:
aug = torchvision.transforms.Compose([
             transforms.Resize(128),
             transforms.RandomCrop(110),
             transforms.RandomHorizontalFlip(),
             transforms.RandAugment(3, 15),
             transforms.ToTensor(),
             transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
ds = torchvision.datasets.ImageFolder('data', transform=aug)

In [35]:
train = torch.utils.data.DataLoader(ds, batch_size=128, shuffle=True, num_workers=4,
           pin_memory=True, drop_last=False)

In [36]:
crit = torch.nn.BCELoss()

In [37]:
model = torchvision.models.mobilenet_v3_small(pretrained=False)

In [38]:
model.classifier[3] = torch.nn.Linear(1024, 1)
model = model.cuda()

In [39]:
opt = torch.optim.AdamW(model.parameters())

In [40]:
from tqdm.notebook import tqdm

In [41]:
act = torch.nn.Sigmoid()

In [44]:
for e in range(100):
    l_acc = 0
    for b in tqdm(train):
        x, y = b
        x = x.cuda()
        pred = act(model(x)).squeeze(1)
        loss = crit(pred, y.float().cuda())
        loss.backward()
        
        print(((torch.round(pred) == torch.round(y.float().cuda())).sum()/len(x)).item()*100, "% accuracy")
        
        opt.step()
        opt.zero_grad()
        l_acc += loss.item()
    print(l_acc)
    

  0%|          | 0/1 [00:00<?, ?it/s]

81.73912763595581 % accuracy
0.3966490924358368


  0%|          | 0/1 [00:00<?, ?it/s]

79.99999523162842 % accuracy
0.43641480803489685


  0%|          | 0/1 [00:00<?, ?it/s]

80.86956143379211 % accuracy
0.41116729378700256


  0%|          | 0/1 [00:00<?, ?it/s]

79.99999523162842 % accuracy
0.46512630581855774


  0%|          | 0/1 [00:00<?, ?it/s]

81.73912763595581 % accuracy
0.4297642409801483


  0%|          | 0/1 [00:00<?, ?it/s]

84.34782028198242 % accuracy
0.4369170367717743


  0%|          | 0/1 [00:00<?, ?it/s]

84.34782028198242 % accuracy
0.3995804190635681


  0%|          | 0/1 [00:00<?, ?it/s]

80.86956143379211 % accuracy
0.4242219030857086


  0%|          | 0/1 [00:00<?, ?it/s]

78.2608687877655 % accuracy
0.45270317792892456


  0%|          | 0/1 [00:00<?, ?it/s]

83.4782600402832 % accuracy
0.3999115228652954


  0%|          | 0/1 [00:00<?, ?it/s]

81.73912763595581 % accuracy
0.4595164358615875


  0%|          | 0/1 [00:00<?, ?it/s]

80.86956143379211 % accuracy
0.43642234802246094


  0%|          | 0/1 [00:00<?, ?it/s]

81.73912763595581 % accuracy
0.38438719511032104


  0%|          | 0/1 [00:00<?, ?it/s]

85.21738648414612 % accuracy
0.39802271127700806


  0%|          | 0/1 [00:00<?, ?it/s]

80.86956143379211 % accuracy
0.48014020919799805


  0%|          | 0/1 [00:00<?, ?it/s]

82.6086938381195 % accuracy
0.4108193814754486


  0%|          | 0/1 [00:00<?, ?it/s]

84.34782028198242 % accuracy
0.41847702860832214


  0%|          | 0/1 [00:00<?, ?it/s]

79.13042902946472 % accuracy
0.45874860882759094


  0%|          | 0/1 [00:00<?, ?it/s]

84.34782028198242 % accuracy
0.3853466808795929


  0%|          | 0/1 [00:00<?, ?it/s]

83.4782600402832 % accuracy
0.4165520966053009


  0%|          | 0/1 [00:00<?, ?it/s]

81.73912763595581 % accuracy
0.3733541667461395


  0%|          | 0/1 [00:00<?, ?it/s]

79.99999523162842 % accuracy
0.4231729507446289


  0%|          | 0/1 [00:00<?, ?it/s]

85.21738648414612 % accuracy
0.3914753794670105


  0%|          | 0/1 [00:00<?, ?it/s]

85.21738648414612 % accuracy
0.3639938533306122


  0%|          | 0/1 [00:00<?, ?it/s]

82.6086938381195 % accuracy
0.4088582992553711


  0%|          | 0/1 [00:00<?, ?it/s]

83.4782600402832 % accuracy
0.38625553250312805


  0%|          | 0/1 [00:00<?, ?it/s]

83.4782600402832 % accuracy
0.3835814893245697


  0%|          | 0/1 [00:00<?, ?it/s]

85.21738648414612 % accuracy
0.35077792406082153


  0%|          | 0/1 [00:00<?, ?it/s]

79.99999523162842 % accuracy
0.4780641496181488


  0%|          | 0/1 [00:00<?, ?it/s]

80.86956143379211 % accuracy
0.45914074778556824


  0%|          | 0/1 [00:00<?, ?it/s]

83.4782600402832 % accuracy
0.4230870008468628


  0%|          | 0/1 [00:00<?, ?it/s]

77.3913025856018 % accuracy
0.4852885901927948


  0%|          | 0/1 [00:00<?, ?it/s]

81.73912763595581 % accuracy
0.4552576243877411


  0%|          | 0/1 [00:00<?, ?it/s]

83.4782600402832 % accuracy
0.3978196978569031


  0%|          | 0/1 [00:00<?, ?it/s]

82.6086938381195 % accuracy
0.3998379111289978


  0%|          | 0/1 [00:00<?, ?it/s]

81.73912763595581 % accuracy
0.4334625005722046


  0%|          | 0/1 [00:00<?, ?it/s]

79.99999523162842 % accuracy
0.4586438834667206


  0%|          | 0/1 [00:00<?, ?it/s]

82.6086938381195 % accuracy
0.3964727222919464


  0%|          | 0/1 [00:00<?, ?it/s]

83.4782600402832 % accuracy
0.4065505862236023


  0%|          | 0/1 [00:00<?, ?it/s]

84.34782028198242 % accuracy
0.3822098970413208


  0%|          | 0/1 [00:00<?, ?it/s]

81.73912763595581 % accuracy
0.4392859935760498


  0%|          | 0/1 [00:00<?, ?it/s]

81.73912763595581 % accuracy
0.4289354979991913


  0%|          | 0/1 [00:00<?, ?it/s]

79.13042902946472 % accuracy
0.4840013086795807


  0%|          | 0/1 [00:00<?, ?it/s]

84.34782028198242 % accuracy
0.35259348154067993


  0%|          | 0/1 [00:00<?, ?it/s]

86.95651888847351 % accuracy
0.3126876950263977


  0%|          | 0/1 [00:00<?, ?it/s]

82.6086938381195 % accuracy
0.44403839111328125


  0%|          | 0/1 [00:00<?, ?it/s]

80.86956143379211 % accuracy
0.42986056208610535


  0%|          | 0/1 [00:00<?, ?it/s]

79.13042902946472 % accuracy
0.4590206444263458


  0%|          | 0/1 [00:00<?, ?it/s]

83.4782600402832 % accuracy
0.36862891912460327


  0%|          | 0/1 [00:00<?, ?it/s]

74.78260397911072 % accuracy
0.5032674074172974


  0%|          | 0/1 [00:00<?, ?it/s]

84.34782028198242 % accuracy
0.3737611472606659


  0%|          | 0/1 [00:00<?, ?it/s]

80.86956143379211 % accuracy
0.42600560188293457


  0%|          | 0/1 [00:00<?, ?it/s]

85.21738648414612 % accuracy
0.339153915643692


  0%|          | 0/1 [00:00<?, ?it/s]

81.73912763595581 % accuracy
0.38897180557250977


  0%|          | 0/1 [00:00<?, ?it/s]

83.4782600402832 % accuracy
0.36869004368782043


  0%|          | 0/1 [00:00<?, ?it/s]

86.08695268630981 % accuracy
0.3756314516067505


  0%|          | 0/1 [00:00<?, ?it/s]

84.34782028198242 % accuracy
0.37206530570983887


  0%|          | 0/1 [00:00<?, ?it/s]

84.34782028198242 % accuracy
0.3512461185455322


  0%|          | 0/1 [00:00<?, ?it/s]

82.6086938381195 % accuracy
0.36214718222618103


  0%|          | 0/1 [00:00<?, ?it/s]

85.21738648414612 % accuracy
0.37030842900276184


  0%|          | 0/1 [00:00<?, ?it/s]

82.6086938381195 % accuracy
0.36940786242485046


  0%|          | 0/1 [00:00<?, ?it/s]

79.13042902946472 % accuracy
0.4728609621524811


  0%|          | 0/1 [00:00<?, ?it/s]

79.13042902946472 % accuracy
0.49540749192237854


  0%|          | 0/1 [00:00<?, ?it/s]

81.73912763595581 % accuracy
0.4025600850582123


  0%|          | 0/1 [00:00<?, ?it/s]

85.21738648414612 % accuracy
0.3551033139228821


  0%|          | 0/1 [00:00<?, ?it/s]

76.52173638343811 % accuracy
0.49492213129997253


  0%|          | 0/1 [00:00<?, ?it/s]

80.86956143379211 % accuracy
0.4867911636829376


  0%|          | 0/1 [00:00<?, ?it/s]

82.6086938381195 % accuracy
0.40124908089637756


  0%|          | 0/1 [00:00<?, ?it/s]

82.6086938381195 % accuracy
0.3926010727882385


  0%|          | 0/1 [00:00<?, ?it/s]

82.6086938381195 % accuracy
0.3810902237892151


  0%|          | 0/1 [00:00<?, ?it/s]

78.2608687877655 % accuracy
0.46311262249946594


  0%|          | 0/1 [00:00<?, ?it/s]

83.4782600402832 % accuracy
0.36647453904151917


  0%|          | 0/1 [00:00<?, ?it/s]

79.99999523162842 % accuracy
0.42574307322502136


  0%|          | 0/1 [00:00<?, ?it/s]

84.34782028198242 % accuracy
0.3899058401584625


  0%|          | 0/1 [00:00<?, ?it/s]

83.4782600402832 % accuracy
0.3947485685348511


  0%|          | 0/1 [00:00<?, ?it/s]

86.08695268630981 % accuracy
0.33066555857658386


  0%|          | 0/1 [00:00<?, ?it/s]

87.8260850906372 % accuracy
0.3208453357219696


  0%|          | 0/1 [00:00<?, ?it/s]

85.21738648414612 % accuracy
0.3969605565071106


  0%|          | 0/1 [00:00<?, ?it/s]

83.4782600402832 % accuracy
0.345598429441452


  0%|          | 0/1 [00:00<?, ?it/s]

80.86956143379211 % accuracy
0.3528376519680023


  0%|          | 0/1 [00:00<?, ?it/s]

86.08695268630981 % accuracy
0.3726082742214203


  0%|          | 0/1 [00:00<?, ?it/s]

81.73912763595581 % accuracy
0.3979179859161377


  0%|          | 0/1 [00:00<?, ?it/s]

87.8260850906372 % accuracy
0.3012440502643585


  0%|          | 0/1 [00:00<?, ?it/s]

84.34782028198242 % accuracy
0.4023791253566742


  0%|          | 0/1 [00:00<?, ?it/s]

86.08695268630981 % accuracy
0.33645376563072205


  0%|          | 0/1 [00:00<?, ?it/s]

82.6086938381195 % accuracy
0.36518463492393494


  0%|          | 0/1 [00:00<?, ?it/s]

86.08695268630981 % accuracy
0.38208019733428955


  0%|          | 0/1 [00:00<?, ?it/s]

86.95651888847351 % accuracy
0.37664636969566345


  0%|          | 0/1 [00:00<?, ?it/s]

81.73912763595581 % accuracy
0.3784125745296478


  0%|          | 0/1 [00:00<?, ?it/s]

82.6086938381195 % accuracy
0.33944615721702576


  0%|          | 0/1 [00:00<?, ?it/s]

84.34782028198242 % accuracy
0.336292028427124


  0%|          | 0/1 [00:00<?, ?it/s]

86.08695268630981 % accuracy
0.3406885266304016


  0%|          | 0/1 [00:00<?, ?it/s]

83.4782600402832 % accuracy
0.35200628638267517


  0%|          | 0/1 [00:00<?, ?it/s]

85.21738648414612 % accuracy
0.35827532410621643


  0%|          | 0/1 [00:00<?, ?it/s]

85.21738648414612 % accuracy
0.34998664259910583


  0%|          | 0/1 [00:00<?, ?it/s]

86.95651888847351 % accuracy
0.313495934009552


  0%|          | 0/1 [00:00<?, ?it/s]

86.95651888847351 % accuracy
0.31226852536201477


  0%|          | 0/1 [00:00<?, ?it/s]

79.99999523162842 % accuracy
0.37682387232780457


  0%|          | 0/1 [00:00<?, ?it/s]

84.34782028198242 % accuracy
0.3771178722381592


  0%|          | 0/1 [00:00<?, ?it/s]

88.69564533233643 % accuracy
0.349774032831192
