In [1]:
import torch
import torch.nn as nn
import torchvision

import numpy as np

from BReGNeXt import BReGNeXt as Classif
import cv2

import os
from itertools import groupby

cuda = torch.device('cuda:0')

In [2]:
classif = Classif().to(cuda)
# tmp = torch.randn(10, 3, 224, 224, device=cuda)
# classif(tmp).shape

In [3]:
classif.load_state_dict(torch.load('graphs/bregxnet0'))

<All keys matched successfully>

In [2]:
# cnn = torchvision.models.mobilenet_v2(pretrained=True)
cnn = torchvision.models.resnet18(pretrained=True)
conv = nn.Sequential(*list(cnn.children())[:-1])
fc = nn.Sequential(
    nn.Flatten(),
    nn.Linear(512, 8),
    nn.Softmax(dim=-1)
)

classif = nn.Sequential(
    conv,
    fc
)

In [3]:
from itertools import groupby 
import cv2

class PackLoader():
    def __init__(self, psize, path):
        
        self.path = path
        files = [path + '/images/'+ imfile for imfile in os.listdir(path + '/images')]
        lfiles = [path + '/annotations/' + file[:-4] + '_exp.npy' for file in os.listdir(path + '/images')]
            
        self.files = np.array(files)
        self.labels = np.array([int(np.load(labfile).item()) for labfile in lfiles])
        self.idx = np.random.permutation(len(files))
        self.psize = psize

        return

    def pack(self, i):
        idx = self.idx[i * self.psize:(i + 1) * self.psize]
    
        files = self.files[idx]
        labels = self.labels[idx]
        pack = []
        for j, file in enumerate(files):
            im = cv2.imread(file)
            pack.append((im, labels[j]))
            
        return pack

In [4]:
_image_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToPILImage(),
    torchvision.transforms.Resize((64, 64)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ColorJitter(brightness=(0, 0.05), contrast=(0.7,1.3), saturation=(0.6, 1.6), hue=0.08),
    torchvision.transforms.RandomResizedCrop((64, 64)),
    torchvision.transforms.ToTensor(),
])


def prepr(features):
    # features = features.reshape(64, 64, 3)
    features = _image_transform(features)
    features = features - torch.FloatTensor([0.5727663, 0.44812188, 0.39362228]).unsqueeze(-1).unsqueeze(-1)
    return features

In [5]:
prepr = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                     std=[0.5, 0.5, 0.5])
])

In [5]:
from torch.utils.data import DataLoader
    
def gen_batch(bdata):
    inp = torch.zeros(len(bdata), 3, 64, 64)
    labels = torch.zeros(len(bdata), dtype=torch.long)
    for i, (im, label) in enumerate(bdata):
        inp[i] = prepr(im)
        labels[i] = label
    return inp.to(cuda), labels.to(cuda)

In [6]:
def focal_loss2(input_tensor, target_tensor, weight=None, gamma=2, reduction='mean'):
    log_prob = torch.nn.functional.log_softmax(input_tensor, dim=-1)
    probs = torch.exp(log_prob)
    return torch.nn.functional.nll_loss(((1 - probs) ** gamma) * log_prob,
                                        target_tensor, weight=weight, reduction=reduction)
ce_loss = nn.CrossEntropyLoss()
def cross_entropy(pred, target):
    probs = torch.nn.functional.softmax(pred, dim=-1)
    # probs = torch.exp(log_prob)
    return ce_loss(probs, target)

In [7]:
ploader = PackLoader(1000, 'data/affectnet/train_set')
ploader_val = PackLoader(200, 'data/affectnet/val_set')

In [13]:
optimizer = torch.optim.Adam(classif.parameters(), lr=1e-6 * 64, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.80)
# criterion = nn.CrossEntropyLoss()
criterion = focal_loss2
# criterion = cross_entropy
classif = classif.to(cuda)

for epoch in range(8):
    # ploader = PackLoader(200, 'data/affectnet/train_set')
    for pnum in range(len(ploader.files) // ploader.psize + 1):
        pack = ploader.pack(pnum)
        dloader = DataLoader(pack, batch_size=128, shuffle=True, collate_fn=gen_batch)
        
        for i, (x, labels) in enumerate(dloader):
            pred = classif(x)

            optimizer.zero_grad()
            loss = criterion(pred, labels)
            loss.backward()
            optimizer.step()

            print('Epoch {} Pack {} Loss {}'.format(epoch, 
                                                    pnum, 
                                                    loss.item()))
    scheduler.step()

Epoch 0 Pack 0 Loss 0.7300662994384766
Epoch 0 Pack 0 Loss 0.716070294380188
Epoch 0 Pack 0 Loss 0.8075171113014221
Epoch 0 Pack 0 Loss 0.822019100189209
Epoch 0 Pack 0 Loss 0.715714156627655
Epoch 0 Pack 0 Loss 0.6217144131660461
Epoch 0 Pack 0 Loss 0.7570478320121765
Epoch 0 Pack 0 Loss 0.6613480448722839
Epoch 0 Pack 1 Loss 0.813779890537262
Epoch 0 Pack 1 Loss 0.7181180119514465
Epoch 0 Pack 1 Loss 0.7561149597167969
Epoch 0 Pack 1 Loss 0.8151175379753113
Epoch 0 Pack 1 Loss 0.6664374470710754
Epoch 0 Pack 1 Loss 0.6587638854980469
Epoch 0 Pack 1 Loss 0.7016550302505493
Epoch 0 Pack 1 Loss 0.6025456190109253
Epoch 0 Pack 2 Loss 0.6725597381591797
Epoch 0 Pack 2 Loss 0.6080736517906189
Epoch 0 Pack 2 Loss 0.7008368968963623
Epoch 0 Pack 2 Loss 0.6500228047370911
Epoch 0 Pack 2 Loss 0.6540015339851379
Epoch 0 Pack 2 Loss 0.6860771775245667
Epoch 0 Pack 2 Loss 0.6175105571746826
Epoch 0 Pack 2 Loss 0.6647729277610779
Epoch 0 Pack 3 Loss 0.840909481048584
Epoch 0 Pack 3 Loss 0.66181969

Epoch 0 Pack 26 Loss 0.7274332642555237
Epoch 0 Pack 26 Loss 0.769182026386261
Epoch 0 Pack 26 Loss 0.6959049701690674
Epoch 0 Pack 26 Loss 0.724321186542511
Epoch 0 Pack 26 Loss 0.7007533311843872
Epoch 0 Pack 26 Loss 0.663105845451355
Epoch 0 Pack 26 Loss 0.6501336097717285
Epoch 0 Pack 26 Loss 0.7249974608421326
Epoch 0 Pack 27 Loss 0.6494816541671753
Epoch 0 Pack 27 Loss 0.7128504514694214
Epoch 0 Pack 27 Loss 0.802293062210083
Epoch 0 Pack 27 Loss 0.6733722686767578
Epoch 0 Pack 27 Loss 0.6170796155929565
Epoch 0 Pack 27 Loss 0.7704850435256958
Epoch 0 Pack 27 Loss 0.7069751024246216
Epoch 0 Pack 27 Loss 0.8328108191490173
Epoch 0 Pack 28 Loss 0.7265833616256714
Epoch 0 Pack 28 Loss 0.7297842502593994
Epoch 0 Pack 28 Loss 0.7660538554191589
Epoch 0 Pack 28 Loss 0.8134390115737915
Epoch 0 Pack 28 Loss 0.6640607118606567
Epoch 0 Pack 28 Loss 0.7831674814224243
Epoch 0 Pack 28 Loss 0.7337072491645813
Epoch 0 Pack 28 Loss 0.7980614900588989
Epoch 0 Pack 29 Loss 0.722895622253418
Epoch

Epoch 0 Pack 51 Loss 0.6120638847351074
Epoch 0 Pack 51 Loss 0.6421970129013062
Epoch 0 Pack 52 Loss 0.771701455116272
Epoch 0 Pack 52 Loss 0.6974979043006897
Epoch 0 Pack 52 Loss 0.7940302491188049
Epoch 0 Pack 52 Loss 0.7631093263626099
Epoch 0 Pack 52 Loss 0.6446461081504822
Epoch 0 Pack 52 Loss 0.7074437737464905
Epoch 0 Pack 52 Loss 0.6193161606788635
Epoch 0 Pack 52 Loss 0.6389595866203308
Epoch 0 Pack 53 Loss 0.604981005191803
Epoch 0 Pack 53 Loss 0.6207723617553711
Epoch 0 Pack 53 Loss 0.7497614026069641
Epoch 0 Pack 53 Loss 0.7094082236289978
Epoch 0 Pack 53 Loss 0.7131283283233643
Epoch 0 Pack 53 Loss 0.6509706974029541
Epoch 0 Pack 53 Loss 0.9406605362892151
Epoch 0 Pack 53 Loss 0.7005202174186707
Epoch 0 Pack 54 Loss 0.7248311638832092
Epoch 0 Pack 54 Loss 0.8024229407310486
Epoch 0 Pack 54 Loss 0.8125392198562622
Epoch 0 Pack 54 Loss 0.7393578290939331
Epoch 0 Pack 54 Loss 0.7457587122917175
Epoch 0 Pack 54 Loss 0.8777675628662109
Epoch 0 Pack 54 Loss 0.5792227387428284
Ep

Epoch 0 Pack 77 Loss 0.6295745968818665
Epoch 0 Pack 77 Loss 0.6229466795921326
Epoch 0 Pack 77 Loss 0.609272837638855
Epoch 0 Pack 77 Loss 0.7645947933197021
Epoch 0 Pack 78 Loss 0.6211947798728943
Epoch 0 Pack 78 Loss 0.6983494758605957
Epoch 0 Pack 78 Loss 0.7947655320167542
Epoch 0 Pack 78 Loss 0.8012849688529968
Epoch 0 Pack 78 Loss 0.7016460299491882
Epoch 0 Pack 78 Loss 0.7467050552368164
Epoch 0 Pack 78 Loss 0.8090361952781677
Epoch 0 Pack 78 Loss 0.6106287240982056
Epoch 0 Pack 79 Loss 0.8227038979530334
Epoch 0 Pack 79 Loss 0.6705199480056763
Epoch 0 Pack 79 Loss 0.7165796756744385
Epoch 0 Pack 79 Loss 0.6402063369750977
Epoch 0 Pack 79 Loss 0.6131652593612671
Epoch 0 Pack 79 Loss 0.8200572729110718
Epoch 0 Pack 79 Loss 0.6252464652061462
Epoch 0 Pack 79 Loss 0.8314741849899292
Epoch 0 Pack 80 Loss 0.7803547382354736
Epoch 0 Pack 80 Loss 0.7397570013999939
Epoch 0 Pack 80 Loss 0.670979380607605
Epoch 0 Pack 80 Loss 0.8028177618980408
Epoch 0 Pack 80 Loss 0.7501792311668396
Ep

Epoch 0 Pack 103 Loss 0.5643416047096252
Epoch 0 Pack 103 Loss 0.8290390968322754
Epoch 0 Pack 103 Loss 0.6301078796386719
Epoch 0 Pack 103 Loss 0.7329001426696777
Epoch 0 Pack 103 Loss 0.6010778546333313
Epoch 0 Pack 103 Loss 0.7987747192382812
Epoch 0 Pack 103 Loss 0.6771787405014038
Epoch 0 Pack 104 Loss 0.7016035914421082
Epoch 0 Pack 104 Loss 0.73887699842453
Epoch 0 Pack 104 Loss 0.795786440372467
Epoch 0 Pack 104 Loss 0.6752070784568787
Epoch 0 Pack 104 Loss 0.7728294730186462
Epoch 0 Pack 104 Loss 0.687274158000946
Epoch 0 Pack 104 Loss 0.6986631751060486
Epoch 0 Pack 104 Loss 0.7739216685295105
Epoch 0 Pack 105 Loss 0.6679486632347107
Epoch 0 Pack 105 Loss 0.7491772770881653
Epoch 0 Pack 105 Loss 0.6313050985336304
Epoch 0 Pack 105 Loss 0.6924847364425659
Epoch 0 Pack 105 Loss 0.7706329822540283
Epoch 0 Pack 105 Loss 0.5588668584823608
Epoch 0 Pack 105 Loss 0.7673169374465942
Epoch 0 Pack 105 Loss 0.5266313552856445
Epoch 0 Pack 106 Loss 0.916519045829773
Epoch 0 Pack 106 Loss

Epoch 0 Pack 128 Loss 0.6013200879096985
Epoch 0 Pack 128 Loss 0.6944103240966797
Epoch 0 Pack 128 Loss 0.6268292665481567
Epoch 0 Pack 128 Loss 0.7841353416442871
Epoch 0 Pack 128 Loss 0.6767542362213135
Epoch 0 Pack 128 Loss 0.7281631231307983
Epoch 0 Pack 129 Loss 0.6559398174285889
Epoch 0 Pack 129 Loss 0.5849480628967285
Epoch 0 Pack 129 Loss 0.7240815162658691
Epoch 0 Pack 129 Loss 0.6801363229751587
Epoch 0 Pack 129 Loss 0.8210688233375549
Epoch 0 Pack 129 Loss 0.807059109210968
Epoch 0 Pack 129 Loss 0.5858482718467712
Epoch 0 Pack 129 Loss 0.5373889803886414
Epoch 0 Pack 130 Loss 0.7988433241844177
Epoch 0 Pack 130 Loss 0.653607189655304
Epoch 0 Pack 130 Loss 0.8041138648986816
Epoch 0 Pack 130 Loss 0.6837820410728455
Epoch 0 Pack 130 Loss 0.8111532926559448
Epoch 0 Pack 130 Loss 0.6041640639305115
Epoch 0 Pack 130 Loss 0.6552101969718933
Epoch 0 Pack 130 Loss 0.6931160688400269
Epoch 0 Pack 131 Loss 0.8143197298049927
Epoch 0 Pack 131 Loss 0.7554063200950623
Epoch 0 Pack 131 L

Epoch 0 Pack 153 Loss 0.6312693953514099
Epoch 0 Pack 153 Loss 0.7079726457595825
Epoch 0 Pack 153 Loss 0.6654346585273743
Epoch 0 Pack 153 Loss 0.7695810198783875
Epoch 0 Pack 153 Loss 0.6649714708328247
Epoch 0 Pack 154 Loss 0.6088693141937256
Epoch 0 Pack 154 Loss 0.7266298532485962
Epoch 0 Pack 154 Loss 0.6688873767852783
Epoch 0 Pack 154 Loss 0.7875837683677673
Epoch 0 Pack 154 Loss 0.8365926742553711
Epoch 0 Pack 154 Loss 0.864716649055481
Epoch 0 Pack 154 Loss 0.7706628441810608
Epoch 0 Pack 154 Loss 0.622466504573822
Epoch 0 Pack 155 Loss 0.6579269170761108
Epoch 0 Pack 155 Loss 0.6373155117034912
Epoch 0 Pack 155 Loss 0.789879322052002
Epoch 0 Pack 155 Loss 0.530773401260376
Epoch 0 Pack 155 Loss 0.600293755531311
Epoch 0 Pack 155 Loss 0.7335379123687744
Epoch 0 Pack 155 Loss 0.6461290717124939
Epoch 0 Pack 155 Loss 0.7662257552146912
Epoch 0 Pack 156 Loss 0.6141861081123352
Epoch 0 Pack 156 Loss 0.7294763922691345
Epoch 0 Pack 156 Loss 0.7622271776199341
Epoch 0 Pack 156 Loss

Epoch 0 Pack 178 Loss 0.7260782122612
Epoch 0 Pack 178 Loss 0.634527325630188
Epoch 0 Pack 178 Loss 0.6696555614471436
Epoch 0 Pack 178 Loss 0.6987264156341553
Epoch 0 Pack 179 Loss 0.6161640286445618
Epoch 0 Pack 179 Loss 0.5898016691207886
Epoch 0 Pack 179 Loss 0.6835291385650635
Epoch 0 Pack 179 Loss 0.667963445186615
Epoch 0 Pack 179 Loss 0.7499179840087891
Epoch 0 Pack 179 Loss 0.5965445637702942
Epoch 0 Pack 179 Loss 0.6984299421310425
Epoch 0 Pack 179 Loss 0.6487895250320435
Epoch 0 Pack 180 Loss 0.8155221939086914
Epoch 0 Pack 180 Loss 0.7454356551170349
Epoch 0 Pack 180 Loss 0.8048859238624573
Epoch 0 Pack 180 Loss 0.7127739191055298
Epoch 0 Pack 180 Loss 0.7706242799758911
Epoch 0 Pack 180 Loss 0.7457444667816162
Epoch 0 Pack 180 Loss 0.7385208010673523
Epoch 0 Pack 180 Loss 0.6578514575958252
Epoch 0 Pack 181 Loss 0.6682485938072205
Epoch 0 Pack 181 Loss 0.7223203182220459
Epoch 0 Pack 181 Loss 0.7172122597694397
Epoch 0 Pack 181 Loss 0.7983598113059998
Epoch 0 Pack 181 Loss

Epoch 0 Pack 203 Loss 0.6693066358566284
Epoch 0 Pack 203 Loss 0.8331700563430786
Epoch 0 Pack 203 Loss 0.6609516739845276
Epoch 0 Pack 204 Loss 0.678307831287384
Epoch 0 Pack 204 Loss 0.6309409737586975
Epoch 0 Pack 204 Loss 0.6879286766052246
Epoch 0 Pack 204 Loss 0.5786482095718384
Epoch 0 Pack 204 Loss 0.7553805708885193
Epoch 0 Pack 204 Loss 0.6455824971199036
Epoch 0 Pack 204 Loss 0.8440299034118652
Epoch 0 Pack 204 Loss 0.6777769923210144
Epoch 0 Pack 205 Loss 0.6771286129951477
Epoch 0 Pack 205 Loss 0.8182247877120972
Epoch 0 Pack 205 Loss 0.7651206851005554
Epoch 0 Pack 205 Loss 0.6625799536705017
Epoch 0 Pack 205 Loss 0.5946934223175049
Epoch 0 Pack 205 Loss 0.7262890934944153
Epoch 0 Pack 205 Loss 0.7019707560539246
Epoch 0 Pack 205 Loss 0.8250668048858643
Epoch 0 Pack 206 Loss 0.8491796851158142
Epoch 0 Pack 206 Loss 0.7571026682853699
Epoch 0 Pack 206 Loss 0.7897329330444336
Epoch 0 Pack 206 Loss 0.6547117829322815
Epoch 0 Pack 206 Loss 0.789107084274292
Epoch 0 Pack 206 L

Epoch 0 Pack 228 Loss 0.7084945440292358
Epoch 0 Pack 228 Loss 0.6579149961471558
Epoch 0 Pack 229 Loss 0.7964068055152893
Epoch 0 Pack 229 Loss 0.8036507368087769
Epoch 0 Pack 229 Loss 0.6883438229560852
Epoch 0 Pack 229 Loss 0.6466311812400818
Epoch 0 Pack 229 Loss 0.7346715927124023
Epoch 0 Pack 229 Loss 0.6351751089096069
Epoch 0 Pack 229 Loss 0.5847271680831909
Epoch 0 Pack 229 Loss 0.7052361369132996
Epoch 0 Pack 230 Loss 0.6689901351928711
Epoch 0 Pack 230 Loss 0.6694002151489258
Epoch 0 Pack 230 Loss 0.752407968044281
Epoch 0 Pack 230 Loss 0.6925952434539795
Epoch 0 Pack 230 Loss 0.7514748573303223
Epoch 0 Pack 230 Loss 0.6172968745231628
Epoch 0 Pack 230 Loss 0.7888516187667847
Epoch 0 Pack 230 Loss 0.7746025323867798
Epoch 0 Pack 231 Loss 0.7098000645637512
Epoch 0 Pack 231 Loss 0.6948099732398987
Epoch 0 Pack 231 Loss 0.7273068428039551
Epoch 0 Pack 231 Loss 0.6613176465034485
Epoch 0 Pack 231 Loss 0.7042025327682495
Epoch 0 Pack 231 Loss 0.7963986396789551
Epoch 0 Pack 231 

Epoch 0 Pack 253 Loss 0.7320940494537354
Epoch 0 Pack 254 Loss 0.6490117907524109
Epoch 0 Pack 254 Loss 0.6682408452033997
Epoch 0 Pack 254 Loss 0.6340940594673157
Epoch 0 Pack 254 Loss 0.6973015069961548
Epoch 0 Pack 254 Loss 0.6477637887001038
Epoch 0 Pack 254 Loss 0.7571375966072083
Epoch 0 Pack 254 Loss 0.7131782174110413
Epoch 0 Pack 254 Loss 0.6584535837173462
Epoch 0 Pack 255 Loss 0.7039490938186646
Epoch 0 Pack 255 Loss 0.6621522903442383
Epoch 0 Pack 255 Loss 0.7643954753875732
Epoch 0 Pack 255 Loss 0.5734276175498962
Epoch 0 Pack 255 Loss 0.8538785576820374
Epoch 0 Pack 255 Loss 0.6378618478775024
Epoch 0 Pack 255 Loss 0.7130578756332397
Epoch 0 Pack 255 Loss 0.7112792134284973
Epoch 0 Pack 256 Loss 0.7258482575416565
Epoch 0 Pack 256 Loss 0.8129177093505859
Epoch 0 Pack 256 Loss 0.7621042132377625
Epoch 0 Pack 256 Loss 0.6875084638595581
Epoch 0 Pack 256 Loss 0.6761965155601501
Epoch 0 Pack 256 Loss 0.6328757405281067
Epoch 0 Pack 256 Loss 0.6079754829406738
Epoch 0 Pack 256

Epoch 0 Pack 279 Loss 0.536717414855957
Epoch 0 Pack 279 Loss 0.6407250761985779
Epoch 0 Pack 279 Loss 0.7716849446296692
Epoch 0 Pack 279 Loss 0.7322776317596436
Epoch 0 Pack 279 Loss 0.8430042862892151
Epoch 0 Pack 279 Loss 0.6621674299240112
Epoch 0 Pack 279 Loss 0.7119523286819458
Epoch 0 Pack 279 Loss 0.7123739719390869
Epoch 0 Pack 280 Loss 0.7817721366882324
Epoch 0 Pack 280 Loss 0.5373813509941101
Epoch 0 Pack 280 Loss 0.7619720697402954
Epoch 0 Pack 280 Loss 0.7734092473983765
Epoch 0 Pack 280 Loss 0.7806218266487122
Epoch 0 Pack 280 Loss 0.8128160238265991
Epoch 0 Pack 280 Loss 0.5577837228775024
Epoch 0 Pack 280 Loss 0.5638477206230164
Epoch 0 Pack 281 Loss 0.7371842861175537
Epoch 0 Pack 281 Loss 0.7486756443977356
Epoch 0 Pack 281 Loss 0.6104936599731445
Epoch 0 Pack 281 Loss 0.6116182208061218
Epoch 0 Pack 281 Loss 0.7360755801200867
Epoch 0 Pack 281 Loss 0.7360796928405762
Epoch 0 Pack 281 Loss 0.6023656129837036
Epoch 0 Pack 281 Loss 0.8642150163650513
Epoch 0 Pack 282 

KeyboardInterrupt: 

In [9]:
torch.save(classif.state_dict(), 'graphs/bregxnet2')

In [17]:
hits = 0
total = 0

# ploader_val = PackLoader(200, 'data/affectnet/val_set')
for pnum in range(len(ploader_val.files) // ploader.psize + 1):
    pack = ploader_val.pack(pnum)
    dloader = DataLoader(pack, batch_size=8, collate_fn=gen_batch)
    
    for i, (x, labels) in enumerate(dloader):
        pred = classif(x)
        
        pr_labels = torch.argmax(pred.detach(), dim=-1)
        hits += (labels == torch.argmax(pred, dim=-1)).sum()
        total += labels.shape[0]
        
print('Accuracy = {}'.format(hits / total))

Accuracy = 0.53207365179061819
